diff --git a/CHANGELOG.md b/CHANGELOG.md index 85728d69..e3ab9963 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 future, this will be expanded to redirect queries for certain TLD's to alternative backend. +### Changed + +- The Quic connection type now uses quic datagrams to transport __data__ (packets + coming from the TUN device) to the peer. Protocol traffic is still sent over a + bidirectional Quic stream (which supports retransmits). + ### Fixed - Return actuall amount of bytes sent to peers instead of the amount of bytes received diff --git a/mycelium/src/connection.rs b/mycelium/src/connection.rs index 819cd7e2..1f19aef7 100644 --- a/mycelium/src/connection.rs +++ b/mycelium/src/connection.rs @@ -1,15 +1,29 @@ -use std::{io, net::SocketAddr, pin::Pin}; +use std::{ + future::Future, + io, + net::SocketAddr, + pin::Pin, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, +}; + +use crate::packet::{self, ControlPacket, DataPacket, Packet}; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - net::TcpStream, +use bytes::{Bytes, BytesMut}; +use futures::{ + stream::{SplitSink, SplitStream}, + SinkExt, StreamExt, }; +use tokio::io::{AsyncRead, AsyncWrite}; mod tracked; +use tokio_util::codec::{Decoder, Encoder, Framed}; pub use tracked::Tracked; #[cfg(feature = "private-network")] -mod tls; +pub mod tls; /// Cost to add to the peer_link_cost for "local processing", when peers are connected over IPv6. /// @@ -31,39 +45,112 @@ const PACKET_PROCESSING_COST_IP6_QUIC: u16 = 7; // TODO const PACKET_PROCESSING_COST_IP4_QUIC: u16 = 12; -pub trait Connection: AsyncRead + AsyncWrite { +pub trait ConnectionReadHalf: Send { + /// Receive a packet from the remote end. + fn receive_packet(&mut self) -> impl Future>> + Send; +} + +pub trait ConnectionWriteHalf: Send { + /// Feeds a data packet on the connection. Depending on the connection you might need to call + /// [`Connection::flush`] before the packet is actually sent. + fn feed_data_packet( + &mut self, + packet: DataPacket, + ) -> impl Future> + Send; + + /// Feeds a control packet on the connection. Depending on the connection you might need to call + /// [`Connection::flush`] before the packet is actually sent. + fn feed_control_packet( + &mut self, + packet: ControlPacket, + ) -> impl Future> + Send; + + /// Flush the connection. This sends all buffered packets which haven't beend sent yet. + fn flush(&mut self) -> impl Future> + Send; +} + +pub trait Connection { + type ReadHalf: ConnectionReadHalf; + type WriteHalf: ConnectionWriteHalf; + + /// Feeds a data packet on the connection. Depending on the connection you might need to call + /// [`Connection::flush`] before the packet is actually sent. + fn feed_data_packet( + &mut self, + packet: DataPacket, + ) -> impl Future> + Send; + + /// Feeds a control packet on the connection. Depending on the connection you might need to call + /// [`Connection::flush`] before the packet is actually sent. + fn feed_control_packet( + &mut self, + packet: ControlPacket, + ) -> impl Future> + Send; + + /// Flush the connection. This sends all buffered packets which haven't beend sent yet. + fn flush(&mut self) -> impl Future> + Send; + + /// Receive a packet from the remote end. + fn receive_packet(&mut self) -> impl Future>> + Send; + /// Get an identifier for this connection, which shows details about the remote fn identifier(&self) -> Result; /// The static cost of using this connection fn static_link_cost(&self) -> Result; + + /// Split the connection in a read and write half which can be used independently + fn split(self) -> (Self::ReadHalf, Self::WriteHalf); } -/// A wrapper around a quic send and quic receive stream, implementing the [`Connection`] trait. -pub struct Quic { - tx: quinn::SendStream, - rx: quinn::RecvStream, - remote: SocketAddr, +/// A wrapper about an asynchronous (non blocking) tcp stream. +pub struct TcpStream { + framed: Framed, packet::Codec>, + local_addr: SocketAddr, + peer_addr: SocketAddr, } -impl Quic { - /// Create a new wrapper around Quic streams. - pub fn new(tx: quinn::SendStream, rx: quinn::RecvStream, remote: SocketAddr) -> Self { - Quic { tx, rx, remote } +impl TcpStream { + /// Create a new wrapped [`TcpStream`] which implements the [`Connection`] trait. + pub fn new( + tcp_stream: tokio::net::TcpStream, + read: Arc, + write: Arc, + ) -> io::Result { + Ok(Self { + local_addr: tcp_stream.local_addr()?, + peer_addr: tcp_stream.peer_addr()?, + framed: Framed::new(Tracked::new(read, write, tcp_stream), packet::Codec::new()), + }) } } impl Connection for TcpStream { + type ReadHalf = TcpStreamReadHalf; + type WriteHalf = TcpStreamWriteHalf; + + async fn feed_data_packet(&mut self, packet: DataPacket) -> io::Result<()> { + self.framed.feed(Packet::DataPacket(packet)).await + } + + async fn feed_control_packet(&mut self, packet: ControlPacket) -> io::Result<()> { + self.framed.feed(Packet::ControlPacket(packet)).await + } + + async fn receive_packet(&mut self) -> Option> { + self.framed.next().await + } + + async fn flush(&mut self) -> io::Result<()> { + self.framed.flush().await + } + fn identifier(&self) -> Result { - Ok(format!( - "TCP {} <-> {}", - self.local_addr()?, - self.peer_addr()? - )) + Ok(format!("TCP {} <-> {}", self.local_addr, self.peer_addr)) } fn static_link_cost(&self) -> Result { - Ok(match self.peer_addr()? { + Ok(match self.peer_addr { SocketAddr::V4(_) => PACKET_PROCESSING_COST_IP4_TCP, SocketAddr::V6(ip) if ip.ip().to_ipv4_mapped().is_some() => { PACKET_PROCESSING_COST_IP4_TCP @@ -71,9 +158,80 @@ impl Connection for TcpStream { SocketAddr::V6(_) => PACKET_PROCESSING_COST_IP6_TCP, }) } + + fn split(self) -> (Self::ReadHalf, Self::WriteHalf) { + let (tx, rx) = self.framed.split(); + + ( + TcpStreamReadHalf { framed: rx }, + TcpStreamWriteHalf { framed: tx }, + ) + } } -impl AsyncRead for Quic { +pub struct TcpStreamReadHalf { + framed: SplitStream, packet::Codec>>, +} + +impl ConnectionReadHalf for TcpStreamReadHalf { + async fn receive_packet(&mut self) -> Option> { + self.framed.next().await + } +} + +pub struct TcpStreamWriteHalf { + framed: SplitSink, packet::Codec>, packet::Packet>, +} + +impl ConnectionWriteHalf for TcpStreamWriteHalf { + async fn feed_data_packet(&mut self, packet: DataPacket) -> io::Result<()> { + self.framed.feed(Packet::DataPacket(packet)).await + } + + async fn feed_control_packet(&mut self, packet: ControlPacket) -> io::Result<()> { + self.framed.feed(Packet::ControlPacket(packet)).await + } + + async fn flush(&mut self) -> io::Result<()> { + self.framed.flush().await + } +} + +/// A wrapper around a quic send and quic receive stream, implementing the [`Connection`] trait. +pub struct Quic { + framed: Framed, packet::Codec>, + con: quinn::Connection, + read: Arc, + write: Arc, +} + +struct QuicStream { + tx: quinn::SendStream, + rx: quinn::RecvStream, +} + +impl Quic { + /// Create a new wrapper around Quic streams. + pub fn new( + tx: quinn::SendStream, + rx: quinn::RecvStream, + con: quinn::Connection, + read: Arc, + write: Arc, + ) -> Self { + Quic { + framed: Framed::new( + Tracked::new(read.clone(), write.clone(), QuicStream { tx, rx }), + packet::Codec::new(), + ), + con, + read, + write, + } + } +} + +impl AsyncRead for QuicStream { #[inline] fn poll_read( mut self: std::pin::Pin<&mut Self>, @@ -84,7 +242,7 @@ impl AsyncRead for Quic { } } -impl AsyncWrite for Quic { +impl AsyncWrite for QuicStream { #[inline] fn poll_write( mut self: Pin<&mut Self>, @@ -128,12 +286,61 @@ impl AsyncWrite for Quic { } impl Connection for Quic { + type ReadHalf = QuicReadHalf; + + type WriteHalf = QuicWriteHalf; + + async fn feed_data_packet(&mut self, packet: DataPacket) -> io::Result<()> { + let mut codec = packet::Codec::new(); + let mut buffer = BytesMut::with_capacity(1500); + codec.encode(Packet::DataPacket(packet), &mut buffer)?; + + let data: Bytes = buffer.into(); + let tx_len = data.len(); + self.write.fetch_add(tx_len as u64, Ordering::Relaxed); + + self.con.send_datagram(data).map_err(io::Error::other) + } + + async fn feed_control_packet(&mut self, packet: ControlPacket) -> io::Result<()> { + self.framed.feed(Packet::ControlPacket(packet)).await + } + + async fn receive_packet(&mut self) -> Option> { + tokio::select! { + datagram = self.con.read_datagram() => { + let datagram_bytes = match datagram { + Ok(buffer) => buffer, + Err(e) => return Some(Err(e.into())), + }; + let recv_len = datagram_bytes.len(); + self.read.fetch_add(recv_len as u64, Ordering::Relaxed); + let mut codec = packet::Codec::new(); + match codec.decode(&mut datagram_bytes.into()) { + Ok(Some(packet)) => Some(Ok(packet)), + // Partial? packet read. We consider this to be a stream hangup + // TODO: verify + Ok(None) => None, + Err(e) => Some(Err(e)), + } + }, + packet = self.framed.next() => { + packet + } + + } + } + + async fn flush(&mut self) -> io::Result<()> { + self.framed.flush().await + } + fn identifier(&self) -> Result { - Ok(format!("QUIC -> {}", self.remote)) + Ok(format!("QUIC -> {}", self.con.remote_address())) } fn static_link_cost(&self) -> Result { - Ok(match self.remote { + Ok(match self.con.remote_address() { SocketAddr::V4(_) => PACKET_PROCESSING_COST_IP4_QUIC, SocketAddr::V6(ip) if ip.ip().to_ipv4_mapped().is_some() => { PACKET_PROCESSING_COST_IP4_QUIC @@ -141,13 +348,130 @@ impl Connection for Quic { SocketAddr::V6(_) => PACKET_PROCESSING_COST_IP6_QUIC, }) } + + fn split(self) -> (Self::ReadHalf, Self::WriteHalf) { + let Self { + framed, + con, + read, + write, + } = self; + + let (tx, rx) = framed.split(); + + ( + QuicReadHalf { + framed: rx, + con: con.clone(), + read, + }, + QuicWriteHalf { + framed: tx, + con, + write, + }, + ) + } +} + +pub struct QuicReadHalf { + framed: SplitStream, packet::Codec>>, + con: quinn::Connection, + read: Arc, +} + +pub struct QuicWriteHalf { + framed: SplitSink, packet::Codec>, packet::Packet>, + con: quinn::Connection, + write: Arc, +} + +impl ConnectionReadHalf for QuicReadHalf { + async fn receive_packet(&mut self) -> Option> { + tokio::select! { + datagram = self.con.read_datagram() => { + let datagram_bytes = match datagram { + Ok(buffer) => buffer, + Err(e) => return Some(Err(e.into())), + }; + let recv_len = datagram_bytes.len(); + self.read.fetch_add(recv_len as u64, Ordering::Relaxed); + let mut codec = packet::Codec::new(); + match codec.decode(&mut datagram_bytes.into()) { + Ok(Some(packet)) => Some(Ok(packet)), + // Partial? packet read. We consider this to be a stream hangup + // TODO: verify + Ok(None) => None, + Err(e) => Some(Err(e)), + } + }, + packet = self.framed.next() => { + packet + } + + } + } +} + +impl ConnectionWriteHalf for QuicWriteHalf { + async fn feed_data_packet(&mut self, packet: DataPacket) -> io::Result<()> { + let mut codec = packet::Codec::new(); + let mut buffer = BytesMut::with_capacity(1500); + codec.encode(Packet::DataPacket(packet), &mut buffer)?; + + let data: Bytes = buffer.into(); + let tx_len = data.len(); + self.write.fetch_add(tx_len as u64, Ordering::Relaxed); + + self.con.send_datagram(data).map_err(io::Error::other) + } + + async fn feed_control_packet(&mut self, packet: ControlPacket) -> io::Result<()> { + self.framed.feed(Packet::ControlPacket(packet)).await + } + + async fn flush(&mut self) -> io::Result<()> { + self.framed.flush().await + } } #[cfg(test)] -use tokio::io::DuplexStream; +/// Wrapper for an in-memory pipe implementing the [`Connection`] trait. +pub struct DuplexStream { + framed: Framed, +} + +#[cfg(test)] +impl DuplexStream { + /// Create a new in memory duplex stream. + pub fn new(duplex: tokio::io::DuplexStream) -> Self { + Self { + framed: Framed::new(duplex, packet::Codec::new()), + } + } +} #[cfg(test)] impl Connection for DuplexStream { + type ReadHalf = DuplexStreamReadHalf; + type WriteHalf = DuplexStreamWriteHalf; + + async fn feed_data_packet(&mut self, packet: DataPacket) -> io::Result<()> { + self.framed.feed(Packet::DataPacket(packet)).await + } + + async fn feed_control_packet(&mut self, packet: ControlPacket) -> io::Result<()> { + self.framed.feed(Packet::ControlPacket(packet)).await + } + + async fn receive_packet(&mut self) -> Option> { + self.framed.next().await + } + + async fn flush(&mut self) -> io::Result<()> { + self.framed.flush().await + } + fn identifier(&self) -> Result { Ok("Memory pipe".to_string()) } @@ -155,4 +479,45 @@ impl Connection for DuplexStream { fn static_link_cost(&self) -> Result { Ok(1) } + + fn split(self) -> (Self::ReadHalf, Self::WriteHalf) { + let (tx, rx) = self.framed.split(); + + ( + DuplexStreamReadHalf { framed: rx }, + DuplexStreamWriteHalf { framed: tx }, + ) + } +} + +#[cfg(test)] +pub struct DuplexStreamReadHalf { + framed: SplitStream>, +} + +#[cfg(test)] +pub struct DuplexStreamWriteHalf { + framed: SplitSink, packet::Packet>, +} + +#[cfg(test)] +impl ConnectionReadHalf for DuplexStreamReadHalf { + async fn receive_packet(&mut self) -> Option> { + self.framed.next().await + } +} + +#[cfg(test)] +impl ConnectionWriteHalf for DuplexStreamWriteHalf { + async fn feed_data_packet(&mut self, packet: DataPacket) -> io::Result<()> { + self.framed.feed(Packet::DataPacket(packet)).await + } + + async fn feed_control_packet(&mut self, packet: ControlPacket) -> io::Result<()> { + self.framed.feed(Packet::ControlPacket(packet)).await + } + + async fn flush(&mut self) -> io::Result<()> { + self.framed.flush().await + } } diff --git a/mycelium/src/connection/tls.rs b/mycelium/src/connection/tls.rs index a73571e8..7eb73504 100644 --- a/mycelium/src/connection/tls.rs +++ b/mycelium/src/connection/tls.rs @@ -1,18 +1,72 @@ -use std::{io, net::SocketAddr}; +use std::{ + io, + net::SocketAddr, + sync::{atomic::AtomicU64, Arc}, +}; +use futures::{ + stream::{SplitSink, SplitStream}, + SinkExt, StreamExt, +}; use tokio::net::TcpStream; +use tokio_util::codec::Framed; + +use crate::{ + connection::Tracked, + packet::{self, Packet}, +}; + +/// A wrapper around an asynchronous TLS stream. +pub struct TlsStream { + framed: Framed>, packet::Codec>, + local_addr: SocketAddr, + peer_addr: SocketAddr, +} + +impl TlsStream { + /// Create a new wrapped [`TlsStream`] which implements the [`Connection`](super::Connection) trait. + pub fn new( + tls_stream: tokio_openssl::SslStream, + read: Arc, + write: Arc, + ) -> io::Result { + Ok(Self { + local_addr: tls_stream.get_ref().local_addr()?, + peer_addr: tls_stream.get_ref().peer_addr()?, + framed: Framed::new(Tracked::new(read, write, tls_stream), packet::Codec::new()), + }) + } +} + +impl super::Connection for TlsStream { + type ReadHalf = TlsStreamReadHalf; + type WriteHalf = TlsStreamWriteHalf; + + async fn feed_data_packet(&mut self, packet: crate::packet::DataPacket) -> io::Result<()> { + self.framed.feed(Packet::DataPacket(packet)).await + } + + async fn feed_control_packet( + &mut self, + packet: crate::packet::ControlPacket, + ) -> io::Result<()> { + self.framed.feed(Packet::ControlPacket(packet)).await + } + + async fn flush(&mut self) -> io::Result<()> { + self.framed.flush().await + } + + async fn receive_packet(&mut self) -> Option> { + self.framed.next().await + } -impl super::Connection for tokio_openssl::SslStream { fn identifier(&self) -> Result { - Ok(format!( - "TLS {} <-> {}", - self.get_ref().local_addr()?, - self.get_ref().peer_addr()? - )) + Ok(format!("TLS {} <-> {}", self.local_addr, self.peer_addr)) } fn static_link_cost(&self) -> Result { - Ok(match self.get_ref().peer_addr()? { + Ok(match self.peer_addr { SocketAddr::V4(_) => super::PACKET_PROCESSING_COST_IP4_TCP, SocketAddr::V6(ip) if ip.ip().to_ipv4_mapped().is_some() => { super::PACKET_PROCESSING_COST_IP4_TCP @@ -20,4 +74,47 @@ impl super::Connection for tokio_openssl::SslStream { SocketAddr::V6(_) => super::PACKET_PROCESSING_COST_IP6_TCP, }) } + + fn split(self) -> (Self::ReadHalf, Self::WriteHalf) { + let (tx, rx) = self.framed.split(); + + ( + TlsStreamReadHalf { framed: rx }, + TlsStreamWriteHalf { framed: tx }, + ) + } +} + +pub struct TlsStreamReadHalf { + framed: SplitStream>, packet::Codec>>, +} + +pub struct TlsStreamWriteHalf { + framed: SplitSink< + Framed>, packet::Codec>, + packet::Packet, + >, +} + +impl super::ConnectionReadHalf for TlsStreamReadHalf { + async fn receive_packet(&mut self) -> Option> { + self.framed.next().await + } +} + +impl super::ConnectionWriteHalf for TlsStreamWriteHalf { + async fn feed_data_packet(&mut self, packet: crate::packet::DataPacket) -> io::Result<()> { + self.framed.feed(Packet::DataPacket(packet)).await + } + + async fn feed_control_packet( + &mut self, + packet: crate::packet::ControlPacket, + ) -> io::Result<()> { + self.framed.feed(Packet::ControlPacket(packet)).await + } + + async fn flush(&mut self) -> io::Result<()> { + self.framed.flush().await + } } diff --git a/mycelium/src/connection/tracked.rs b/mycelium/src/connection/tracked.rs index 74f12f59..6c8ecac9 100644 --- a/mycelium/src/connection/tracked.rs +++ b/mycelium/src/connection/tracked.rs @@ -9,8 +9,6 @@ use std::{ use tokio::io::{AsyncRead, AsyncWrite}; -use super::Connection; - /// Wrapper which keeps track of how much bytes have been read and written from a connection. pub struct Tracked { /// Bytes read counter @@ -23,7 +21,7 @@ pub struct Tracked { impl Tracked where - C: Connection + Unpin, + C: AsyncRead + AsyncWrite + Unpin, { /// Create a new instance of a tracked connections. Counters are passed in so they can be /// reused accross connections. @@ -32,21 +30,6 @@ where } } -impl Connection for Tracked -where - C: Connection + Unpin, -{ - #[inline] - fn identifier(&self) -> Result { - self.con.identifier() - } - - #[inline] - fn static_link_cost(&self) -> Result { - self.con.static_link_cost() - } -} - impl AsyncRead for Tracked where C: AsyncRead + Unpin, diff --git a/mycelium/src/crypto.rs b/mycelium/src/crypto.rs index affd07d6..5bf123b9 100644 --- a/mycelium/src/crypto.rs +++ b/mycelium/src/crypto.rs @@ -15,7 +15,7 @@ use serde::{de::Visitor, Deserialize, Serialize}; /// const generic argument which is then expanded with the needed extra space for the buffer, /// however as it stands const generics can only be used standalone and not in a constant /// expression. This _is_ possible on nightly rust, with a feature gate (generic_const_exprs). -const PACKET_SIZE: usize = 1400; +const PACKET_SIZE: usize = 1_400; /// Size of an AES_GCM tag in bytes. const AES_TAG_SIZE: usize = 16; diff --git a/mycelium/src/peer.rs b/mycelium/src/peer.rs index 5b769b4c..fd6404da 100644 --- a/mycelium/src/peer.rs +++ b/mycelium/src/peer.rs @@ -1,9 +1,8 @@ -use futures::{SinkExt, StreamExt}; use std::{ error::Error, io, sync::{ - atomic::{AtomicBool, AtomicU64, Ordering}, + atomic::{AtomicBool, Ordering}, Arc, RwLock, Weak, }, }; @@ -11,12 +10,11 @@ use tokio::{ select, sync::{mpsc, Notify}, }; -use tokio_util::codec::Framed; use tracing::{debug, error, info, trace}; use crate::{ - connection::{self, Connection}, - packet::{self, Packet}, + connection::{Connection, ConnectionReadHalf, ConnectionWriteHalf}, + packet::Packet, }; use crate::{ packet::{ControlPacket, DataPacket}, @@ -65,12 +63,7 @@ impl Peer { router_control_tx: mpsc::UnboundedSender<(ControlPacket, Peer)>, connection: C, dead_peer_sink: mpsc::Sender, - bytes_written: Arc, - bytes_read: Arc, ) -> Result { - // Wrap connection so we can get access to the counters. - let connection = connection::Tracked::new(bytes_read, bytes_written, connection); - // Data channel for peer let (to_peer_data, mut from_routing_data) = mpsc::unbounded_channel::(); // Control channel for peer @@ -90,21 +83,18 @@ impl Peer { }), }; - // Framed for peer - // Used to send and receive packets from a TCP stream - let framed = Framed::with_capacity(connection, packet::Codec::new(), 128 << 10); - let (mut sink, mut stream) = framed.split(); - { let peer = peer.clone(); + let (mut stream, mut sink) = connection.split(); + + let mut needs_flush = false; + tokio::spawn(async move { - let mut needs_flush = false; loop { select! { - // Received over the TCP stream - frame = stream.next() => { - match frame { + packet = stream.receive_packet() => { + match packet { Some(Ok(packet)) => { match packet { Packet::DataPacket(packet) => { @@ -137,14 +127,13 @@ impl Peer { } } - rv = from_routing_data.recv(), if !needs_flush => { + rv = from_routing_data.recv(), if !needs_flush => { match rv { None => break, Some(packet) => { - needs_flush = true; - if let Err(e) = sink.feed(Packet::DataPacket(packet)).await { + if let Err(e) = sink.feed_data_packet(packet).await { error!("Failed to feed data packet to connection: {e}"); break } @@ -154,13 +143,13 @@ impl Peer { // There can be 2 cases of errors here, empty channel and no more // senders. In both cases we don't really care at this point. if let Ok(packet) = from_routing_data.try_recv() { - if let Err(e) = sink.feed(Packet::DataPacket(packet)).await { + if let Err(e) = sink.feed_data_packet(packet).await { error!("Failed to feed data packet to connection: {e}"); break } trace!("Instantly queued ready packet to transfer to peer"); } else { - // No packets ready, flush currently buffered ones + // no packets ready, flush currently buffered ones break } } @@ -172,10 +161,9 @@ impl Peer { match rv { None => break, Some(packet) => { - needs_flush = true; - if let Err(e) = sink.feed(Packet::ControlPacket(packet)).await { + if let Err(e) = sink.feed_control_packet(packet).await { error!("Failed to feed control packet to connection: {e}"); break } @@ -184,7 +172,7 @@ impl Peer { // There can be 2 cases of errors here, empty channel and no more // senders. In both cases we don't really care at this point. if let Ok(packet) = from_routing_control.try_recv() { - if let Err(e) = sink.feed(Packet::ControlPacket(packet)).await { + if let Err(e) = sink.feed_control_packet(packet).await { error!("Failed to feed data packet to connection: {e}"); break } @@ -198,8 +186,8 @@ impl Peer { } r = sink.flush(), if needs_flush => { - if let Err(err) = r { - error!("Failed to flush peer connection: {err}"); + if let Err(e) = r { + error!("Failed to flush buffered peer connection packets: {e}"); break } needs_flush = false; @@ -207,8 +195,8 @@ impl Peer { _ = death_watcher.notified() => { // Attempt gracefull shutdown - let mut framed = sink.reunite(stream).expect("SplitSink and SplitStream here can only be part of the same original Framned; Qed"); - let _ = framed.close().await; + // let mut framed = sink.reunite(stream).expect("SplitSink and SplitStream here can only be part of the same original Framned; Qed"); + // let _ = framed.close().await; break; } } diff --git a/mycelium/src/peer_manager.rs b/mycelium/src/peer_manager.rs index 26ab00f6..5b28d96f 100644 --- a/mycelium/src/peer_manager.rs +++ b/mycelium/src/peer_manager.rs @@ -1,4 +1,6 @@ -use crate::connection::Quic; +#[cfg(feature = "private-network")] +use crate::connection::tls::TlsStream; +use crate::connection::{Quic, TcpStream}; use crate::endpoint::{Endpoint, Protocol}; use crate::metrics::Metrics; use crate::peer::{Peer, PeerRef}; @@ -9,7 +11,7 @@ use futures::{FutureExt, StreamExt}; #[cfg(feature = "private-network")] use openssl::ssl::{Ssl, SslAcceptor, SslConnector, SslMethod}; use quinn::crypto::rustls::QuicClientConfig; -use quinn::{MtuDiscoveryConfig, ServerConfig, TransportConfig}; +use quinn::{congestion, MtuDiscoveryConfig, ServerConfig, TransportConfig}; use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer, ServerName, UnixTime}; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; @@ -24,7 +26,6 @@ use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::{Arc, Mutex}; use std::time::Duration; use std::{collections::hash_map::Entry, future::IntoFuture}; -use tokio::net::TcpStream; use tokio::net::{TcpListener, UdpSocket}; use tokio::task::AbortHandle; use tokio::time::{Instant, MissedTickBehavior}; @@ -496,7 +497,7 @@ where None }; - match TcpStream::connect(endpoint.address()) + match tokio::net::TcpStream::connect(endpoint.address()) .map(|result| result.and_then(|socket| set_fw_mark(socket, self.firewall_mark))) .await { @@ -545,35 +546,57 @@ where } debug!("Completed TLS handshake"); + let tls_stream = match TlsStream::new(ssl_stream, ct.rx_bytes, ct.tx_bytes) + { + Ok(tls_stream) => tls_stream, + Err(err) => { + error!(%err, "Failed to create wrapped Tls stream"); + return (endpoint, None); + } + }; + Peer::new( router_data_tx, router_control_tx, - ssl_stream, + tls_stream, dead_peer_sink, - ct.tx_bytes, - ct.rx_bytes, ) } else { + let peer_stream = + match TcpStream::new(peer_stream, ct.rx_bytes, ct.tx_bytes) { + Ok(ps) => ps, + Err(err) => { + error!(%err, "Failed to create wrapped tcp stream"); + return (endpoint, None); + } + }; + Peer::new( router_data_tx, router_control_tx, peer_stream, dead_peer_sink, - ct.tx_bytes, - ct.rx_bytes, ) } }; #[cfg(not(feature = "private-network"))] - let res = Peer::new( - router_data_tx, - router_control_tx, - peer_stream, - dead_peer_sink, - ct.tx_bytes, - ct.rx_bytes, - ); + let res = { + let peer_stream = match TcpStream::new(peer_stream, ct.rx_bytes, ct.tx_bytes) { + Ok(ps) => ps, + Err(err) => { + error!(%err, "Failed to create wrapped tcp stream"); + return (endpoint, None); + } + }; + + Peer::new( + router_data_tx, + router_control_tx, + peer_stream, + dead_peer_sink, + ) + }; match res { Ok(new_peer) => { @@ -630,29 +653,23 @@ where transport_config.mtu_discovery_config(Some(MtuDiscoveryConfig::default())); transport_config.keep_alive_interval(Some(Duration::from_secs(20))); // we don't use datagrams. - transport_config.datagram_receive_buffer_size(None); - transport_config.datagram_send_buffer_size(0); + transport_config.datagram_receive_buffer_size(Some(16 << 20)); + transport_config.datagram_send_buffer_size(16 << 20); + transport_config.initial_mtu(1500); config.transport_config(Arc::new(transport_config)); match quic_socket.connect_with(config, endpoint.address(), "dummy.mycelium") { Ok(connecting) => match connecting.await { Ok(con) => match con.open_bi().await { Ok((tx, rx)) => { - let q_con = Quic::new(tx, rx, endpoint.address()); + let q_con = Quic::new(tx, rx, con, ct.tx_bytes, ct.rx_bytes); let res = { let router = self.router.lock().unwrap(); let router_data_tx = router.router_data_tx(); let router_control_tx = router.router_control_tx(); let dead_peer_sink = router.dead_peer_sink().clone(); - Peer::new( - router_data_tx, - router_control_tx, - q_con, - dead_peer_sink, - ct.tx_bytes, - ct.rx_bytes, - ) + Peer::new(router_data_tx, router_control_tx, q_con, dead_peer_sink) }; match res { Ok(new_peer) => { @@ -760,42 +777,71 @@ where } debug!(%remote, "Accepted TLS handshake"); + let tls_stream = match TlsStream::new( + ssl_stream, + rx_bytes.clone(), + tx_bytes.clone(), + ) { + Ok(tls_stream) => tls_stream, + Err(err) => { + error!(%err, "Failed to create wrapped Tls stream"); + continue; + } + }; + Peer::new( router_data_tx.clone(), router_control_tx.clone(), - ssl_stream, + tls_stream, dead_peer_sink.clone(), - tx_bytes.clone(), - rx_bytes.clone(), ) } else { + let new_stream = + match TcpStream::new(stream, rx_bytes.clone(), tx_bytes.clone()) { + Ok(ns) => ns, + Err(err) => { + error!(%err, "Failed to create wrapped tcp stream"); + continue; + } + }; + Peer::new( router_data_tx.clone(), router_control_tx.clone(), - stream, + new_stream, dead_peer_sink.clone(), - tx_bytes.clone(), - rx_bytes.clone(), ) }; #[cfg(not(feature = "private-network"))] - let new_peer = Peer::new( - router_data_tx.clone(), - router_control_tx.clone(), - stream, - dead_peer_sink.clone(), - tx_bytes.clone(), - rx_bytes.clone(), - ); + let new_peer = { + let new_stream = + match TcpStream::new(stream, rx_bytes.clone(), tx_bytes.clone()) { + Ok(ns) => ns, + Err(err) => { + error!(%err, "Failed to create wrapped tcp stream"); + continue; + } + }; - let new_peer = match new_peer { - Ok(peer) => peer, - Err(e) => { - error!(err=%e, "Failed to spawn peer"); - continue; + Peer::new( + router_data_tx.clone(), + router_control_tx.clone(), + new_stream, + dead_peer_sink.clone(), + ) + }; + + let new_peer = { + match new_peer { + Ok(peer) => peer, + Err(e) => { + error!(err=%e, "Failed to spawn peer"); + continue; + } } }; + info!("Accepted new inbound peer"); self.add_peer( Endpoint::new( @@ -863,24 +909,25 @@ where return; } }; + let remote_address = con.remote_address(); + + + let tx_bytes = Arc::new(AtomicU64::new(0)); + let rx_bytes = Arc::new(AtomicU64::new(0)); let quic_peer = match con.accept_bi().await { - Ok((tx, rx)) => Quic::new(tx, rx, con.remote_address()), + Ok((tx, rx)) => Quic::new(tx, rx, con, rx_bytes.clone(), tx_bytes.clone()), Err(e) => { debug!(err=%e, "Failed to accept bidirectional quic stream"); return; } }; - let tx_bytes = Arc::new(AtomicU64::new(0)); - let rx_bytes = Arc::new(AtomicU64::new(0)); let new_peer = match Peer::new( router_data_tx.clone(), router_control_tx.clone(), quic_peer, dead_peer_sink.clone(), - tx_bytes.clone(), - rx_bytes.clone(), ) { Ok(peer) => peer, Err(e) => { @@ -888,9 +935,9 @@ where return; } }; - info!(remote=%con.remote_address(), "Accepted new inbound quic peer"); + info!(remote=%remote_address, "Accepted new inbound quic peer"); self.add_peer( - Endpoint::new(Protocol::Quic, con.remote_address()), + Endpoint::new(Protocol::Quic, remote_address), PeerType::Inbound, ConnectionTraffic { tx_bytes, rx_bytes }, Some(new_peer), @@ -1206,16 +1253,21 @@ fn make_quic_endpoint( transport_config.max_idle_timeout(Some(Duration::from_secs(60).try_into()?)); transport_config.mtu_discovery_config(Some(MtuDiscoveryConfig::default())); transport_config.keep_alive_interval(Some(Duration::from_secs(20))); - // we don't use datagrams. - transport_config.datagram_receive_buffer_size(None); - transport_config.datagram_send_buffer_size(0); - // TODO: further tweak this. + transport_config.datagram_receive_buffer_size(Some(16 << 20)); + transport_config.datagram_send_buffer_size(16 << 20); + transport_config.initial_mtu(1500); + transport_config.enable_segmentation_offload(true); + transport_config.send_window((8 * (10u32 << 20)).into()); + transport_config.stream_receive_window((10u32 << 20).into()); + let mut congestion_controller = congestion::CubicConfig::default(); + congestion_controller.initial_window(1 << 22); // 4MiB + // TODO: further tweak this. let socket = std::net::UdpSocket::bind(("::", quic_listen_port)) .and_then(|socket| set_fw_mark(socket, firewall_mark))?; debug!("Bound UDP socket for Quic"); - //TODO tweak or confirm + // TODO: tweak or confirm let endpoint = quinn::Endpoint::new( quinn::EndpointConfig::default(), Some(server_config), diff --git a/mycelium/src/router.rs b/mycelium/src/router.rs index e3ebab96..c091969b 100644 --- a/mycelium/src/router.rs +++ b/mycelium/src/router.rs @@ -2091,15 +2091,14 @@ fn advertised_update_interval(sre: &RouteEntry) -> Duration { mod tests { use std::{ net::{IpAddr, Ipv6Addr}, - sync::{atomic::AtomicU64, Arc}, time::Duration, }; use tokio::sync::mpsc; use crate::{ - babel::Update, crypto::PublicKey, metric::Metric, peer::Peer, router_id::RouterId, - sequence_number::SeqNo, source_table::SourceKey, subnet::Subnet, + babel::Update, connection::DuplexStream, crypto::PublicKey, metric::Metric, peer::Peer, + router_id::RouterId, sequence_number::SeqNo, source_table::SourceKey, subnet::Subnet, }; #[test] @@ -2156,10 +2155,8 @@ mod tests { let neighbor = Peer::new( router_data_tx, router_control_tx, - con1, + DuplexStream::new(con1), dead_peer_sink, - Arc::new(AtomicU64::new(0)), - Arc::new(AtomicU64::new(0)), ) .expect("Can create a dummy peer"); let subnet = Subnet::new(IpAddr::V6(Ipv6Addr::new(0x400, 0, 0, 0, 0, 0, 0, 0)), 64) diff --git a/mycelium/src/source_table.rs b/mycelium/src/source_table.rs index 2839065f..562fcb71 100644 --- a/mycelium/src/source_table.rs +++ b/mycelium/src/source_table.rs @@ -169,6 +169,7 @@ mod tests { use crate::{ babel, + connection::DuplexStream, crypto::SecretKey, metric::Metric, peer::Peer, @@ -178,11 +179,7 @@ mod tests { source_table::{FeasibilityDistance, SourceKey, SourceTable}, subnet::Subnet, }; - use std::{ - net::Ipv6Addr, - sync::{atomic::AtomicU64, Arc}, - time::Duration, - }; + use std::{net::Ipv6Addr, time::Duration}; /// A retraction is always considered to be feasible. #[tokio::test] @@ -378,10 +375,8 @@ mod tests { let neighbor = Peer::new( router_data_tx, router_control_tx, - con1, + DuplexStream::new(con1), dead_peer_sink, - Arc::new(AtomicU64::new(0)), - Arc::new(AtomicU64::new(0)), ) .expect("Can create a dummy peer"); @@ -423,10 +418,8 @@ mod tests { let neighbor = Peer::new( router_data_tx, router_control_tx, - con1, + DuplexStream::new(con1), dead_peer_sink, - Arc::new(AtomicU64::new(0)), - Arc::new(AtomicU64::new(0)), ) .expect("Can create a dummy peer"); @@ -468,10 +461,8 @@ mod tests { let neighbor = Peer::new( router_data_tx, router_control_tx, - con1, + DuplexStream::new(con1), dead_peer_sink, - Arc::new(AtomicU64::new(0)), - Arc::new(AtomicU64::new(0)), ) .expect("Can create a dummy peer");