Skip to content

Commit 4d5368b

Browse files
committed
Add flushing for peer connections as separate select branch
This restores the old behavior which was introduced to fix an issue where a peer task could get stuck in a flush call (which was not a select branch), thus preventing the task from ever exitting, causing a buildup of tasks (which could keep open OS resources like file descriptors, causing an eventual resource exhaustion). To do this, introduce a Connection{Read,Write}Half, and add a `split` method to the Connection trait. This mimics the behavior of the old code which wrapped a connection in a framed and then split that entirely. Signed-off-by: Lee Smet <lee.smet@hotmail.com>
1 parent e7b226b commit 4d5368b

File tree

3 files changed

+303
-37
lines changed

3 files changed

+303
-37
lines changed

mycelium/src/connection.rs

Lines changed: 207 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@ use std::{
1212
use crate::packet::{self, ControlPacket, DataPacket, Packet};
1313

1414
use bytes::{Bytes, BytesMut};
15-
use futures::{SinkExt, StreamExt};
15+
use futures::{
16+
stream::{SplitSink, SplitStream},
17+
SinkExt, StreamExt,
18+
};
1619
use tokio::io::{AsyncRead, AsyncWrite};
1720

1821
mod tracked;
@@ -42,7 +45,34 @@ const PACKET_PROCESSING_COST_IP6_QUIC: u16 = 7;
4245
// TODO
4346
const PACKET_PROCESSING_COST_IP4_QUIC: u16 = 12;
4447

48+
pub trait ConnectionReadHalf: Send {
49+
/// Receive a packet from the remote end.
50+
fn receive_packet(&mut self) -> impl Future<Output = Option<io::Result<Packet>>> + Send;
51+
}
52+
53+
pub trait ConnectionWriteHalf: Send {
54+
/// Feeds a data packet on the connection. Depending on the connection you might need to call
55+
/// [`Connection::flush`] before the packet is actually sent.
56+
fn feed_data_packet(
57+
&mut self,
58+
packet: DataPacket,
59+
) -> impl Future<Output = io::Result<()>> + Send;
60+
61+
/// Feeds a control packet on the connection. Depending on the connection you might need to call
62+
/// [`Connection::flush`] before the packet is actually sent.
63+
fn feed_control_packet(
64+
&mut self,
65+
packet: ControlPacket,
66+
) -> impl Future<Output = io::Result<()>> + Send;
67+
68+
/// Flush the connection. This sends all buffered packets which haven't beend sent yet.
69+
fn flush(&mut self) -> impl Future<Output = io::Result<()>> + Send;
70+
}
71+
4572
pub trait Connection {
73+
type ReadHalf: ConnectionReadHalf;
74+
type WriteHalf: ConnectionWriteHalf;
75+
4676
/// Feeds a data packet on the connection. Depending on the connection you might need to call
4777
/// [`Connection::flush`] before the packet is actually sent.
4878
fn feed_data_packet(
@@ -68,6 +98,9 @@ pub trait Connection {
6898

6999
/// The static cost of using this connection
70100
fn static_link_cost(&self) -> Result<u16, io::Error>;
101+
102+
/// Split the connection in a read and write half which can be used independently
103+
fn split(self) -> (Self::ReadHalf, Self::WriteHalf);
71104
}
72105

73106
/// A wrapper about an asynchronous (non blocking) tcp stream.
@@ -93,6 +126,9 @@ impl TcpStream {
93126
}
94127

95128
impl Connection for TcpStream {
129+
type ReadHalf = TcpStreamReadHalf;
130+
type WriteHalf = TcpStreamWriteHalf;
131+
96132
async fn feed_data_packet(&mut self, packet: DataPacket) -> io::Result<()> {
97133
self.framed.feed(Packet::DataPacket(packet)).await
98134
}
@@ -122,6 +158,43 @@ impl Connection for TcpStream {
122158
SocketAddr::V6(_) => PACKET_PROCESSING_COST_IP6_TCP,
123159
})
124160
}
161+
162+
fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
163+
let (tx, rx) = self.framed.split();
164+
165+
(
166+
TcpStreamReadHalf { framed: rx },
167+
TcpStreamWriteHalf { framed: tx },
168+
)
169+
}
170+
}
171+
172+
pub struct TcpStreamReadHalf {
173+
framed: SplitStream<Framed<Tracked<tokio::net::TcpStream>, packet::Codec>>,
174+
}
175+
176+
impl ConnectionReadHalf for TcpStreamReadHalf {
177+
async fn receive_packet(&mut self) -> Option<io::Result<Packet>> {
178+
self.framed.next().await
179+
}
180+
}
181+
182+
pub struct TcpStreamWriteHalf {
183+
framed: SplitSink<Framed<Tracked<tokio::net::TcpStream>, packet::Codec>, packet::Packet>,
184+
}
185+
186+
impl ConnectionWriteHalf for TcpStreamWriteHalf {
187+
async fn feed_data_packet(&mut self, packet: DataPacket) -> io::Result<()> {
188+
self.framed.feed(Packet::DataPacket(packet)).await
189+
}
190+
191+
async fn feed_control_packet(&mut self, packet: ControlPacket) -> io::Result<()> {
192+
self.framed.feed(Packet::ControlPacket(packet)).await
193+
}
194+
195+
async fn flush(&mut self) -> io::Result<()> {
196+
self.framed.flush().await
197+
}
125198
}
126199

