1+ #![ feature( type_alias_impl_trait) ]
2+
13use std:: {
24 io,
5+ future:: Future ,
6+ mem:: MaybeUninit ,
7+ pin:: Pin ,
38 sync:: Arc ,
9+ task:: { Context , Poll } ,
410} ;
511
6- use futures:: Future ;
12+ use bytes:: { Buf , BufMut } ;
13+ use futures:: future:: TryFutureExt ;
714use rustls:: ClientConfig ;
8- use tokio_io :: { AsyncRead , AsyncWrite } ;
15+ use tokio :: io :: { AsyncRead , AsyncWrite } ;
916use tokio_postgres:: tls:: { ChannelBinding , MakeTlsConnect , TlsConnect } ;
1017use tokio_rustls:: { client:: TlsStream , TlsConnector } ;
1118use webpki:: { DNSName , DNSNameRef } ;
@@ -23,13 +30,13 @@ impl MakeRustlsConnect {
2330
2431impl < S > MakeTlsConnect < S > for MakeRustlsConnect
2532where
26- S : AsyncRead + AsyncWrite + Send + ' static
33+ S : AsyncRead + AsyncWrite + Unpin ,
2734{
28- type Stream = TlsStream < S > ;
35+ type Stream = RustlsStream < S > ;
2936 type TlsConnect = RustlsConnect ;
30- type Error = io:: Error ;
37+ type Error = std :: io:: Error ;
3138
32- fn make_tls_connect ( & mut self , hostname : & str ) -> Result < RustlsConnect , Self :: Error > {
39+ fn make_tls_connect ( & mut self , hostname : & str ) -> std :: io :: Result < RustlsConnect > {
3340 DNSNameRef :: try_from_ascii_str ( hostname)
3441 . map ( |dns_name| RustlsConnect {
3542 hostname : dns_name. to_owned ( ) ,
@@ -46,44 +53,84 @@ pub struct RustlsConnect {
4653
4754impl < S > TlsConnect < S > for RustlsConnect
4855where
49- S : AsyncRead + AsyncWrite + Send + ' static
56+ S : AsyncRead + AsyncWrite + Unpin ,
5057{
51- type Stream = TlsStream < S > ;
52- type Error = io:: Error ;
53- type Future = Box < dyn Future < Item = ( Self :: Stream , ChannelBinding ) , Error = Self :: Error > + Send > ;
58+ type Stream = RustlsStream < S > ;
59+ type Error = std :: io:: Error ;
60+ type Future = impl Future < Output = std :: io :: Result < RustlsStream < S > > > ;
5461
5562 fn connect ( self , stream : S ) -> Self :: Future {
56- Box :: new (
57- self . connector . connect ( self . hostname . as_ref ( ) , stream)
58- . map ( |s| ( s, ChannelBinding :: none ( ) ) ) // TODO
59- )
63+ self . connector . connect ( self . hostname . as_ref ( ) , stream)
64+ . map_ok ( |s| RustlsStream ( Box :: pin ( s) ) )
65+ }
66+ }
67+
68+ pub struct RustlsStream < S > ( Pin < Box < TlsStream < S > > > ) ;
69+
70+ impl < S > tokio_postgres:: tls:: TlsStream for RustlsStream < S >
71+ where
72+ S : AsyncRead + AsyncWrite + Unpin ,
73+ {
74+ fn channel_binding ( & self ) -> ChannelBinding {
75+ ChannelBinding :: none ( ) // TODO
76+ }
77+ }
78+
79+ impl < S > AsyncRead for RustlsStream < S >
80+ where
81+ S : AsyncRead + AsyncWrite + Unpin ,
82+ {
83+ fn poll_read ( mut self : Pin < & mut Self > , cx : & mut Context , buf : & mut [ u8 ] ) -> Poll < tokio:: io:: Result < usize > > {
84+ self . 0 . as_mut ( ) . poll_read ( cx, buf)
85+ }
86+
87+ unsafe fn prepare_uninitialized_buffer ( & self , buf : & mut [ MaybeUninit < u8 > ] ) -> bool {
88+ self . 0 . prepare_uninitialized_buffer ( buf)
89+ }
90+
91+ fn poll_read_buf < B : BufMut > ( mut self : Pin < & mut Self > , cx : & mut Context , buf : & mut B ) -> Poll < tokio:: io:: Result < usize > >
92+ where
93+ Self : Sized ,
94+ {
95+ self . 0 . as_mut ( ) . poll_read_buf ( cx, buf)
96+ }
97+ }
98+
99+ impl < S > AsyncWrite for RustlsStream < S >
100+ where
101+ S : AsyncRead + AsyncWrite + Unpin ,
102+ {
103+ fn poll_write ( mut self : Pin < & mut Self > , cx : & mut Context , buf : & [ u8 ] ) -> Poll < tokio:: io:: Result < usize > > {
104+ self . 0 . as_mut ( ) . poll_write ( cx, buf)
105+ }
106+
107+ fn poll_flush ( mut self : Pin < & mut Self > , cx : & mut Context ) -> Poll < tokio:: io:: Result < ( ) > > {
108+ self . 0 . as_mut ( ) . poll_flush ( cx)
109+ }
110+
111+ fn poll_shutdown ( mut self : Pin < & mut Self > , cx : & mut Context ) -> Poll < tokio:: io:: Result < ( ) > > {
112+ self . 0 . as_mut ( ) . poll_shutdown ( cx)
113+ }
114+
115+ fn poll_write_buf < B : Buf > ( mut self : Pin < & mut Self > , cx : & mut Context , buf : & mut B ) -> Poll < tokio:: io:: Result < usize > >
116+ where
117+ Self : Sized ,
118+ {
119+ self . 0 . as_mut ( ) . poll_write_buf ( cx, buf)
60120 }
61121}
62122
63123#[ cfg( test) ]
64124mod tests {
65- use futures:: { Future , Stream } ;
66- use tokio:: runtime:: current_thread;
125+ use futures:: future:: TryFutureExt ;
67126
68- #[ test]
69- fn it_works ( ) {
127+ #[ tokio :: test]
128+ async fn it_works ( ) {
70129 let config = rustls:: ClientConfig :: new ( ) ;
71130 let tls = super :: MakeRustlsConnect :: new ( config) ;
72- current_thread:: block_on_all (
73- tokio_postgres:: connect ( "sslmode=require host=localhost user=postgres" , tls)
74- . map ( |( client, connection) | {
75- tokio:: spawn (
76- connection. map_err ( |e| panic ! ( "{:?}" , e) )
77- ) ;
78- client
79- } )
80- . and_then ( |mut client| {
81- client. prepare ( "SELECT 1" )
82- . map ( |s| ( client, s) )
83- } )
84- . and_then ( |( mut client, statement) | {
85- client. query ( & statement, & [ ] ) . collect ( )
86- } )
87- ) . unwrap ( ) ;
131+ let ( client, conn) = tokio_postgres:: connect ( "sslmode=require host=localhost user=postgres" , tls) . await . unwrap ( ) ;
132+ tokio:: spawn ( conn. map_err ( |e| panic ! ( "{:?}" , e) ) ) ;
133+ let stmt = client. prepare ( "SELECT 1" ) . await . unwrap ( ) ;
134+ let _ = client. query ( & stmt, & [ ] ) . await . unwrap ( ) ;
88135 }
89136}
0 commit comments