Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ Async PostgreSQL client in Nim.
- Unix socket connection
- Multi-host failover
- Target session attributes (any, read-write, read-only, primary, standby, prefer-standby)
- `load_balance_hosts=random` shuffles the configured host list per connection to
spread a pool across replicas (reorders the multi-host list only, not multiple
addresses behind a single DNS name)

### Queries & Statements
- `sql` macro — compile-time `{expr}` placeholder extraction with automatic parameterization
Expand Down
11 changes: 7 additions & 4 deletions async_postgres/pg_connection.nim
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,20 @@
## `nextMessage`, `recvMessage`,
## `sendMsg`), TCP keepalive,
## `closeTransport`, notification/notice
## dispatch and `isConnected` /
## `socketHasFin`.
## dispatch, `isConnected` /
## `socketHasFin`, and the `getHosts`
## host helper.
## - `pg_connection/ssl` — SSL negotiation (`negotiateSSL`) for
## chronos+BearSSL and asyncdispatch+OpenSSL.
## - `pg_connection/cache` — client-side prepared-statement LRU.
## - `pg_connection/simple_query` — simple-query / simple-exec / ping,
## `cancel` / `invalidateOnTimeout`,
## `checkSessionAttrs`, `quoteIdentifier`,
## `QueryResult` helpers.
## - `pg_connection/lifecycle` — `connect` / `connectToHost` / `close`
## and the SCRAM/require_auth helpers.
## - `pg_connection/lifecycle` — `connect` / `connectToHost` / `close`,
## `orderedHosts` (load-balanced host
## ordering) and the SCRAM/require_auth
## helpers.
## - `pg_connection/notify` — LISTEN/NOTIFY pump, `waitNotification`,
## `reconnectInPlace`.
## - `pg_connection/type_lookup` — `lookupTypeOids` generic helper to
Expand Down
13 changes: 13 additions & 0 deletions async_postgres/pg_connection/dsn.nim
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,15 @@ proc parseTargetSessionAttrs*(s: string): TargetSessionAttrs =
else:
raise newException(PgError, "Invalid target_session_attrs: " & s)

proc parseLoadBalanceHosts*(s: string): LoadBalanceHosts =
case s
of "disable":
lbhDisable
of "random":
lbhRandom
else:
raise newException(PgError, "Invalid load_balance_hosts: " & s)

