From 78bfb004e5c7929d52c28b92fda38b41978c75c9 Mon Sep 17 00:00:00 2001 From: Zava <693320+zavakid@users.noreply.github.com> Date: Wed, 8 Oct 2025 19:29:43 +0800 Subject: [PATCH] Refine body receiver polling --- crates/http/src/protocol/body/body_channel.rs | 101 ++++++++++++++---- 1 file changed, 79 insertions(+), 22 deletions(-) diff --git a/crates/http/src/protocol/body/body_channel.rs b/crates/http/src/protocol/body/body_channel.rs index b963f42..b140b41 100644 --- a/crates/http/src/protocol/body/body_channel.rs +++ b/crates/http/src/protocol/body/body_channel.rs @@ -1,6 +1,6 @@ use crate::protocol::{Message, ParseError, PayloadItem, PayloadSize, RequestHeader}; use bytes::Bytes; -use futures::{SinkExt, Stream, StreamExt, channel::mpsc}; +use futures::{Sink, SinkExt, Stream, StreamExt, channel::mpsc}; use http_body::{Body, Frame, SizeHint}; use std::pin::Pin; use std::task::{Context, Poll}; @@ -120,6 +120,7 @@ pub(crate) struct BodyReceiver { signal_sender: mpsc::Sender, data_receiver: mpsc::Receiver>, payload_size: PayloadSize, + in_flight: bool, } impl BodyReceiver { @@ -128,21 +129,7 @@ impl BodyReceiver { data_receiver: mpsc::Receiver>, payload_size: PayloadSize, ) -> Self { - Self { signal_sender, data_receiver, payload_size } - } -} - -impl BodyReceiver { - pub async fn receive_data(&mut self) -> Result { - if let Err(e) = self.signal_sender.send(BodyRequestSignal::RequestData).await { - error!("failed to send request_more through channel, {}", e); - return Err(ParseError::invalid_body("failed to send signal when receive body data")); - } - - self.data_receiver - .next() - .await - .unwrap_or_else(|| Err(ParseError::invalid_body("body stream should not receive None when receive data"))) + Self { signal_sender, data_receiver, payload_size, in_flight: false } } } @@ -153,14 +140,40 @@ impl Body for BodyReceiver { fn poll_frame(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll, Self::Error>>> { let this = self.get_mut(); - tokio::pin! { - let future = this.receive_data(); + if !this.in_flight { + match Pin::new(&mut this.signal_sender).poll_ready(cx) { + Poll::Ready(Ok(())) => { + if let Err(e) = Pin::new(&mut this.signal_sender).start_send(BodyRequestSignal::RequestData) { + error!("failed to send request_more through channel, {}", e); + return Poll::Ready(Some(Err(ParseError::invalid_body("failed to send signal when receive body data")))); + } + this.in_flight = true; + } + Poll::Ready(Err(e)) => { + error!("failed to prepare request_more through channel, {}", e); + return Poll::Ready(Some(Err(ParseError::invalid_body("failed to send signal when receive body data")))); + } + Poll::Pending => return Poll::Pending, + } } - match future.poll(cx) { - Poll::Ready(Ok(PayloadItem::Chunk(bytes))) => Poll::Ready(Some(Ok(Frame::data(bytes)))), - Poll::Ready(Ok(PayloadItem::Eof)) => Poll::Ready(None), - Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))), + match this.data_receiver.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(PayloadItem::Chunk(bytes)))) => { + this.in_flight = false; + Poll::Ready(Some(Ok(Frame::data(bytes)))) + } + Poll::Ready(Some(Ok(PayloadItem::Eof))) => { + this.in_flight = false; + Poll::Ready(None) + } + Poll::Ready(Some(Err(e))) => { + this.in_flight = false; + Poll::Ready(Some(Err(e))) + } + Poll::Ready(None) => { + this.in_flight = false; + Poll::Ready(Some(Err(ParseError::invalid_body("body stream should not receive None when receive data")))) + } Poll::Pending => Poll::Pending, } } @@ -189,3 +202,47 @@ impl From for SizeHint { } } } + +#[cfg(test)] +mod tests { + use super::*; + use bytes::Bytes; + use futures::channel::mpsc; + use futures::task::noop_waker_ref; + use futures::{FutureExt, StreamExt}; + use std::pin::Pin; + use std::task::{Context, Poll}; + + #[tokio::test] + async fn body_receiver_only_requests_once_until_response() { + let (signal_sender, mut signal_receiver) = mpsc::channel(8); + let (mut data_sender, data_receiver) = mpsc::channel(8); + let mut body_receiver = BodyReceiver::new(signal_sender, data_receiver, PayloadSize::new_chunked()); + + let waker = noop_waker_ref(); + let mut cx = Context::from_waker(waker); + + assert!(matches!(Pin::new(&mut body_receiver).poll_frame(&mut cx), Poll::Pending)); + assert!(matches!(signal_receiver.next().await, Some(BodyRequestSignal::RequestData))); + + assert!(matches!(Pin::new(&mut body_receiver).poll_frame(&mut cx), Poll::Pending)); + assert!(signal_receiver.next().now_or_never().is_none()); + + data_sender.try_send(Ok(PayloadItem::Chunk(Bytes::from_static(b"hello")))).expect("send chunk"); + + match Pin::new(&mut body_receiver).poll_frame(&mut cx) { + Poll::Ready(Some(Ok(frame))) => { + let data = frame.into_data().expect("expected data frame"); + assert_eq!(data, Bytes::from_static(b"hello")); + } + other => panic!("unexpected poll result: {:?}", other), + } + + assert!(matches!(Pin::new(&mut body_receiver).poll_frame(&mut cx), Poll::Pending)); + assert!(matches!(signal_receiver.next().await, Some(BodyRequestSignal::RequestData))); + + data_sender.try_send(Ok(PayloadItem::Eof)).expect("send eof"); + + assert!(matches!(Pin::new(&mut body_receiver).poll_frame(&mut cx), Poll::Ready(None))); + } +}