11use std:: {
2+ convert:: TryFrom ,
23 future:: Future ,
34 io,
45 pin:: Pin ,
@@ -8,11 +9,10 @@ use std::{
89
910use futures:: future:: { FutureExt , TryFutureExt } ;
1011use ring:: digest;
11- use rustls:: { ClientConfig , Session } ;
12+ use rustls:: { ClientConfig , ServerName } ;
1213use tokio:: io:: { AsyncRead , AsyncWrite , ReadBuf } ;
1314use tokio_postgres:: tls:: { ChannelBinding , MakeTlsConnect , TlsConnect } ;
1415use tokio_rustls:: { client:: TlsStream , TlsConnector } ;
15- use webpki:: { DNSName , DNSNameRef } ;
1616
1717#[ derive( Clone ) ]
1818pub struct MakeRustlsConnect {
@@ -36,19 +36,21 @@ where
3636 type Error = io:: Error ;
3737
3838 fn make_tls_connect ( & mut self , hostname : & str ) -> io:: Result < RustlsConnect > {
39- DNSNameRef :: try_from_ascii_str ( hostname)
40- . map ( |dns_name| RustlsConnect ( Some ( RustlsConnectData {
41- hostname : dns_name. to_owned ( ) ,
42- connector : Arc :: clone ( & self . config ) . into ( ) ,
43- } ) ) )
39+ ServerName :: try_from ( hostname)
40+ . map ( |dns_name| {
41+ RustlsConnect ( Some ( RustlsConnectData {
42+ hostname : dns_name,
43+ connector : Arc :: clone ( & self . config ) . into ( ) ,
44+ } ) )
45+ } )
4446 . or ( Ok ( RustlsConnect ( None ) ) )
4547 }
4648}
4749
4850pub struct RustlsConnect ( Option < RustlsConnectData > ) ;
4951
5052struct RustlsConnectData {
51- hostname : DNSName ,
53+ hostname : ServerName ,
5254 connector : TlsConnector ,
5355}
5456
@@ -63,10 +65,11 @@ where
6365 fn connect ( self , stream : S ) -> Self :: Future {
6466 match self . 0 {
6567 None => Box :: pin ( core:: future:: ready ( Err ( io:: ErrorKind :: InvalidInput . into ( ) ) ) ) ,
66- Some ( c) => c. connector
67- . connect ( c. hostname . as_ref ( ) , stream)
68+ Some ( c) => c
69+ . connector
70+ . connect ( c. hostname , stream)
6871 . map_ok ( |s| RustlsStream ( Box :: pin ( s) ) )
69- . boxed ( )
72+ . boxed ( ) ,
7073 }
7174 }
7275}
7982{
8083 fn channel_binding ( & self ) -> ChannelBinding {
8184 let ( _, session) = self . 0 . get_ref ( ) ;
82- match session. get_peer_certificates ( ) {
83- Some ( certs) if certs. len ( ) > 0 => {
85+ match session. peer_certificates ( ) {
86+ Some ( certs) if ! certs. is_empty ( ) => {
8487 let sha256 = digest:: digest ( & digest:: SHA256 , certs[ 0 ] . as_ref ( ) ) ;
8588 ChannelBinding :: tls_server_end_point ( sha256. as_ref ( ) . into ( ) )
8689 }
@@ -100,7 +103,6 @@ where
100103 ) -> Poll < tokio:: io:: Result < ( ) > > {
101104 self . 0 . as_mut ( ) . poll_read ( cx, buf)
102105 }
103-
104106}
105107
106108impl < S > AsyncWrite for RustlsStream < S >
@@ -122,7 +124,6 @@ where
122124 fn poll_shutdown ( mut self : Pin < & mut Self > , cx : & mut Context ) -> Poll < tokio:: io:: Result < ( ) > > {
123125 self . 0 . as_mut ( ) . poll_shutdown ( cx)
124126 }
125-
126127}
127128
128129#[ cfg( test) ]
@@ -133,12 +134,17 @@ mod tests {
133134 async fn it_works ( ) {
134135 env_logger:: builder ( ) . is_test ( true ) . try_init ( ) . unwrap ( ) ;
135136
136- let config = rustls:: ClientConfig :: new ( ) ;
137+ let config = rustls:: ClientConfig :: builder ( )
138+ . with_safe_defaults ( )
139+ . with_root_certificates ( rustls:: RootCertStore :: empty ( ) )
140+ . with_no_client_auth ( ) ;
137141 let tls = super :: MakeRustlsConnect :: new ( config) ;
138- let ( client, conn) =
139- tokio_postgres:: connect ( "sslmode=require host=localhost port=5432 user=postgres" , tls)
140- . await
141- . expect ( "connect" ) ;
142+ let ( client, conn) = tokio_postgres:: connect (
143+ "sslmode=require host=localhost port=5432 user=postgres" ,
144+ tls,
145+ )
146+ . await
147+ . expect ( "connect" ) ;
142148 tokio:: spawn ( conn. map_err ( |e| panic ! ( "{:?}" , e) ) ) ;
143149 let stmt = client. prepare ( "SELECT 1" ) . await . expect ( "prepare" ) ;
144150 let _ = client. query ( & stmt, & [ ] ) . await . expect ( "query" ) ;
0 commit comments