From 706354667efb02ffbc03f71d39493b7ad54e1cfa Mon Sep 17 00:00:00 2001 From: link2xt Date: Sun, 20 Jul 2025 14:01:14 +0000 Subject: [PATCH] refactor!: use try_next() more I have looked through the code to make sure errors are not ignored, especially in loops. Ignoring errors in loops when reading from streams may result in infinite loop reading the same error over and over again. No bugs found, but I refactored reading from streams to use `try_next()` more to bubble up the errors with `?` as soon as possible. This is a breaking change since `read_response()` is resultified to return `Result>` instead of `Option>`. `read_response()` is a public interface that is used by library users to read the banner. --- examples/src/bin/integration.rs | 4 +- src/client.rs | 125 +++++++++++++++----------------- src/extensions/id.rs | 5 +- src/extensions/idle.rs | 3 +- src/extensions/quota.rs | 10 +-- src/parse.rs | 26 +++---- 6 files changed, 78 insertions(+), 95 deletions(-) diff --git a/examples/src/bin/integration.rs b/examples/src/bin/integration.rs index 42d5ea6..ff15fb1 100644 --- a/examples/src/bin/integration.rs +++ b/examples/src/bin/integration.rs @@ -31,7 +31,7 @@ async fn session(user: &str) -> Result Result<()> { let mut client = async_imap::Client::new(tcp_stream); let _greeting = client .read_response() - .await + .await? .context("unexpected end of stream, expected greeting")?; client.run_command_and_check_ok("STARTTLS", None).await?; let stream = client.into_inner(); diff --git a/src/client.rs b/src/client.rs index 05ecfd6..c91cd17 100644 --- a/src/client.rs +++ b/src/client.rs @@ -10,7 +10,7 @@ use async_std::io::{Read, Write, WriteExt}; use base64::Engine as _; use extensions::id::{format_identification, parse_id}; use extensions::quota::parse_get_quota_root; -use futures::{io, Stream, StreamExt}; +use futures::{io, Stream, TryStreamExt}; use imap_proto::{Metadata, RequestId, Response}; #[cfg(feature = "runtime-tokio")] use tokio::io::{AsyncRead as Read, AsyncWrite as Write, AsyncWriteExt}; @@ -122,7 +122,7 @@ macro_rules! ok_or_unauth_client_err { ($r:expr, $self:expr) => { match $r { Ok(o) => o, - Err(e) => return Err((e, $self)), + Err(e) => return Err((e.into(), $self)), } }; } @@ -262,42 +262,37 @@ impl Client { // explicit match blocks neccessary to convert error to tuple and not bind self too // early (see also comment on `login`) loop { - if let Some(res) = self.read_response().await { - let res = ok_or_unauth_client_err!(res.map_err(Into::into), self); - match res.parsed() { - Response::Continue { information, .. } => { - let challenge = if let Some(text) = information { - ok_or_unauth_client_err!( - base64::engine::general_purpose::STANDARD - .decode(text.as_ref()) - .map_err(|e| Error::Parse(ParseError::Authentication( - (*text).to_string(), - Some(e) - ))), - self - ) - } else { - Vec::new() - }; - let raw_response = &mut authenticator.process(&challenge); - let auth_response = - base64::engine::general_purpose::STANDARD.encode(raw_response); - - ok_or_unauth_client_err!( - self.conn.run_command_untagged(&auth_response).await, - self - ); - } - _ => { + let Some(res) = ok_or_unauth_client_err!(self.read_response().await, self) else { + return Err((Error::ConnectionLost, self)); + }; + match res.parsed() { + Response::Continue { information, .. } => { + let challenge = if let Some(text) = information { ok_or_unauth_client_err!( - self.check_done_ok_from(&id, None, res).await, + base64::engine::general_purpose::STANDARD + .decode(text.as_ref()) + .map_err(|e| Error::Parse(ParseError::Authentication( + (*text).to_string(), + Some(e) + ))), self - ); - return Ok(Session::new(self.conn)); - } + ) + } else { + Vec::new() + }; + let raw_response = &mut authenticator.process(&challenge); + let auth_response = + base64::engine::general_purpose::STANDARD.encode(raw_response); + + ok_or_unauth_client_err!( + self.conn.run_command_untagged(&auth_response).await, + self + ); + } + _ => { + ok_or_unauth_client_err!(self.check_done_ok_from(&id, None, res).await, self); + return Ok(Session::new(self.conn)); } - } else { - return Err((Error::ConnectionLost, self)); } } } @@ -975,12 +970,13 @@ impl Session { mailbox_pattern.unwrap_or("\"\"") )) .await?; - - Ok(parse_names( + let names = parse_names( &mut self.conn.stream, self.unsolicited_responses_tx.clone(), id, - )) + ); + + Ok(names) } /// The [`LSUB` command](https://tools.ietf.org/html/rfc3501#section-6.3.9) returns a subset of @@ -1136,23 +1132,20 @@ impl Session { )) .await?; - match self.read_response().await { - Some(Ok(res)) => { - if let Response::Continue { .. } = res.parsed() { - self.stream.as_mut().write_all(content).await?; - self.stream.as_mut().write_all(b"\r\n").await?; - self.stream.flush().await?; - self.conn - .check_done_ok(&id, Some(self.unsolicited_responses_tx.clone())) - .await?; - Ok(()) - } else { - Err(Error::Append) - } - } - Some(Err(err)) => Err(err.into()), - _ => Err(Error::Append), - } + let Some(res) = self.read_response().await? else { + return Err(Error::Append); + }; + let Response::Continue { .. } = res.parsed() else { + return Err(Error::Append); + }; + + self.stream.as_mut().write_all(content).await?; + self.stream.as_mut().write_all(b"\r\n").await?; + self.stream.flush().await?; + self.conn + .check_done_ok(&id, Some(self.unsolicited_responses_tx.clone())) + .await?; + Ok(()) } /// The [`SEARCH` command](https://tools.ietf.org/html/rfc3501#section-6.4.4) searches the @@ -1352,7 +1345,7 @@ impl Session { } /// Read the next response on the connection. - pub async fn read_response(&mut self) -> Option> { + pub async fn read_response(&mut self) -> io::Result> { self.conn.read_response().await } } @@ -1377,8 +1370,8 @@ impl Connection { } /// Read the next response on the connection. - pub async fn read_response(&mut self) -> Option> { - self.stream.next().await + pub async fn read_response(&mut self) -> io::Result> { + self.stream.try_next().await } pub(crate) async fn run_command_untagged(&mut self, command: &str) -> Result<()> { @@ -1415,8 +1408,8 @@ impl Connection { id: &RequestId, unsolicited: Option>, ) -> Result<()> { - if let Some(first_res) = self.stream.next().await { - self.check_done_ok_from(id, unsolicited, first_res?).await + if let Some(first_res) = self.stream.try_next().await? { + self.check_done_ok_from(id, unsolicited, first_res).await } else { Err(Error::ConnectionLost) } @@ -1447,11 +1440,10 @@ impl Connection { handle_unilateral(response, unsolicited); } - if let Some(res) = self.stream.next().await { - response = res?; - } else { + let Some(res) = self.stream.try_next().await? else { return Err(Error::ConnectionLost); - } + }; + response = res; } } @@ -1495,6 +1487,7 @@ mod tests { use std::future::Future; use async_std::sync::{Arc, Mutex}; + use futures::StreamExt; use imap_proto::Status; macro_rules! mock_client { @@ -1555,7 +1548,7 @@ mod tests { async fn readline_eof() { let mock_stream = MockStream::default().with_eof(); let mut client = mock_client!(mock_stream); - let res = client.read_response().await; + let res = client.read_response().await.unwrap(); assert!(res.is_none()); } @@ -2117,7 +2110,7 @@ mod tests { .unwrap(); // Unexpected EOF. - let err = fetch_result.next().await.unwrap().unwrap_err(); + let err = fetch_result.try_next().await.unwrap_err(); let Error::Io(io_err) = err else { panic!("Unexpected error type: {err}") }; diff --git a/src/extensions/id.rs b/src/extensions/id.rs index 00e6c84..a9e8024 100644 --- a/src/extensions/id.rs +++ b/src/extensions/id.rs @@ -43,10 +43,9 @@ pub(crate) async fn parse_id> + Unpin> let mut id = None; while let Some(resp) = stream .take_while(|res| filter(res, &command_tag)) - .next() - .await + .try_next() + .await? { - let resp = resp?; match resp.parsed() { Response::Id(res) => { id = res.as_ref().map(|m| { diff --git a/src/extensions/idle.rs b/src/extensions/idle.rs index 70e38e1..072849d 100644 --- a/src/extensions/idle.rs +++ b/src/extensions/idle.rs @@ -182,8 +182,7 @@ impl Handle { pub async fn init(&mut self) -> Result<()> { let id = self.session.run_command("IDLE").await?; self.id = Some(id); - while let Some(res) = self.session.stream.next().await { - let res = res?; + while let Some(res) = self.session.stream.try_next().await? { match res.parsed() { Response::Continue { .. } => { return Ok(()); diff --git a/src/extensions/quota.rs b/src/extensions/quota.rs index 8b62535..0bb8d52 100644 --- a/src/extensions/quota.rs +++ b/src/extensions/quota.rs @@ -23,10 +23,9 @@ pub(crate) async fn parse_get_quota> + let mut quota = None; while let Some(resp) = stream .take_while(|res| filter(res, &command_tag)) - .next() - .await + .try_next() + .await? { - let resp = resp?; match resp.parsed() { Response::Quota(q) => quota = Some(q.clone().into()), _ => { @@ -53,10 +52,9 @@ pub(crate) async fn parse_get_quota_root { roots.push(qr.clone().into()); diff --git a/src/parse.rs b/src/parse.rs index e99f10c..f1daed5 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -99,8 +99,7 @@ pub(crate) async fn parse_status> + Un ) -> Result { let mut mbox = Mailbox::default(); - while let Some(resp) = stream.next().await { - let resp = resp?; + while let Some(resp) = stream.try_next().await? { match resp.parsed() { Response::Done { tag, @@ -192,10 +191,9 @@ pub(crate) async fn parse_capabilities while let Some(resp) = stream .take_while(|res| filter(res, &command_tag)) - .next() - .await + .try_next() + .await? { - let resp = resp?; match resp.parsed() { Response::Capabilities(cs) => { for c in cs { @@ -218,10 +216,9 @@ pub(crate) async fn parse_noop> + Unpi ) -> Result<()> { while let Some(resp) = stream .take_while(|res| filter(res, &command_tag)) - .next() - .await + .try_next() + .await? { - let resp = resp?; handle_unilateral(resp, unsolicited.clone()); } @@ -235,8 +232,7 @@ pub(crate) async fn parse_mailbox> + U ) -> Result { let mut mailbox = Mailbox::default(); - while let Some(resp) = stream.next().await { - let resp = resp?; + while let Some(resp) = stream.try_next().await? { match resp.parsed() { Response::Done { tag, @@ -345,10 +341,9 @@ pub(crate) async fn parse_ids> + Unpin while let Some(resp) = stream .take_while(|res| filter(res, &command_tag)) - .next() - .await + .try_next() + .await? { - let resp = resp?; match resp.parsed() { Response::MailboxData(MailboxDatum::Search(cs)) => { for c in cs { @@ -374,10 +369,9 @@ pub(crate) async fn parse_metadata> + let mut res_values = Vec::new(); while let Some(resp) = stream .take_while(|res| filter(res, &command_tag)) - .next() - .await + .try_next() + .await? { - let resp = resp?; match resp.parsed() { // METADATA Response with Values //