Skip to content

Commit fea8a9f

Browse files
committed
add connect_raw
1 parent 046b5b2 commit fea8a9f

File tree

2 files changed

+75
-35
lines changed

2 files changed

+75
-35
lines changed

src/lib.rs

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,39 @@ where
6262
T: TlsConnect<AsyncStream>,
6363
{
6464
let stream = connect_stream(&config).await?;
65+
connect_raw(stream, config, tls).await
66+
}
67+
68+
/// Connect to postgres server with a tls connector.
69+
///
70+
/// ```rust
71+
/// use async_postgres::connect;
72+
///
73+
/// use std::error::Error;
74+
/// use async_std::task::spawn;
75+
///
76+
/// async fn play() -> Result<(), Box<dyn Error>> {
77+
/// let url = "host=localhost user=postgres";
78+
/// let (client, conn) = connect(url.parse()?).await?;
79+
/// spawn(conn);
80+
/// let row = client.query_one("SELECT * FROM user WHERE id=$1", &[&0]).await?;
81+
/// let value: &str = row.get(0);
82+
/// println!("value: {}", value);
83+
/// Ok(())
84+
/// }
85+
/// ```
86+
#[inline]
87+
pub async fn connect_raw<S, T>(
88+
stream: S,
89+
config: Config,
90+
tls: T,
91+
) -> io::Result<(Client, Connection<AsyncStream, T::Stream>)>
92+
where
93+
S: Into<AsyncStream>,
94+
T: TlsConnect<AsyncStream>,
95+
{
6596
config
66-
.connect_raw(stream, tls)
97+
.connect_raw(stream.into(), tls)
6798
.await
6899
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))
69100
}

src/stream.rs

Lines changed: 43 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
#[cfg(unix)]
2+
use async_std::os::unix::net::UnixStream;
3+
14
use async_std::io::{self, Read, Write};
25
use async_std::net::TcpStream;
36
use std::mem::MaybeUninit;
@@ -6,11 +9,25 @@ use std::task::{Context, Poll};
69
use tokio::io::{AsyncRead, AsyncWrite};
710
use tokio_postgres::config::{Config, Host};
811

9-
/// Default port of postgres.
12+
/// Default socket port of postgres.
1013
const DEFAULT_PORT: u16 = 5432;
1114

12-
/// A wrapper for async_std::net::TcpStream, implementing tokio::io::{AsyncRead, AsyncWrite}.
13-
pub struct AsyncStream(TcpStream);
15+
/// A alias for 'static + Unpin + Send + Read + Write
16+
pub trait AsyncReadWriter: 'static + Unpin + Send + Read + Write {}
17+
18+
impl<T> AsyncReadWriter for T where T: 'static + Unpin + Send + Read + Write {}
19+
20+
/// A adaptor between futures::io::{AsyncRead, AsyncWrite} and tokio::io::{AsyncRead, AsyncWrite}.
21+
pub struct AsyncStream(Box<dyn AsyncReadWriter>);
22+
23+
impl<T> From<T> for AsyncStream
24+
where
25+
T: AsyncReadWriter,
26+
{
27+
fn from(stream: T) -> Self {
28+
Self(Box::new(stream))
29+
}
30+
}
1431

1532
impl AsyncRead for AsyncStream {
1633
#[inline]
@@ -56,39 +73,31 @@ impl AsyncWrite for AsyncStream {
5673
}
5774

5875
/// Establish connection to postgres server by AsyncStream.
76+
///
77+
///
5978
#[inline]
6079
pub async fn connect_stream(config: &Config) -> io::Result<AsyncStream> {
61-
let host = try_tcp_host(&config)?;
62-
let port = config
63-
.get_ports()
64-
.iter()
65-
.copied()
66-
.next()
67-
.unwrap_or(DEFAULT_PORT);
68-
69-
let tcp_stream = TcpStream::connect((host, port)).await?;
70-
Ok(AsyncStream(tcp_stream))
71-
}
72-
73-
/// Try to get TCP hostname from postgres config.
74-
#[inline]
75-
fn try_tcp_host(config: &Config) -> io::Result<&str> {
76-
match config
77-
.get_hosts()
78-
.iter()
79-
.filter_map(|host| {
80-
if let Host::Tcp(value) = host {
81-
Some(value)
82-
} else {
83-
None
80+
let mut error = io::Error::new(io::ErrorKind::Other, "host missing");
81+
let mut ports = config.get_ports().iter().cloned();
82+
for host in config.get_hosts() {
83+
let result = match host {
84+
#[cfg(unix)]
85+
Host::Unix(path) => UnixStream::connect(path).await.map(Into::into),
86+
Host::Tcp(tcp) => {
87+
let port = ports.next().unwrap_or(DEFAULT_PORT);
88+
TcpStream::connect((tcp.as_str(), port))
89+
.await
90+
.map(Into::into)
91+
}
92+
#[cfg(not(unix))]
93+
Host::Unix(_) => {
94+
io::Error::new(io::ErrorKind::Other, "unix domain socket is unsupported")
8495
}
85-
})
86-
.next()
87-
{
88-
Some(host) => Ok(host),
89-
None => Err(io::Error::new(
90-
io::ErrorKind::Other,
91-
"At least one tcp hostname is required",
92-
)),
96+
};
97+
match result {
98+
Err(err) => error = err,
99+
stream => return stream,
100+
}
93101
}
102+
Err(error)
94103
}

0 commit comments

Comments
 (0)