Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 79 additions & 22 deletions crates/http/src/protocol/body/body_channel.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::protocol::{Message, ParseError, PayloadItem, PayloadSize, RequestHeader};
use bytes::Bytes;
use futures::{SinkExt, Stream, StreamExt, channel::mpsc};
use futures::{Sink, SinkExt, Stream, StreamExt, channel::mpsc};
use http_body::{Body, Frame, SizeHint};
use std::pin::Pin;
use std::task::{Context, Poll};
Expand Down Expand Up @@ -120,6 +120,7 @@ pub(crate) struct BodyReceiver {
signal_sender: mpsc::Sender<BodyRequestSignal>,
data_receiver: mpsc::Receiver<Result<PayloadItem, ParseError>>,
payload_size: PayloadSize,
in_flight: bool,
}

impl BodyReceiver {
Expand All @@ -128,21 +129,7 @@ impl BodyReceiver {
data_receiver: mpsc::Receiver<Result<PayloadItem, ParseError>>,
payload_size: PayloadSize,
) -> Self {
Self { signal_sender, data_receiver, payload_size }
}
}

impl BodyReceiver {
pub async fn receive_data(&mut self) -> Result<PayloadItem, ParseError> {
if let Err(e) = self.signal_sender.send(BodyRequestSignal::RequestData).await {
error!("failed to send request_more through channel, {}", e);
return Err(ParseError::invalid_body("failed to send signal when receive body data"));
}

self.data_receiver
.next()
.await
.unwrap_or_else(|| Err(ParseError::invalid_body("body stream should not receive None when receive data")))
Self { signal_sender, data_receiver, payload_size, in_flight: false }
}
}

Expand All @@ -153,14 +140,40 @@ impl Body for BodyReceiver {
fn poll_frame(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
let this = self.get_mut();

tokio::pin! {
let future = this.receive_data();
if !this.in_flight {
match Pin::new(&mut this.signal_sender).poll_ready(cx) {
Poll::Ready(Ok(())) => {
if let Err(e) = Pin::new(&mut this.signal_sender).start_send(BodyRequestSignal::RequestData) {
error!("failed to send request_more through channel, {}", e);
return Poll::Ready(Some(Err(ParseError::invalid_body("failed to send signal when receive body data"))));
}
this.in_flight = true;
}
Poll::Ready(Err(e)) => {
error!("failed to prepare request_more through channel, {}", e);
return Poll::Ready(Some(Err(ParseError::invalid_body("failed to send signal when receive body data"))));
}
Poll::Pending => return Poll::Pending,
}
}

match future.poll(cx) {
Poll::Ready(Ok(PayloadItem::Chunk(bytes))) => Poll::Ready(Some(Ok(Frame::data(bytes)))),
Poll::Ready(Ok(PayloadItem::Eof)) => Poll::Ready(None),
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
match this.data_receiver.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(PayloadItem::Chunk(bytes)))) => {
this.in_flight = false;
Poll::Ready(Some(Ok(Frame::data(bytes))))
}
Poll::Ready(Some(Ok(PayloadItem::Eof))) => {
this.in_flight = false;
Poll::Ready(None)
}
Poll::Ready(Some(Err(e))) => {
this.in_flight = false;
Poll::Ready(Some(Err(e)))
}
Poll::Ready(None) => {
this.in_flight = false;
Poll::Ready(Some(Err(ParseError::invalid_body("body stream should not receive None when receive data"))))
}
Poll::Pending => Poll::Pending,
}
}
Expand Down Expand Up @@ -189,3 +202,47 @@ impl From<PayloadSize> for SizeHint {
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use futures::channel::mpsc;
use futures::task::noop_waker_ref;
use futures::{FutureExt, StreamExt};
use std::pin::Pin;
use std::task::{Context, Poll};

#[tokio::test]
async fn body_receiver_only_requests_once_until_response() {
let (signal_sender, mut signal_receiver) = mpsc::channel(8);
let (mut data_sender, data_receiver) = mpsc::channel(8);
let mut body_receiver = BodyReceiver::new(signal_sender, data_receiver, PayloadSize::new_chunked());

let waker = noop_waker_ref();
let mut cx = Context::from_waker(waker);

assert!(matches!(Pin::new(&mut body_receiver).poll_frame(&mut cx), Poll::Pending));
assert!(matches!(signal_receiver.next().await, Some(BodyRequestSignal::RequestData)));

assert!(matches!(Pin::new(&mut body_receiver).poll_frame(&mut cx), Poll::Pending));
assert!(signal_receiver.next().now_or_never().is_none());

data_sender.try_send(Ok(PayloadItem::Chunk(Bytes::from_static(b"hello")))).expect("send chunk");

match Pin::new(&mut body_receiver).poll_frame(&mut cx) {
Poll::Ready(Some(Ok(frame))) => {
let data = frame.into_data().expect("expected data frame");
assert_eq!(data, Bytes::from_static(b"hello"));
}
other => panic!("unexpected poll result: {:?}", other),
}

assert!(matches!(Pin::new(&mut body_receiver).poll_frame(&mut cx), Poll::Pending));
assert!(matches!(signal_receiver.next().await, Some(BodyRequestSignal::RequestData)));

data_sender.try_send(Ok(PayloadItem::Eof)).expect("send eof");

assert!(matches!(Pin::new(&mut body_receiver).poll_frame(&mut cx), Poll::Ready(None)));
}
}