127200
/// A wrapper around a quic send and quic receive stream, implementing the [`Connection`] trait.
@@ -213,6 +286,10 @@ impl AsyncWrite for QuicStream {
213286
}
214287

215288
impl Connection for Quic {
289+
type ReadHalf = QuicReadHalf;
290+
291+
type WriteHalf = QuicWriteHalf;
292+
216293
async fn feed_data_packet(&mut self, packet: DataPacket) -> io::Result<()> {
217294
let mut codec = packet::Codec::new();
218295
let mut buffer = BytesMut::with_capacity(1500);
@@ -271,6 +348,91 @@ impl Connection for Quic {
271348
SocketAddr::V6(_) => PACKET_PROCESSING_COST_IP6_QUIC,
272349
})
273350
}
351+
352+
fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
353+
let Self {
354+
framed,
355+
con,
356+
read,
357+
write,
358+
} = self;
359+
360+
let (tx, rx) = framed.split();
361+
362+
(
363+
QuicReadHalf {
364+
framed: rx,
365+
con: con.clone(),
366+
read,
367+
},
368+
QuicWriteHalf {
369+
framed: tx,
370+
con,
371+
write,
372+
},
373+
)
374+
}
375+
}
376+
377+
pub struct QuicReadHalf {
378+
framed: SplitStream<Framed<Tracked<QuicStream>, packet::Codec>>,
379+
con: quinn::Connection,
380+
read: Arc<AtomicU64>,
381+
}
382+
383+
pub struct QuicWriteHalf {
384+
framed: SplitSink<Framed<Tracked<QuicStream>, packet::Codec>, packet::Packet>,
385+
con: quinn::Connection,
386+
write: Arc<AtomicU64>,
387+
}
388+
389+
impl ConnectionReadHalf for QuicReadHalf {
390+
async fn receive_packet(&mut self) -> Option<io::Result<Packet>> {
391+
tokio::select! {
392+
datagram = self.con.read_datagram() => {
393+
let datagram_bytes = match datagram {
394+
Ok(buffer) => buffer,
395+
Err(e) => return Some(Err(e.into())),
396+
};
397+
let recv_len = datagram_bytes.len();
398+
self.read.fetch_add(recv_len as u64, Ordering::Relaxed);
399+
let mut codec = packet::Codec::new();
400+
match codec.decode(&mut datagram_bytes.into()) {
401+
Ok(Some(packet)) => Some(Ok(packet)),
402+
// Partial? packet read. We consider this to be a stream hangup
403+
// TODO: verify
404+
Ok(None) => None,
405+
Err(e) => Some(Err(e)),
406+
}
407+
},
408+
packet = self.framed.next() => {
409+
packet
410+
}
411+
412+
}
413+
}
414+
}
415+
416+
impl ConnectionWriteHalf for QuicWriteHalf {
417+
async fn feed_data_packet(&mut self, packet: DataPacket) -> io::Result<()> {
418+
let mut codec = packet::Codec::new();
419+
let mut buffer = BytesMut::with_capacity(1500);
420+
codec.encode(Packet::DataPacket(packet), &mut buffer)?;
421+
422+
let data: Bytes = buffer.into();
423+
let tx_len = data.len();
424+
self.write.fetch_add(tx_len as u64, Ordering::Relaxed);
425+
426+
self.con.send_datagram(data).map_err(io::Error::other)
427+
}
428+
429+
async fn feed_control_packet(&mut self, packet: ControlPacket) -> io::Result<()> {
430+
self.framed.feed(Packet::ControlPacket(packet)).await
431+
}
432+
433+
async fn flush(&mut self) -> io::Result<()> {
434+
self.framed.flush().await
435+
}
274436
}
275437

