diff --git a/Cargo.toml b/Cargo.toml index a21cd0fdf..28e3a9c77 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,19 +19,25 @@ rustdoc-args = ["--cfg", "docsrs", "--cfg", "reqwest_unstable"] targets = ["x86_64-unknown-linux-gnu", "wasm32-unknown-unknown"] [package.metadata.playground] -features = [ - "blocking", - "cookies", - "json", - "multipart", -] +features = ["blocking", "cookies", "json", "multipart"] [features] -default = ["default-tls", "charset", "http2", "macos-system-configuration"] +default = [ + "default-tls", + "charset", + "http2", + "macos-system-configuration", + "hickory-dns", +] # Note: this doesn't enable the 'native-tls' feature, which adds specific # functionality for it. -default-tls = ["dep:hyper-tls", "dep:native-tls-crate", "__tls", "dep:tokio-native-tls"] +default-tls = [ + "dep:hyper-tls", + "dep:native-tls-crate", + "__tls", + "dep:tokio-native-tls", +] http2 = ["h2", "hyper/http2", "hyper-util/http2"] @@ -45,7 +51,13 @@ rustls-tls-manual-roots = ["__rustls"] rustls-tls-webpki-roots = ["dep:webpki-roots", "__rustls"] rustls-tls-native-roots = ["dep:rustls-native-certs", "__rustls"] -blocking = ["futures-channel/sink", "futures-util/io", "futures-util/sink", "tokio/rt-multi-thread", "tokio/sync"] +blocking = [ + "futures-channel/sink", + "futures-util/io", + "futures-util/sink", + "tokio/rt-multi-thread", + "tokio/sync", +] charset = ["dep:encoding_rs"] @@ -53,7 +65,11 @@ cookies = ["dep:cookie_crate", "dep:cookie_store"] gzip = ["dep:async-compression", "async-compression?/gzip", "dep:tokio-util"] -brotli = ["dep:async-compression", "async-compression?/brotli", "dep:tokio-util"] +brotli = [ + "dep:async-compression", + "async-compression?/brotli", + "dep:tokio-util", +] deflate = ["dep:async-compression", "async-compression?/zlib", "dep:tokio-util"] @@ -74,7 +90,13 @@ macos-system-configuration = ["dep:system-configuration"] # Experimental HTTP/3 client. # Disabled while waiting for quinn to upgrade. -#http3 = ["rustls-tls-manual-roots", "dep:h3", "dep:h3-quinn", "dep:quinn", "dep:futures-channel"] +http3 = [ + "rustls-tls-manual-roots", + "dep:h3", + "dep:h3-quinn", + "dep:quinn", + "dep:futures-channel", +] # Internal (PRIVATE!) features used to aid testing. # Don't rely on these whatsoever. They may disappear at anytime. @@ -84,7 +106,14 @@ __tls = ["dep:rustls-pemfile", "tokio/io-util"] # Enables common rustls code. # Equivalent to rustls-tls-manual-roots but shorter :) -__rustls = ["dep:hyper-rustls", "dep:tokio-rustls", "dep:rustls", "__tls", "dep:rustls-pemfile", "rustls-pki-types"] +__rustls = [ + "dep:hyper-rustls", + "dep:tokio-rustls", + "dep:rustls", + "__tls", + "dep:rustls-pemfile", + "rustls-pki-types", +] # When enabled, disable using the cached SYS_PROXIES. __internal_proxy_sys_no_cache = [] @@ -100,7 +129,7 @@ tower-service = "0.3" futures-core = { version = "0.3.0", default-features = false } futures-util = { version = "0.3.0", default-features = false } sync_wrapper = "0.1.2" - +rand = "0.8" # Optional deps... ## json @@ -112,14 +141,25 @@ mime_guess = { version = "2.0", default-features = false, optional = true } encoding_rs = { version = "0.8", optional = true } http-body = "1" http-body-util = "0.1" -hyper = { version = "1", features = ["http1", "client"] } -hyper-util = { version = "0.1.3", features = ["http1", "client", "client-legacy", "tokio"] } +hyper = { git = "https://github.com/getsentry/hyper", rev = "9efa5aedce8a36e9a86115b8c104da44fcc43d7e", features = [ + "http1", + "client", +] } +hyper-util = { git = "https://github.com/getsentry/hyper-util", rev = "cff84af9d682add40b92c5a57ad43f73bb714081", features = [ + "http1", + "client", + "client-legacy", + "tokio", +] } h2 = { version = "0.4", optional = true } once_cell = "1" log = "0.4" mime = "0.3.16" percent-encoding = "2.1" -tokio = { version = "1.0", default-features = false, features = ["net", "time"] } +tokio = { version = "1.0", default-features = false, features = [ + "net", + "time", +] } pin-project-lite = "0.2.0" ipnet = "2.3" @@ -127,14 +167,14 @@ ipnet = "2.3" rustls-pemfile = { version = "2", optional = true } ## default-tls -hyper-tls = { version = "0.6", optional = true } +hyper-tls = { git = "https://github.com/getsentry/hyper-tls", rev = "7968d40a8b00842803adae4f65963bbf0b126dca", optional = true } native-tls-crate = { version = "0.2.10", optional = true, package = "native-tls" } tokio-native-tls = { version = "0.3.0", optional = true } # rustls-tls hyper-rustls = { version = "0.26.0", default-features = false, optional = true } rustls = { version = "0.22.2", optional = true } -rustls-pki-types = { version = "1.1.0", features = ["alloc"] ,optional = true } +rustls-pki-types = { version = "1.1.0", features = ["alloc"], optional = true } tokio-rustls = { version = "0.25", optional = true } webpki-roots = { version = "0.26.0", optional = true } rustls-native-certs = { version = "0.7", optional = true } @@ -144,32 +184,61 @@ cookie_crate = { version = "0.17.0", package = "cookie", optional = true } cookie_store = { version = "0.20.0", optional = true } ## compression -async-compression = { version = "0.4.0", default-features = false, features = ["tokio"], optional = true } -tokio-util = { version = "0.7.1", default-features = false, features = ["codec", "io"], optional = true } +async-compression = { version = "0.4.0", default-features = false, features = [ + "tokio", +], optional = true } +tokio-util = { version = "0.7.1", default-features = false, features = [ + "codec", + "io", +], optional = true } ## socks tokio-socks = { version = "0.5.1", optional = true } ## hickory-dns -hickory-resolver = { version = "0.24", optional = true, features = ["tokio-runtime"] } +hickory-resolver = { version = "0.24", optional = true, features = [ + "tokio-runtime", +] } # HTTP/3 experimental support h3 = { version = "0.0.4", optional = true } h3-quinn = { version = "0.0.5", optional = true } -quinn = { version = "0.10", default-features = false, features = ["tls-rustls", "ring", "runtime-tokio"], optional = true } +quinn = { version = "0.10", default-features = false, features = [ + "tls-rustls", + "ring", + "runtime-tokio", +], optional = true } futures-channel = { version = "0.3", optional = true } [target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] env_logger = "0.10" -hyper = { version = "1.1.0", default-features = false, features = ["http1", "http2", "client", "server"] } -hyper-util = { version = "0.1", features = ["http1", "http2", "client", "client-legacy", "server-auto", "tokio"] } +hyper = { git = "https://github.com/getsentry/hyper", rev = "9efa5aedce8a36e9a86115b8c104da44fcc43d7e", default-features = false, features = [ + "http1", + "http2", + "client", + "server", +] } +hyper-util = { git = "https://github.com/getsentry/hyper-util", rev = "cff84af9d682add40b92c5a57ad43f73bb714081", features = [ + "http1", + "http2", + "client", + "client-legacy", + "server-auto", + "tokio", +] } serde = { version = "1.0", features = ["derive"] } libflate = "1.0" brotli_crate = { package = "brotli", version = "3.3.0" } doc-comment = "0.3" -tokio = { version = "1.0", default-features = false, features = ["macros", "rt-multi-thread"] } -futures-util = { version = "0.3.0", default-features = false, features = ["std", "alloc"] } +tokio = { version = "1.0", default-features = false, features = [ + "macros", + "rt-multi-thread", +] } +futures-util = { version = "0.3.0", default-features = false, features = [ + "std", + "alloc", +] } [target.'cfg(windows)'.dependencies] winreg = "0.52.0" @@ -203,7 +272,7 @@ features = [ "ServiceWorkerGlobalScope", "RequestCredentials", "File", - "ReadableStream" + "ReadableStream", ] [target.'cfg(target_arch = "wasm32")'.dev-dependencies] diff --git a/src/async_impl/body.rs b/src/async_impl/body.rs index cd9658c64..ea36d635d 100644 --- a/src/async_impl/body.rs +++ b/src/async_impl/body.rs @@ -1,11 +1,10 @@ +use bytes::Bytes; +use http_body::Body as HttpBody; +use http_body_util::combinators::BoxBody; use std::fmt; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; - -use bytes::Bytes; -use http_body::Body as HttpBody; -use http_body_util::combinators::BoxBody; //use sync_wrapper::SyncWrapper; #[cfg(feature = "stream")] use tokio::fs::File; diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index 0d2361a92..213de6f67 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -2,10 +2,11 @@ use std::any::Any; use std::net::IpAddr; use std::sync::Arc; -use std::time::Duration; +use std::time::{Duration, Instant, SystemTime}; use std::{collections::HashMap, convert::TryInto, net::SocketAddr}; use std::{fmt, str}; +use crate::tls::TlsInfo; use bytes::Bytes; use http::header::{ Entry, HeaderMap, HeaderValue, ACCEPT, ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, @@ -13,6 +14,7 @@ use http::header::{ }; use http::uri::Scheme; use http::Uri; +use hyper::stats::RequestId; use hyper_util::client::legacy::connect::HttpConnector; #[cfg(feature = "default-tls")] use native_tls_crate::TlsConnector; @@ -35,6 +37,8 @@ use crate::connect::Connector; use crate::cookie; #[cfg(feature = "hickory-dns")] use crate::dns::hickory::HickoryDnsResolver; +use hickory_resolver; + use crate::dns::{gai::GaiResolver, DnsResolverWithOverrides, DynResolver, Resolve}; use crate::error; use crate::into_url::try_uri; @@ -53,6 +57,7 @@ use quinn::TransportConfig; use quinn::VarInt; type HyperResponseFuture = hyper_util::client::legacy::ResponseFuture; +const DEFAULT_DNS_PORT: u16 = 53; /// An asynchronous `Client` to make Requests with. /// @@ -152,6 +157,10 @@ struct Config { #[cfg(feature = "cookies")] cookie_store: Option>, hickory_dns: bool, + #[cfg(feature = "hickory-dns")] + dns_nameservers: Option>, + #[cfg(feature = "hickory-dns")] + ip_filter: fn(IpAddr) -> bool, error: Option, https_only: bool, #[cfg(feature = "http3")] @@ -248,6 +257,8 @@ impl ClientBuilder { interface: None, nodelay: true, hickory_dns: cfg!(feature = "hickory-dns"), + #[cfg(feature = "hickory-dns")] + ip_filter: |_| true, #[cfg(feature = "cookies")] cookie_store: None, https_only: false, @@ -263,6 +274,8 @@ impl ClientBuilder { #[cfg(feature = "http3")] quic_send_window: None, dns_resolver: None, + #[cfg(feature = "hickory-dns")] + dns_nameservers: None, }, } } @@ -299,7 +312,39 @@ impl ClientBuilder { let mut resolver: Arc = match config.hickory_dns { false => Arc::new(GaiResolver::new()), #[cfg(feature = "hickory-dns")] - true => Arc::new(HickoryDnsResolver::default()), + true => { + let mut resolver = HickoryDnsResolver::new(config.ip_filter); + if let Some(nameservers) = config.dns_nameservers { + let mut hickory_config = hickory_resolver::config::ResolverConfig::new(); + for ip in nameservers { + hickory_config.add_name_server( + hickory_resolver::config::NameServerConfig { + socket_addr: (ip, DEFAULT_DNS_PORT).into(), + protocol: hickory_resolver::config::Protocol::Udp, + tls_dns_name: None, + trust_negative_responses: false, + bind_addr: None, + }, + ); + hickory_config.add_name_server( + hickory_resolver::config::NameServerConfig { + socket_addr: (ip, DEFAULT_DNS_PORT).into(), + protocol: hickory_resolver::config::Protocol::Tcp, + tls_dns_name: None, + trust_negative_responses: false, + bind_addr: None, + }, + ); + } + + let mut opts = hickory_resolver::config::ResolverOpts::default(); + opts.use_hosts_file = false; + + resolver = resolver.with_config(hickory_config, opts); + } + + Arc::new(resolver) + } #[cfg(not(feature = "hickory-dns"))] true => unreachable!("hickory-dns shouldn't be enabled unless the feature is"), }; @@ -743,6 +788,7 @@ impl ClientBuilder { proxies, proxies_maybe_http_auth, https_only: config.https_only, + ip_filter: config.ip_filter, }), }) } @@ -1686,6 +1732,24 @@ impl ClientBuilder { self } + /// Configure custom DNS nameservers for this client when using hickory_dns + /// + /// # Example + /// ``` + /// # use std::net::{IpAddr, Ipv4Addr}; + /// let client = reqwest::Client::builder() + /// .dns_nameservers(vec![IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))]) + /// .build()?; + /// ``` + #[cfg(feature = "hickory-dns")] + pub fn dns_nameservers(mut self, nameservers: I) -> ClientBuilder + where + I: IntoIterator, + { + self.config.dns_nameservers = Some(nameservers.into_iter().collect()); + self + } + /// Disables the hickory-dns async resolver. /// /// This method exists even if the optional `hickory-dns` feature is not enabled. @@ -1721,6 +1785,17 @@ impl ClientBuilder { } } + /// Adds a filter for valid IP addresses during DNS lookup. + /// + /// # Optional + /// + /// This requires the optional `hickory-dns` feature to be enabled. + #[cfg(feature = "hickory-dns")] + pub fn ip_filter(mut self, filter: fn(std::net::IpAddr) -> bool) -> ClientBuilder { + self.config.ip_filter = filter; + self + } + /// Override DNS resolution for specific domains to a particular IP address. /// /// Warning @@ -1936,7 +2011,7 @@ impl Client { } pub(super) fn execute_request(&self, req: Request) -> Pending { - let (method, url, mut headers, body, timeout, version) = req.pieces(); + let (method, url, mut headers, body, timeout, version, req_id) = req.pieces(); if url.scheme() != "http" && url.scheme() != "https" { return Pending::new_err(error::url_bad_scheme(url)); } @@ -1972,6 +2047,11 @@ impl Client { } } + if let Err(err) = validate_url(self.inner.ip_filter, &url) { + return Pending { + inner: PendingInner::Error(Some(err)), + }; + } let uri = match try_uri(&url) { Ok(uri) => uri, _ => return Pending::new_err(error::url_invalid_uri(url)), @@ -2002,7 +2082,7 @@ impl Client { _ => { let mut req = builder.body(body).expect("valid request parts"); *req.headers_mut() = headers.clone(); - ResponseFuture::Default(self.inner.hyper.request(req)) + ResponseFuture::Default(self.inner.hyper.request(req, req_id.clone())) } }; @@ -2017,6 +2097,7 @@ impl Client { url, headers, body: reusable, + req_id, urls: Vec::new(), @@ -2026,6 +2107,9 @@ impl Client { in_flight, timeout, + + poll_start: None, + poll_start_timestamp: None, }), } } @@ -2234,6 +2318,7 @@ struct ClientRef { proxies: Arc>, proxies_maybe_http_auth: bool, https_only: bool, + ip_filter: fn(IpAddr) -> bool, } impl ClientRef { @@ -2299,6 +2384,10 @@ pin_project! { in_flight: ResponseFuture, #[pin] timeout: Option>>, + + poll_start: Option, + poll_start_timestamp: Option, + req_id: RequestId, } } @@ -2377,7 +2466,12 @@ impl PendingRequest { .body(body) .expect("valid request parts"); *req.headers_mut() = self.headers.clone(); - ResponseFuture::Default(self.client.hyper.request(req)) + // TODO klochek If we ever implement retries, this is where we generate the new id. + ResponseFuture::Default( + self.client + .hyper + .request(req, hyper::stats::next_request_id()), + ) } }; @@ -2460,6 +2554,19 @@ impl Future for PendingRequest { } loop { + if self.poll_start.is_none() { + self.poll_start = Some(std::time::Instant::now()); + self.poll_start_timestamp = Some( + SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap_or(Duration::from_secs(0)) + .as_micros(), + ); + + hyper::stats::get_request_stats(&self.req_id) + .set_poll_start(self.poll_start.unwrap(), self.poll_start_timestamp.unwrap()); + } + let res = match self.as_mut().in_flight().get_mut() { ResponseFuture::Default(r) => match Pin::new(r).poll(cx) { Poll::Ready(Err(e)) => { @@ -2579,12 +2686,17 @@ impl Future for PendingRequest { loc, ))); } - + let old_url = self.url.clone(); self.url = loc; let mut headers = std::mem::replace(self.as_mut().headers(), HeaderMap::new()); remove_sensitive_headers(&mut headers, &self.url, &self.urls); + + if let Err(err) = validate_url(self.client.ip_filter, &self.url) { + return Poll::Ready(Err(err)); + } + let uri = try_uri(&self.url)?; let body = match self.body { Some(Some(ref body)) => Body::reusable(body.clone()), @@ -2623,7 +2735,42 @@ impl Future for PendingRequest { .expect("valid request parts"); *req.headers_mut() = headers.clone(); std::mem::swap(self.as_mut().headers(), &mut headers); - ResponseFuture::Default(self.client.hyper.request(req)) + + let request_body_size = self + .body + .as_ref() + .map(|o| o.as_ref().map(|b| b.len())) + .flatten() + .unwrap_or(0) + as u32; + let now = Instant::now(); + + let certificate = res + .extensions() + .get::() + .and_then(|info| info.peer_certificate()) + .and_then(|bytes| Some(bytes.to_vec())); + + let next_req_id = hyper::stats::next_request_id(); + + hyper::stats::get_request_stats(&self.req_id) + .set_redirect(next_req_id.clone()) + .set_finished(now) + .set_status_code(res.status().as_u16()) + .set_url( + try_uri(&old_url) + .expect("Uri already successfully parsed."), + ) + .set_request_body_size(request_body_size) + .set_certificate(certificate.clone()); + + self.req_id = next_req_id.clone(); + self.poll_start = None; + self.poll_start_timestamp = None; + + ResponseFuture::Default( + self.client.hyper.request(req, next_req_id), + ) } }; @@ -2639,6 +2786,31 @@ impl Future for PendingRequest { } } + let finish = Instant::now(); + + let certificate = res + .extensions() + .get::() + .and_then(|info| info.peer_certificate()) + .and_then(|bytes| Some(bytes.to_vec())); + let status = res.status().as_u16(); + let request_body_size = self + .body + .as_ref() + .map(|o| o.as_ref().map(|b| b.len())) + .flatten() + .unwrap_or(0) as u32; + + let mut req_stats = hyper::stats::get_request_stats(&self.req_id); + + req_stats + .set_poll_start(self.poll_start.unwrap(), self.poll_start_timestamp.unwrap()) + .set_finished(finish) + .set_status_code(status) + .set_url(try_uri(&self.url).expect("Uri already successfully parsed.")) + .set_request_body_size(request_body_size) + .set_certificate(certificate.clone()); + let res = Response::new( res, self.url.clone(), @@ -2682,6 +2854,20 @@ fn add_cookie_header(headers: &mut HeaderMap, cookie_store: &dyn cookie::CookieS } } +fn validate_url(ip_filter: fn(IpAddr) -> bool, url: &Url) -> Result<(), crate::Error> { + let is_valid_ip = match url.host() { + Some(url::Host::Ipv4(ip)) => (ip_filter)(IpAddr::V4(ip)), + Some(url::Host::Ipv6(ip)) => (ip_filter)(IpAddr::V6(ip)), + _ => true, + }; + + if !is_valid_ip { + let e = hickory_resolver::error::ResolveError::from("destination is restricted"); + return Err(crate::Error::new(crate::error::Kind::Request, Some(e))); + } + Ok(()) +} + #[cfg(test)] mod tests { #[tokio::test] diff --git a/src/async_impl/request.rs b/src/async_impl/request.rs index 665710430..350f739a9 100644 --- a/src/async_impl/request.rs +++ b/src/async_impl/request.rs @@ -3,6 +3,7 @@ use std::fmt; use std::future::Future; use std::time::Duration; +use hyper::stats::RequestId; use serde::Serialize; #[cfg(feature = "json")] use serde_json; @@ -26,6 +27,7 @@ pub struct Request { body: Option, timeout: Option, version: Version, + req_id: RequestId, } /// A builder to construct the properties of a `Request`. @@ -48,6 +50,7 @@ impl Request { body: None, timeout: None, version: Version::default(), + req_id: hyper::stats::next_request_id(), } } @@ -123,6 +126,11 @@ impl Request { &mut self.version } + /// Gets the unique request id for this request. + pub fn req_id(&self) -> &RequestId { + &self.req_id + } + /// Attempt to clone the request. /// /// `None` is returned if the request can not be cloned, i.e. if the body is a stream. @@ -148,6 +156,7 @@ impl Request { Option, Option, Version, + RequestId, ) { ( self.method, @@ -156,6 +165,7 @@ impl Request { self.body, self.timeout, self.version, + self.req_id, ) } } @@ -618,6 +628,7 @@ where body: Some(body.into()), timeout: None, version, + req_id: hyper::stats::next_request_id(), }) } } diff --git a/src/async_impl/response.rs b/src/async_impl/response.rs index d2ddfc3a1..41028ecf0 100644 --- a/src/async_impl/response.rs +++ b/src/async_impl/response.rs @@ -1,7 +1,3 @@ -use std::fmt; -use std::net::SocketAddr; -use std::pin::Pin; - use bytes::Bytes; use http_body_util::BodyExt; use hyper::{HeaderMap, StatusCode, Version}; @@ -10,6 +6,9 @@ use hyper_util::client::legacy::connect::HttpInfo; use serde::de::DeserializeOwned; #[cfg(feature = "json")] use serde_json; +use std::fmt; +use std::net::SocketAddr; +use std::pin::Pin; use tokio::time::Sleep; use url::Url; diff --git a/src/async_impl/upgrade.rs b/src/async_impl/upgrade.rs index 3b599d0ad..64de450ce 100644 --- a/src/async_impl/upgrade.rs +++ b/src/async_impl/upgrade.rs @@ -1,9 +1,8 @@ +use futures_util::TryFutureExt; +use hyper_util::rt::TokioIo; use std::pin::Pin; use std::task::{self, Poll}; use std::{fmt, io}; - -use futures_util::TryFutureExt; -use hyper_util::rt::TokioIo; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; /// An upgraded HTTP connection. diff --git a/src/connect.rs b/src/connect.rs index ff76c57f8..6c784e38c 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -3,6 +3,7 @@ use http::header::HeaderValue; use http::uri::{Authority, Scheme}; use http::Uri; use hyper::rt::{Read, ReadBufCursor, Write}; +use hyper::stats::RequestId; use hyper_util::client::legacy::connect::{Connected, Connection}; #[cfg(any(feature = "socks", feature = "__tls"))] use hyper_util::rt::TokioIo; @@ -271,11 +272,16 @@ impl Connector { }) } - async fn connect_with_maybe_proxy(self, dst: Uri, is_proxy: bool) -> Result { + async fn connect_with_maybe_proxy( + self, + dst: Uri, + req_id: RequestId, + is_proxy: bool, + ) -> Result { match self.inner { #[cfg(not(feature = "__tls"))] Inner::Http(mut http) => { - let io = http.call(dst).await?; + let io = http.call((dst, req_id)).await?; Ok(Conn { inner: self.verbose.wrap(io), is_proxy, @@ -295,7 +301,7 @@ impl Connector { let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone()); let mut http = hyper_tls::HttpsConnector::from((http, tls_connector)); - let io = http.call(dst).await?; + let io = http.call((dst, req_id)).await?; if let hyper_tls::MaybeHttpsStream::Https(stream) = io { if !self.nodelay { @@ -359,6 +365,7 @@ impl Connector { async fn connect_via_proxy( self, dst: Uri, + req_id: RequestId, proxy_scheme: ProxyScheme, ) -> Result { log::debug!("proxy({proxy_scheme:?}) intercepts '{dst:?}'"); @@ -382,7 +389,7 @@ impl Connector { let http = http.clone(); let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone()); let mut http = hyper_tls::HttpsConnector::from((http, tls_connector)); - let conn = http.call(proxy_dst).await?; + let conn = http.call((proxy_dst, req_id)).await?; log::trace!("tunneling HTTPS over proxy"); let tunneled = tunnel( conn, @@ -396,6 +403,7 @@ impl Connector { let io = tls_connector .connect(host.ok_or("no host in url")?, TokioIo::new(tunneled)) .await?; + return Ok(Conn { inner: self.verbose.wrap(NativeTlsConn { inner: TokioIo::new(io), @@ -444,7 +452,7 @@ impl Connector { Inner::Http(_) => (), } - self.connect_with_maybe_proxy(proxy_dst, true).await + self.connect_with_maybe_proxy(proxy_dst, req_id, true).await } pub fn set_keepalive(&mut self, dur: Option) { @@ -484,7 +492,7 @@ where } } -impl Service for Connector { +impl Service<(Uri, RequestId)> for Connector { type Response = Conn; type Error = BoxError; type Future = Connecting; @@ -493,20 +501,20 @@ impl Service for Connector { Poll::Ready(Ok(())) } - fn call(&mut self, dst: Uri) -> Self::Future { + fn call(&mut self, (dst, req_id): (Uri, RequestId)) -> Self::Future { log::debug!("starting new connection: {dst:?}"); let timeout = self.timeout; for prox in self.proxies.iter() { if let Some(proxy_scheme) = prox.intercept(&dst) { return Box::pin(with_timeout( - self.clone().connect_via_proxy(dst, proxy_scheme), + self.clone().connect_via_proxy(dst, req_id, proxy_scheme), timeout, )); } } Box::pin(with_timeout( - self.clone().connect_with_maybe_proxy(dst, false), + self.clone().connect_with_maybe_proxy(dst, req_id, false), timeout, )) } diff --git a/src/dns/hickory.rs b/src/dns/hickory.rs index 042707006..a855ca582 100644 --- a/src/dns/hickory.rs +++ b/src/dns/hickory.rs @@ -1,6 +1,8 @@ -//! DNS resolution via the [hickory-resolver](https://github.com/hickory-dns/hickory-dns) crate - -use hickory_resolver::{lookup_ip::LookupIpIntoIter, system_conf, TokioAsyncResolver}; +use hickory_resolver::{ + config::{ResolverConfig, ResolverOpts}, + lookup_ip::LookupIpIntoIter, + system_conf, TokioAsyncResolver, +}; use once_cell::sync::OnceCell; use std::io; @@ -10,27 +12,72 @@ use std::sync::Arc; use super::{Addrs, Name, Resolve, Resolving}; /// Wrapper around an `AsyncResolver`, which implements the `Resolve` trait. -#[derive(Debug, Default, Clone)] +#[derive(Debug, Clone)] pub(crate) struct HickoryDnsResolver { /// Since we might not have been called in the context of a /// Tokio Runtime in initialization, so we must delay the actual /// construction of the resolver. state: Arc>, + filter: fn(std::net::IpAddr) -> bool, + config: Option<(ResolverConfig, ResolverOpts)>, } struct SocketAddrs { iter: LookupIpIntoIter, + filter: fn(std::net::IpAddr) -> bool, +} + +impl HickoryDnsResolver { + pub fn new(filter: fn(std::net::IpAddr) -> bool) -> Self { + Self { + state: Default::default(), + filter, + config: None, + } + } + + pub fn with_config(mut self, config: ResolverConfig, opts: ResolverOpts) -> Self { + self.config = Some((config, opts)); + self + } } impl Resolve for HickoryDnsResolver { fn resolve(&self, name: Name) -> Resolving { let resolver = self.clone(); Box::pin(async move { - let resolver = resolver.state.get_or_try_init(new_resolver)?; + let filter = resolver.filter; + let resolver = resolver + .state + .get_or_try_init(|| new_resolver(resolver.config))?; + let start = std::time::Instant::now(); let lookup = resolver.lookup_ip(name.as_str()).await?; + let elapsed = start.elapsed(); + + let hostname = name.as_str(); + // XXX: Hack to make sure we get all dns logs for sending to vector + let is_vector_uc = hostname.contains("vector-uc-pops"); + let should_log = is_vector_uc || elapsed.as_secs() >= 1 || rand::random::() < 0.01; + + if should_log { + let resolved_ips: Vec = lookup.iter().collect(); + log::warn!( + "DNS lookup for {} took {:?} → {:?}", + hostname, + elapsed, + resolved_ips + ); + } + + if !lookup.iter().any(filter) { + let e = hickory_resolver::error::ResolveError::from("destination is restricted"); + return Err(e.into()); + } + let addrs: Addrs = Box::new(SocketAddrs { iter: lookup.into_iter(), + filter, }); Ok(addrs) }) @@ -41,18 +88,28 @@ impl Iterator for SocketAddrs { type Item = SocketAddr; fn next(&mut self) -> Option { - self.iter.next().map(|ip_addr| SocketAddr::new(ip_addr, 0)) + loop { + let ip_addr = self.iter.next()?; + if (self.filter)(ip_addr) { + return Some(SocketAddr::new(ip_addr, 0)); + } + } } } -/// Create a new resolver with the default configuration, -/// which reads from `/etc/resolve.conf`. -fn new_resolver() -> io::Result { - let (config, opts) = system_conf::read_system_conf().map_err(|e| { - io::Error::new( - io::ErrorKind::Other, - format!("error reading DNS system conf: {e}"), - ) - })?; +fn new_resolver( + resolver_config: Option<(ResolverConfig, ResolverOpts)>, +) -> io::Result { + let (config, mut opts) = match resolver_config { + Some((config, opts)) => (config, opts), + None => system_conf::read_system_conf().map_err(|e| { + io::Error::new( + io::ErrorKind::Other, + format!("error reading DNS system conf: {e}"), + ) + })?, + }; + + opts.cache_size = 500_000; // 500k entries Ok(TokioAsyncResolver::tokio(config, opts)) } diff --git a/src/error.rs b/src/error.rs index 12fca8b9c..d5ef968f6 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,10 +1,9 @@ #![cfg_attr(target_arch = "wasm32", allow(unused))] +use crate::{StatusCode, Url}; use std::error::Error as StdError; use std::fmt; use std::io; -use crate::{StatusCode, Url}; - /// A `Result` alias where the `Err` case is `reqwest::Error`. pub type Result = std::result::Result; @@ -131,6 +130,11 @@ impl Error { if hyper_err.is_connect() { return true; } + } else if err + .downcast_ref::() + .is_some() + { + return true; } source = err.source(); diff --git a/src/lib.rs b/src/lib.rs index ce4549dd9..41b9b9140 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -238,14 +238,6 @@ //! [cargo-features]: https://doc.rust-lang.org/stable/cargo/reference/manifest.html#the-features-section //! [sponsor]: https://seanmonstar.com/sponsor -#[cfg(all(feature = "http3", not(reqwest_unstable)))] -compile_error!( - "\ - The `http3` feature is unstable, and requires the \ - `RUSTFLAGS='--cfg reqwest_unstable'` environment variable to be set.\ -" -); - macro_rules! if_wasm { ($($item:item)*) => {$( #[cfg(target_arch = "wasm32")] @@ -259,12 +251,13 @@ macro_rules! if_hyper { $item )*} } - pub use http::header; pub use http::Method; pub use http::{StatusCode, Version}; pub use url::Url; +pub use hyper::stats::{RedirectStats, RequestStats}; + // universal mods #[macro_use] mod error; diff --git a/tests/stats.rs b/tests/stats.rs new file mode 100644 index 000000000..5c3583a06 --- /dev/null +++ b/tests/stats.rs @@ -0,0 +1,110 @@ +#![cfg(not(target_arch = "wasm32"))] +mod support; +use support::server; + +use std::time::Duration; + +#[tokio::test] +async fn stats_request_timeout() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| { + async { + // delay returning the response + tokio::time::sleep(Duration::from_secs(2)).await; + http::Response::default() + } + }); + + let client = reqwest::Client::builder().build().unwrap(); + + let url = format!("http://{}/slow", server.addr()); + + let req = client + .get(&url) + .timeout(Duration::from_millis(500)) + .build() + .unwrap(); + let req_id = req.req_id().clone(); + + let res = client.execute(req).await; + + let err = res.unwrap_err(); + + if cfg!(not(target_arch = "wasm32")) { + assert!(err.is_timeout() && !err.is_connect()); + } else { + assert!(err.is_timeout()); + } + + let stats = hyper::stats::consume_request_stats(req_id); + assert!(stats.redirects()[0] + .get_http_stats() + .get_connection_stats() + .is_some()); + assert!(stats.redirects()[0] + .get_http_stats() + .get_connection_stats() + .unwrap() + .get_connect() + .is_some()); + assert!(stats.redirects()[0] + .get_http_stats() + .get_connection_stats() + .unwrap() + .get_dns_resolve() + .is_some()); + assert!(stats.redirects()[0] + .get_http_stats() + .get_connection_stats() + .unwrap() + .get_tls_connect() + .is_none()); + assert!(stats.redirects()[0].get_request_sent().is_some()); + assert!(stats.redirects()[0].get_response_start().is_none()); +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn stats_connect_timeout() { + let _ = env_logger::try_init(); + + let client = reqwest::Client::builder() + .connect_timeout(Duration::from_millis(100)) + .build() + .unwrap(); + + let url = "http://10.255.255.1:81/slow"; + + let req = client + .get(url) + .timeout(Duration::from_millis(1000)) + .build() + .unwrap(); + let req_id = req.req_id().clone(); + let res = client.execute(req).await; + + let err = res.unwrap_err(); + + assert!(err.is_connect() && err.is_timeout()); + + let stats = hyper::stats::consume_request_stats(req_id); + assert!(stats.redirects()[0] + .get_http_stats() + .get_connection_stats() + .is_some()); + assert!(stats.redirects()[0] + .get_http_stats() + .get_connection_stats() + .unwrap() + .get_dns_resolve() + .is_some()); + assert!(stats.redirects()[0] + .get_http_stats() + .get_connection_stats() + .unwrap() + .get_tls_connect() + .is_none()); + assert!(stats.redirects()[0].get_request_sent().is_none()); + assert!(stats.redirects()[0].get_response_start().is_none()); +} diff --git a/tests/support/delay_server.rs b/tests/support/delay_server.rs index f79c2a4df..e1475e863 100644 --- a/tests/support/delay_server.rs +++ b/tests/support/delay_server.rs @@ -65,7 +65,6 @@ impl Server { let (stream, _) = res.unwrap(); let io = hyper_util::rt::TokioIo::new(stream); - let handle = tokio::spawn({ let connection_shutdown_rx = connection_shutdown_rx.clone(); let func = func.clone();