@@ -3,9 +3,11 @@ use async_std::os::unix::net::UnixStream;
33
44use async_std:: io:: { self , Read , Write } ;
55use async_std:: net:: TcpStream ;
6+ use std:: future:: Future ;
67use std:: mem:: MaybeUninit ;
78use std:: pin:: Pin ;
89use std:: task:: { Context , Poll } ;
10+ use std:: time:: Duration ;
911use tokio:: io:: { AsyncRead , AsyncWrite } ;
1012use tokio_postgres:: config:: { Config , Host } ;
1113
@@ -81,24 +83,39 @@ pub async fn connect_socket(config: &Config) -> io::Result<Socket> {
8183 let mut ports = config. get_ports ( ) . iter ( ) . cloned ( ) ;
8284 for host in config. get_hosts ( ) {
8385 let port = ports. next ( ) . unwrap_or ( DEFAULT_PORT ) ;
86+ let dur = config. get_connect_timeout ( ) ;
8487 let result = match host {
8588 #[ cfg( unix) ]
8689 Host :: Unix ( path) => {
8790 let sock = path. join ( format ! ( ".s.PGSQL.{}" , port) ) ;
88- UnixStream :: connect ( sock) . await . map ( Into :: into)
91+ let fut = UnixStream :: connect ( sock) ;
92+ timeout ( dur, fut) . await . map ( Into :: into)
93+ }
94+ Host :: Tcp ( tcp) => {
95+ let fut = TcpStream :: connect ( ( tcp. as_str ( ) , port) ) ;
96+ timeout ( dur, fut) . await . map ( Into :: into)
8997 }
90- Host :: Tcp ( tcp) => TcpStream :: connect ( ( tcp. as_str ( ) , port) )
91- . await
92- . map ( Into :: into) ,
9398 #[ cfg( not( unix) ) ]
9499 Host :: Unix ( _) => {
95100 io:: Error :: new ( io:: ErrorKind :: Other , "unix domain socket is unsupported" )
96101 }
97102 } ;
103+
98104 match result {
99105 Err ( err) => error = err,
100106 stream => return stream,
101107 }
102108 }
103109 Err ( error)
104110}
111+
112+ async fn timeout < F , T > ( dur : Option < & Duration > , fut : F ) -> io:: Result < T >
113+ where
114+ F : Future < Output = io:: Result < T > > ,
115+ {
116+ if let Some ( timeout) = dur {
117+ io:: timeout ( timeout. clone ( ) , fut) . await
118+ } else {
119+ fut. await
120+ }
121+ }
0 commit comments