276438
#[cfg(test)]
@@ -291,6 +453,9 @@ impl DuplexStream {
291453

292454
#[cfg(test)]
293455
impl Connection for DuplexStream {
456+
type ReadHalf = DuplexStreamReadHalf;
457+
type WriteHalf = DuplexStreamWriteHalf;
458+
294459
async fn feed_data_packet(&mut self, packet: DataPacket) -> io::Result<()> {
295460
self.framed.feed(Packet::DataPacket(packet)).await
296461
}
@@ -314,4 +479,45 @@ impl Connection for DuplexStream {
314479
fn static_link_cost(&self) -> Result<u16, io::Error> {
315480
Ok(1)
316481
}
482+
483+
fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
484+
let (tx, rx) = self.framed.split();
485+
486+
(
487+
DuplexStreamReadHalf { framed: rx },
488+
DuplexStreamWriteHalf { framed: tx },
489+
)
490+
}
491+
}
492+
493+
#[cfg(test)]
494+
pub struct DuplexStreamReadHalf {
495+
framed: SplitStream<Framed<tokio::io::DuplexStream, packet::Codec>>,
496+
}
497+
498+
#[cfg(test)]
499+
pub struct DuplexStreamWriteHalf {
500+
framed: SplitSink<Framed<tokio::io::DuplexStream, packet::Codec>, packet::Packet>,
501+
}
502+
503+
#[cfg(test)]
504+
impl ConnectionReadHalf for DuplexStreamReadHalf {
505+
async fn receive_packet(&mut self) -> Option<io::Result<Packet>> {
506+
self.framed.next().await
507+
}
508+
}
509+
510+
#[cfg(test)]
511+
impl ConnectionWriteHalf for DuplexStreamWriteHalf {
512+
async fn feed_data_packet(&mut self, packet: DataPacket) -> io::Result<()> {
513+
self.framed.feed(Packet::DataPacket(packet)).await
514+
}
515+
516+
async fn feed_control_packet(&mut self, packet: ControlPacket) -> io::Result<()> {
517+
self.framed.feed(Packet::ControlPacket(packet)).await
518+
}
519+
520+
async fn flush(&mut self) -> io::Result<()> {
521+
self.framed.flush().await
522+
}
317523
}

mycelium/src/connection/tls.rs

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ use std::{
44
sync::{atomic::AtomicU64, Arc},
55
};
66

7-
use futures::{SinkExt, StreamExt};
7+
use futures::{
8+
stream::{SplitSink, SplitStream},
9+
SinkExt, StreamExt,
10+
};
811
use tokio::net::TcpStream;
912
use tokio_util::codec::Framed;
1013

@@ -36,6 +39,9 @@ impl TlsStream {
3639
}
3740

3841
impl super::Connection for TlsStream {
42+
type ReadHalf = TlsStreamReadHalf;
43+
type WriteHalf = TlsStreamWriteHalf;
44+
3945
async fn feed_data_packet(&mut self, packet: crate::packet::DataPacket) -> io::Result<()> {
4046
self.framed.feed(Packet::DataPacket(packet)).await
4147
}
@@ -68,4 +74,47 @@ impl super::Connection for TlsStream {
6874
SocketAddr::V6(_) => super::PACKET_PROCESSING_COST_IP6_TCP,
6975
})
7076
}
77+
78+
fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
79+
let (tx, rx) = self.framed.split();
80+
81+
(
82+
TlsStreamReadHalf { framed: rx },
83+
TlsStreamWriteHalf { framed: tx },
84+
)
85+
}
86+
}
87+
88+
pub struct TlsStreamReadHalf {
89+
framed: SplitStream<Framed<Tracked<tokio_openssl::SslStream<TcpStream>>, packet::Codec>>,
90+
}
91+
92+
pub struct TlsStreamWriteHalf {
93+
framed: SplitSink<
94+
Framed<Tracked<tokio_openssl::SslStream<TcpStream>>, packet::Codec>,
95+
packet::Packet,
96+
>,
97+
}
98+
99+
impl super::ConnectionReadHalf for TlsStreamReadHalf {
100+
async fn receive_packet(&mut self) -> Option<io::Result<crate::packet::Packet>> {
101+
self.framed.next().await
102+
}
103+
}
104+
105+
impl super::ConnectionWriteHalf for TlsStreamWriteHalf {
106+
async fn feed_data_packet(&mut self, packet: crate::packet::DataPacket) -> io::Result<()> {
107+
self.framed.feed(Packet::DataPacket(packet)).await
108+
}
109+
110+
async fn feed_control_packet(
111+
&mut self,
112+
packet: crate::packet::ControlPacket,
113+
) -> io::Result<()> {
114+
self.framed.feed(Packet::ControlPacket(packet)).await
115+
}
116+
117+
async fn flush(&mut self) -> io::Result<()> {
118+
self.framed.flush().await
119+
}
71120
}

0 commit comments

Comments
 (0)