proc parsePort*(s: string): int =
try:
result = parseInt(s)
Expand Down Expand Up @@ -207,6 +216,8 @@ proc applyParam*(result: var ConnConfig, key, val: string) =
raise newException(PgError, "keepalives_count must be non-negative: " & val)
of "target_session_attrs":
result.targetSessionAttrs = parseTargetSessionAttrs(val)
of "load_balance_hosts":
result.loadBalanceHosts = parseLoadBalanceHosts(val)
of "max_message_size":
try:
result.maxMessageSize = parseInt(val)
Expand Down Expand Up @@ -451,6 +462,7 @@ proc initConnConfig*(
keepAliveCount = 0,
hosts: seq[HostEntry] = @[],
targetSessionAttrs = tsaAny,
loadBalanceHosts = lbhDisable,
requireAuth: set[AuthMethod] = {},
extraParams: seq[(string, string)] = @[],
maxMessageSize = 0,
Expand All @@ -475,6 +487,7 @@ proc initConnConfig*(
keepAliveCount: keepAliveCount,
hosts: hosts,
targetSessionAttrs: targetSessionAttrs,
loadBalanceHosts: loadBalanceHosts,
requireAuth: requireAuth,
extraParams: extraParams,
maxMessageSize: maxMessageSize,
Expand Down
49 changes: 43 additions & 6 deletions async_postgres/pg_connection/lifecycle.nim
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
## auth loop → ParameterStatus/BackendKeyData → extension OID discovery.
## - `connect` — the public entry: multi-host failover, `targetSessionAttrs`
## handling, optional `connectTimeout`, top-level connect tracing.
## - `orderedHosts` — the host list to try, reordered per `loadBalanceHosts`
## (`lbhRandom` shuffles it once per connection).
## - `close` — idempotent teardown: stop background listen pump, send
## `Terminate`, drop transport handles.
##
## Imports `simple_query` for `checkSessionAttrs` (failover probe).
## Re-exported through `pg_connection.nim`.

import std/[options, strutils, tables]
import std/[options, random, strutils, sysrand, tables]

import ../[async_backend, pg_errors, pg_protocol, pg_auth]
import pkg/nimcrypto/utils as ncutils
Expand Down Expand Up @@ -505,14 +507,47 @@ proc attemptHostTimed(
else:
return await attemptHost(config, entry, attrs).wait(config.connectTimeout)

proc orderedHosts*(config: ConnConfig): seq[HostEntry] =
## `getHosts`, reordered per `config.loadBalanceHosts`. With `lbhRandom`
## (libpq `load_balance_hosts=random`) the configured host list is shuffled
## once per call, so a pool of connections spreads across hosts. With
## `lbhDisable` (default) the configured order is preserved. Only the
## multi-host list is reordered — multiple addresses behind a single host
## name are not shuffled (`attemptHost` dials the first resolved address).
##
## The shuffle is seeded from the OS secure random source (`std/sysrand`)
## into a local `std/random` RNG, so it is safe under `--threads:on`, does
## not require the application to call `randomize()`, and keeps no
## module-level state.
result = config.getHosts()
if config.loadBalanceHosts == lbhRandom and result.len > 1:
let bytes =
try:
urandom(8)
except OSError as e:
raise newException(
PgConnectionError,
"Failed to read entropy for load_balance_hosts=random: " & e.msg,
)
var seed: uint64
for b in bytes:
seed = (seed shl 8) or b
# initRand expects a signed seed; use a bit-preserving cast so any
# 64-bit random value is valid, not just values <= int64.high.
var rng = initRand(cast[int64](seed))
rng.shuffle(result)

proc connect*(config: ConnConfig): Future[PgConnection] =
## Establish a new connection to a PostgreSQL server.
## Supports multi-host failover: tries each host in order.
## Supports multi-host failover: tries each host in order, or in a random
## order when `loadBalanceHosts == lbhRandom` (libpq `load_balance_hosts`).
## Respects `targetSessionAttrs` to select the appropriate server type.
## `connectTimeout` is applied per host (libpq semantics): each host attempt
## gets its own budget, so the total wait may reach `connectTimeout * hosts`.
proc perform(): Future[PgConnection] {.async.} =
let hosts = config.getHosts()
proc perform(hosts: seq[HostEntry]): Future[PgConnection] {.async.} =
# `hosts` is already ordered by the caller (shuffled under lbhRandom), so
# both the preferStandby two-pass loop and the single-pass loop below share
# one order.
var errors: seq[string]
# With a single host there is no failover. Preserve the contract that its
# `connectTimeout` surfaces as a raw `AsyncTimeoutError` (callers and the
Expand Down Expand Up @@ -565,7 +600,9 @@ proc connect*(config: ConnConfig): Future[PgConnection] =
)

proc wrapped(): Future[PgConnection] {.async.} =
let hosts = config.getHosts()
# Compute the ordered host list once so the trace and the actual connection
# attempts see the same order under lbhRandom.
let hosts = config.orderedHosts()
var conn: PgConnection
withTracing(
config.tracer,
Expand All @@ -577,7 +614,7 @@ proc connect*(config: ConnConfig): Future[PgConnection] =
):
# `connectTimeout` is enforced per host inside `attemptHostTimed`, so
# `perform()` is awaited directly here — no outer total-timeout wrapper.
conn = await perform()
conn = await perform(hosts)
conn.tracer = config.tracer
return conn

Expand Down
14 changes: 14 additions & 0 deletions async_postgres/pg_connection/types.nim
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,16 @@ type
tsaStandby ## Standby server
tsaPreferStandby ## Prefer standby, fall back to any

LoadBalanceHosts* = enum
## Host connection ordering for a multi-host connection (libpq compatible).
lbhDisable ## Try hosts in the configured order (default)
lbhRandom
## Shuffle the configured host list once per connection so a pool of
## connections spreads across hosts (e.g. read replicas). Only the
## multi-host list is reordered — multiple addresses behind a single host
## name are not shuffled. See `orderedHosts` for the seeding and
## thread-safety details.

HostEntry* = object ## A single host:port entry for multi-host connection.
host*: string ## Host name (or Unix socket dir); used for SSL verification
hostaddr*: string
Expand Down Expand Up @@ -127,6 +137,10 @@ type
keepAliveCount*: int ## Number of probes before giving up (0 = OS default)
hosts*: seq[HostEntry] ## Multiple hosts for failover (empty = use host/port)
targetSessionAttrs*: TargetSessionAttrs ## Target server type (default tsaAny)
loadBalanceHosts*: LoadBalanceHosts
## Host ordering for multi-host connections (libpq `load_balance_hosts`);
## see `LoadBalanceHosts`. `lbhDisable` (default) preserves the configured
## order.
extraParams*: seq[(string, string)] ## Additional startup parameters
maxMessageSize*: int
## Upper bound (in bytes) on a single backend message including
Expand Down
83 changes: 82 additions & 1 deletion tests/test_dsn.nim
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import std/unittest
import std/[random, unittest]

import ../async_postgres/[async_backend, pg_connection]

Expand Down Expand Up @@ -476,6 +476,87 @@ suite "parseDsn":
expect PgError:
discard parseDsn("postgresql://h/db?target_session_attrs=bogus")

test "load_balance_hosts all values":
check parseDsn("postgresql://h/db?load_balance_hosts=disable").loadBalanceHosts ==
lbhDisable
check parseDsn("postgresql://h/db?load_balance_hosts=random").loadBalanceHosts ==
lbhRandom

test "load_balance_hosts defaults to disable":
check parseDsn("postgresql://h1,h2/db").loadBalanceHosts == lbhDisable

test "load_balance_hosts keyword=value form":
check parseDsn("host=h1,h2 load_balance_hosts=random").loadBalanceHosts == lbhRandom

test "error: invalid load_balance_hosts":
expect PgError:
discard parseDsn("postgresql://h/db?load_balance_hosts=bogus")

test "orderedHosts: lbhDisable preserves configured order":
let cfg = parseDsn("postgresql://h1,h2,h3/db?load_balance_hosts=disable")
check cfg.orderedHosts() == cfg.getHosts()

test "orderedHosts: lbhRandom is a permutation and reorders":
let cfg = parseDsn("postgresql://h1,h2,h3,h4,h5/db?load_balance_hosts=random")
let base = cfg.getHosts()
var sawReorder = false
for _ in 0 ..< 100:
let o = cfg.orderedHosts()
check o.len == base.len
for h in base: # same multiset of hosts, just reordered
check h in o
if o != base:
sawReorder = true
check sawReorder

test "orderedHosts: single host is never shuffled under lbhRandom":
let cfg = parseDsn("postgresql://only/db?load_balance_hosts=random")
check cfg.orderedHosts() == cfg.getHosts()

test "orderedHosts: exactly two hosts are eligible for shuffling":
# Boundary of the `result.len > 1` guard: two hosts (the smallest pool that
# can be balanced) must be reorderable, so a `> 2` off-by-one regression
# that silently disabled load balancing for 2-host pools would be caught.
let cfg = parseDsn("postgresql://h1,h2/db?load_balance_hosts=random")
let base = cfg.getHosts()
var sawReorder = false
for _ in 0 ..< 100:
let o = cfg.orderedHosts()
check o.len == 2
for h in base:
check h in o
if o != base:
sawReorder = true
check sawReorder

test "orderedHosts: lbhRandom is independent of the global RNG seed":
# The shuffle must come from a local RNG seeded by the OS entropy source,
# not std/random's global RNG. Pinning the global RNG to the same seed
# before every call must NOT make the order reproducible: a buggy
# implementation that shuffled via the global RNG would return an identical
# order each time the global seed was reset to the same value.
let cfg = parseDsn("postgresql://h1,h2,h3,h4,h5/db?load_balance_hosts=random")
let base = cfg.getHosts()
randomize(123456789)
let first = cfg.orderedHosts()
var sawDifferent = false
for _ in 0 ..< 50:
randomize(123456789) # reset the global RNG to an identical state each time
let o = cfg.orderedHosts()
check o.len == base.len
for h in base:
check h in o
if o != first:
sawDifferent = true
# With a global-RNG implementation every `o` would equal `first`;
# independence means we still observe variation despite the fixed seed.
check sawDifferent
randomize() # restore a non-deterministic global RNG (avoid cross-test leak)

# The invariant that connect() shuffles once and reuses that single order for
# both the connect-start trace and every host attempt (including both
# prefer-standby passes) is covered end-to-end in tests/test_tracing.nim.

test "single host backward compat - hosts has one entry":
let cfg = parseDsn("postgresql://myhost:5433/db")
check cfg.hosts.len == 1
Expand Down
95 changes: 95 additions & 0 deletions tests/test_tracing.nim
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,101 @@ suite "Tracing: connect":

waitFor t()

test "lbhRandom connect: connect-start trace reflects the real attempt order":
# Real end-to-end check (replaces the earlier tautological proxy): connect()
# must compute the shuffled order once and use that single order for BOTH the
# connect-start trace and the actual host attempts. All hosts dial a closed
# loopback port so the attempt loop runs without a live server; the
# aggregated error message begins with the first host actually dialled.
proc t() {.async.} =
let log = newTraceLog()
let tracer = buildTracer(log)
var cfg = parseDsn(
"host=lbh1,lbh2,lbh3,lbh4,lbh5 " &
"hostaddr=127.0.0.1,127.0.0.1,127.0.0.1,127.0.0.1,127.0.0.1 " &
"port=19999 dbname=db sslmode=disable load_balance_hosts=random"
)
cfg.tracer = tracer

var errMsg = ""
try:
let conn = await connect(cfg)
await conn.close()
except CatchableError as e:
errMsg = e.msg

doAssert errMsg.len > 0, "expected the multi-host connect to fail"
doAssert log.connectStarts.len == 1
let traced = log.connectStarts[0].hosts
doAssert traced.len == 5

# The tracer sees a permutation of all configured hosts.
for name in ["lbh1", "lbh2", "lbh3", "lbh4", "lbh5"]:
var found = false
for h in traced:
if h.host == name:
found = true
doAssert found, "trace missing host " & name

# The first host actually dialled (the aggregated error starts with it)
# must equal the first host in the traced order — i.e. perform() uses the
# same single shuffled order that was reported to the tracer, not a fresh
# shuffle. Only the leading host name is parsed, so this stays robust
# across backends (asyncdispatch interleaves multi-line tracebacks into
# later parts of the message). "127.0.0.1" never collides with "lbh".
let idx = errMsg.find("lbh")
doAssert idx >= 0
doAssert errMsg[idx ..< idx + 4] == traced[0].host

waitFor t()

test "lbhRandom + prefer-standby reuse one order for both attempt passes":
# connect() runs the prefer-standby passes (standby-first, then any) over the
# single shuffled list. On chronos the aggregated error lists every attempt
# in a clean message, so both passes must spell out the same traced order.
# (asyncdispatch embeds async tracebacks in the message, so the full order is
# asserted only on the clean-message backend.)
when hasChronos:
proc t() {.async.} =
let log = newTraceLog()
let tracer = buildTracer(log)
var cfg = parseDsn(
"host=lbh1,lbh2,lbh3,lbh4,lbh5 " &
"hostaddr=127.0.0.1,127.0.0.1,127.0.0.1,127.0.0.1,127.0.0.1 " &
"port=19999 dbname=db sslmode=disable " &
"load_balance_hosts=random target_session_attrs=prefer-standby"
)
cfg.tracer = tracer

var errMsg = ""
try:
let conn = await connect(cfg)
await conn.close()
except CatchableError as e:
errMsg = e.msg

doAssert log.connectStarts.len == 1
var traced: seq[string]
for h in log.connectStarts[0].hosts:
traced.add(h.host)
doAssert traced.len == 5

var attempted: seq[string]
var i = 0
while true:
let idx = errMsg.find("lbh", i)
if idx < 0:
break
attempted.add(errMsg[idx ..< idx + 4]) # "lbhN"
i = idx + 4

doAssert attempted.len == 10,
"expected 5 hosts x 2 prefer-standby passes, got " & $attempted.len
doAssert attempted[0 ..< 5] == traced # standby-first pass
doAssert attempted[5 ..< 10] == traced # any pass — same single order

waitFor t()

suite "Tracing: exec":
test "onQueryStart(isExec=true) and onQueryEnd with commandTag":
proc t() {.async.} =
Expand Down