|
1 | 1 | use futures_util::FutureExt; |
2 | | -use hyper::client::connect::{self, Connect}; |
3 | 2 | #[cfg(feature = "tokio-runtime")] |
4 | | -use hyper::client::HttpConnector; |
5 | | -use rustls::{ClientConfig, Session}; |
| 3 | +use hyper::client::connect::HttpConnector; |
| 4 | +use hyper::{client::connect::Connection, service::Service, Uri}; |
| 5 | +use rustls::ClientConfig; |
6 | 6 | use std::future::Future; |
7 | 7 | use std::pin::Pin; |
8 | 8 | use std::sync::Arc; |
| 9 | +use std::task::{Context, Poll}; |
9 | 10 | use std::{fmt, io}; |
| 11 | +use tokio::io::{AsyncRead, AsyncWrite}; |
10 | 12 | use tokio_rustls::TlsConnector; |
11 | 13 | use webpki::DNSNameRef; |
12 | 14 |
|
13 | 15 | use crate::stream::MaybeHttpsStream; |
14 | 16 |
|
| 17 | +type BoxError = Box<dyn std::error::Error + Send + Sync>; |
| 18 | + |
15 | 19 | /// A Connector for the `https` scheme. |
16 | 20 | #[derive(Clone)] |
17 | 21 | pub struct HttpsConnector<T> { |
@@ -70,59 +74,55 @@ impl<T> From<(T, Arc<ClientConfig>)> for HttpsConnector<T> { |
70 | 74 | } |
71 | 75 | } |
72 | 76 |
|
73 | | -impl<T> Connect for HttpsConnector<T> |
| 77 | +impl<T> Service<Uri> for HttpsConnector<T> |
74 | 78 | where |
75 | | - T: Connect<Error = io::Error>, |
76 | | - T::Transport: 'static, |
77 | | - T::Future: 'static, |
| 79 | + T: Service<Uri>, |
| 80 | + T::Response: Connection + AsyncRead + AsyncWrite + Send + Unpin + 'static, |
| 81 | + T::Future: Send + 'static, |
| 82 | + T::Error: Into<BoxError>, |
78 | 83 | { |
79 | | - type Transport = MaybeHttpsStream<T::Transport>; |
80 | | - type Error = io::Error; |
| 84 | + type Response = MaybeHttpsStream<T::Response>; |
| 85 | + type Error = BoxError; |
81 | 86 |
|
82 | 87 | #[allow(clippy::type_complexity)] |
83 | | - type Future = Pin< |
84 | | - Box< |
85 | | - dyn Future< |
86 | | - Output = Result< |
87 | | - (MaybeHttpsStream<T::Transport>, connect::Connected), |
88 | | - io::Error, |
89 | | - >, |
90 | | - > + Send, |
91 | | - >, |
92 | | - >; |
93 | | - |
94 | | - fn connect(&self, dst: connect::Destination) -> Self::Future { |
95 | | - let is_https = dst.scheme() == "https"; |
| 88 | + type Future = |
| 89 | + Pin<Box<dyn Future<Output = Result<MaybeHttpsStream<T::Response>, BoxError>> + Send>>; |
| 90 | + |
| 91 | + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
| 92 | + match self.http.poll_ready(cx) { |
| 93 | + Poll::Ready(Ok(())) => Poll::Ready(Ok(())), |
| 94 | + Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), |
| 95 | + Poll::Pending => Poll::Pending, |
| 96 | + } |
| 97 | + } |
| 98 | + |
| 99 | + fn call(&mut self, dst: Uri) -> Self::Future { |
| 100 | + let is_https = dst.scheme_str() == Some("https"); |
96 | 101 |
|
97 | 102 | if !is_https { |
98 | | - let connecting_future = self.http.connect(dst); |
| 103 | + let connecting_future = self.http.call(dst); |
99 | 104 |
|
100 | 105 | let f = async move { |
101 | | - let (tcp, conn) = connecting_future.await?; |
| 106 | + let tcp = connecting_future.await.map_err(Into::into)?; |
102 | 107 |
|
103 | | - Ok((MaybeHttpsStream::Http(tcp), conn)) |
| 108 | + Ok(MaybeHttpsStream::Http(tcp)) |
104 | 109 | }; |
105 | 110 | f.boxed() |
106 | 111 | } else { |
107 | 112 | let cfg = self.tls_config.clone(); |
108 | | - let hostname = dst.host().to_string(); |
109 | | - let connecting_future = self.http.connect(dst); |
| 113 | + let hostname = dst.host().unwrap_or_default().to_string(); |
| 114 | + let connecting_future = self.http.call(dst); |
110 | 115 |
|
111 | 116 | let f = async move { |
112 | | - let (tcp, conn) = connecting_future.await?; |
| 117 | + let tcp = connecting_future.await.map_err(Into::into)?; |
113 | 118 | let connector = TlsConnector::from(cfg); |
114 | 119 | let dnsname = DNSNameRef::try_from_ascii_str(&hostname) |
115 | 120 | .map_err(|_| io::Error::new(io::ErrorKind::Other, "invalid dnsname"))?; |
116 | 121 | let tls = connector |
117 | 122 | .connect(dnsname, tcp) |
118 | 123 | .await |
119 | 124 | .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; |
120 | | - let connected = if tls.get_ref().1.get_alpn_protocol() == Some(b"h2") { |
121 | | - conn.negotiated_h2() |
122 | | - } else { |
123 | | - conn |
124 | | - }; |
125 | | - Ok((MaybeHttpsStream::Https(tls), connected)) |
| 125 | + Ok(MaybeHttpsStream::Https(tls)) |
126 | 126 | }; |
127 | 127 | f.boxed() |
128 | 128 | } |
|
0 commit comments