Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 125 additions & 36 deletions async_postgres/pg_protocol.nim
Original file line number Diff line number Diff line change
Expand Up @@ -147,16 +147,6 @@ type
data: RowData
rowIdx: int32

func initRow*(data: RowData, rowIdx: int32): Row {.inline.} =
## Create a Row view into the given RowData at the specified row index.
Row(data: data, rowIdx: rowIdx)

func data*(row: Row): RowData {.inline.} = ## The underlying RowData buffer.
row.data

func rowIdx*(row: Row): int32 {.inline.} = ## The row index within the RowData buffer.
row.rowIdx

const
syncMsg* = [byte('S'), 0'u8, 0'u8, 0'u8, 4'u8] ## Pre-built Sync message bytes.
flushMsg* = [byte('H'), 0'u8, 0'u8, 0'u8, 4'u8] ## Pre-built Flush message bytes.
Expand Down Expand Up @@ -289,12 +279,40 @@ const
pgCopyBinaryTrailer*: array[2, byte] = [0xFF'u8, 0xFF'u8]
## PGCOPY binary format trailer (int16(-1) sentinel).

maxInt16Count = int(high(int16))
## Maximum element count encodable in a wire Int16 count field (32767).
## Parameter-type, format-code, and parameter-value counts in Parse/Bind
## messages are all Int16, so they cannot represent more elements than this.

maxInt32Len = int(high(int32))
## Maximum byte length encodable in a wire Int32 length field (2147483647).
## Parameter values, SASL data, binary COPY fields, and every message length
## are Int32. PostgreSQL further caps a single value at `MaxAllocSize`
## (~1 GiB - 1), so legitimate payloads stay well below this.

DefaultMaxBackendMessageLen* = 1024 * 1024 * 1024
## Default upper bound on a single backend message (header + body), 1 GiB.
## Prevents a malicious or broken server from causing unbounded recv-buffer
## growth (and OOM) by advertising an int32-max length. PostgreSQL itself
## bounds individual values by `MaxAllocSize` (~1 GiB - 1), so legitimate
## traffic stays well below this cap.

func makeBinarySafeLookup(): array[BinarySafeMaxOid + 1, bool] {.compileTime.} =
for oid in BinarySafeOids:
result[oid] = true

const binarySafeLookup = makeBinarySafeLookup()

func initRow*(data: RowData, rowIdx: int32): Row =
## Create a Row view into the given RowData at the specified row index.
Row(data: data, rowIdx: rowIdx)

func data*(row: Row): RowData = ## The underlying RowData buffer.
row.data

func rowIdx*(row: Row): int32 = ## The row index within the RowData buffer.
row.rowIdx

func isBinarySafeOid*(oid: int32): bool =
## Check if a type OID can be safely requested in binary format.
oid >= 0 and oid <= BinarySafeMaxOid and binarySafeLookup[oid]
Expand All @@ -313,14 +331,33 @@ proc encodeInt32*(val: int32): array[4, byte] =
result[2] = byte((val shr 8) and 0xFF)
result[3] = byte(val and 0xFF)

proc addInt16*(buf: var seq[byte], val: int16) {.inline.} =
proc addInt16*(buf: var seq[byte], val: int16) =
## Append a 16-bit integer in big-endian format to the buffer.
let oldLen = buf.len
buf.setLen(oldLen + 2)
buf[oldLen] = byte((val shr 8) and 0xFF)
buf[oldLen + 1] = byte(val and 0xFF)

proc addInt32*(buf: var seq[byte], val: int32) {.inline.} =
proc addCount16*(buf: var seq[byte], n: int, what: string) =
## Append an Int16 count field, rejecting counts that overflow the wire's
## signed 16-bit range. Without this guard the `int16(n)` conversion raises an
## uncatchable `RangeDefect` on default builds, or silently wraps to a bogus
## (often negative) count that desyncs the protocol stream under `-d:danger`.
## The check is always active so callers get a catchable `ValueError` instead.
##
## `ValueError` is used here (and in `addLen32`) because the count is supplied
## directly by the caller; message-level length overflows detected after the
## message has been assembled raise `PgProtocolError` instead.
if n < 0:
raise newException(ValueError, what & " count " & $n & " is negative")
if n > maxInt16Count:
raise newException(
ValueError,
what & " count " & $n & " exceeds protocol maximum of " & $maxInt16Count,
)
buf.addInt16(int16(n))

proc addInt32*(buf: var seq[byte], val: int32) =
## Append a 32-bit integer in big-endian format to the buffer.
let oldLen = buf.len
buf.setLen(oldLen + 4)
Expand All @@ -329,7 +366,26 @@ proc addInt32*(buf: var seq[byte], val: int32) {.inline.} =
buf[oldLen + 2] = byte((val shr 8) and 0xFF)
buf[oldLen + 3] = byte(val and 0xFF)

proc addInt64*(buf: var seq[byte], val: int64) {.inline.} =
proc addLen32*(buf: var seq[byte], n: int, what: string) =
## Append an Int32 length field, rejecting payloads that overflow the wire's
## signed 32-bit length. Like `addCount16`, this turns the otherwise
## uncatchable `RangeDefect` (or a wrapped, often-negative length that desyncs
## the stream under `-d:danger`) into a catchable `ValueError` raised before
## the oversized payload is appended.
##
## `ValueError` is used here (and in `addCount16`) because the length is
## supplied directly by the caller; message-level length overflows detected
## after the message has been assembled raise `PgProtocolError` instead.
if n < 0:
raise newException(ValueError, what & " length " & $n & " is negative")
if n > maxInt32Len:
raise newException(
ValueError,
what & " length " & $n & " exceeds protocol maximum of " & $maxInt32Len,
)
buf.addInt32(int32(n))

proc addInt64*(buf: var seq[byte], val: int64) =
## Append a 64-bit integer in big-endian format to the buffer.
let oldLen = buf.len
buf.setLen(oldLen + 8)
Expand All @@ -344,25 +400,45 @@ proc addInt64*(buf: var seq[byte], val: int64) {.inline.} =

proc patchLen*(buf: var seq[byte], offset: int = 1) =
## Patch the length placeholder at `offset` with buf.len minus the tag byte.
## Raises `PgProtocolError` when the assembled message exceeds the Int32
## maximum; this is a protocol-level failure for an internally built message,
## distinct from the `ValueError` raised by `addLen32`/`addCount16` for
## caller-supplied field values.
if offset < 0 or offset + 3 >= buf.len:
raise newException(
PgProtocolError,
"patchLen: offset " & $offset & " out of range for buf.len " & $buf.len,
)
if buf.high > maxInt32Len:
raise newException(
PgProtocolError,
"patchLen: message length " & $buf.high & " exceeds Int32 maximum of " &
$maxInt32Len,
)
let length = int32(buf.high)
buf[offset] = byte((length shr 24) and 0xFF)
buf[offset + 1] = byte((length shr 16) and 0xFF)
buf[offset + 2] = byte((length shr 8) and 0xFF)
buf[offset + 3] = byte(length and 0xFF)

proc patchMsgLen*(buf: var seq[byte], msgStart: int) {.inline.} =
proc patchMsgLen*(buf: var seq[byte], msgStart: int) =
## Patch the length field of a message starting at `msgStart`.
## Length = total message size minus the type byte.
## Raises `PgProtocolError` when the assembled message exceeds the Int32
## maximum; this is a protocol-level failure for an internally built message,
## distinct from the `ValueError` raised by `addLen32`/`addCount16` for
## caller-supplied field values.
if msgStart < 0 or msgStart + 4 >= buf.len:
raise newException(
PgProtocolError,
"patchMsgLen: msgStart " & $msgStart & " out of range for buf.len " & $buf.len,
)
if buf.len - msgStart - 1 > maxInt32Len:
raise newException(
PgProtocolError,
"patchMsgLen: message length " & $(buf.len - msgStart - 1) &
" exceeds Int32 maximum of " & $maxInt32Len,
)
let length = int32(buf.len - msgStart - 1)
buf[msgStart + 1] = byte((length shr 24) and 0xFF)
buf[msgStart + 2] = byte((length shr 16) and 0xFF)
Expand Down Expand Up @@ -424,6 +500,9 @@ proc encodeStartup*(
user: string, database: string, extraParams: openArray[(string, string)] = []
): seq[byte] =
## Encode a StartupMessage (protocol v3.0) with user, database, and extra parameters.
## Raises `ValueError` for invalid caller-supplied values (e.g. embedded NUL
## bytes) and `PgProtocolError` if the assembled message exceeds the Int32
## maximum length.
result.addInt32(0) # length placeholder
result.addInt32(196608) # protocol version 3.0
result.addCString("user")
Expand All @@ -435,6 +514,12 @@ proc encodeStartup*(
result.addCString(k)
result.addCString(v)
result.add(0'u8) # terminator
if result.len > maxInt32Len:
raise newException(
PgProtocolError,
"encodeStartup: message length " & $result.len & " exceeds Int32 maximum of " &
$maxInt32Len,
)
let length = int32(result.len)
let encoded = encodeInt32(length)
result[0] = encoded[0]
Expand Down Expand Up @@ -467,7 +552,7 @@ proc encodeSASLInitialResponse*(mechanism: string, data: seq[byte]): seq[byte] =
result.add(byte('p'))
result.addInt32(0) # length placeholder
result.addCString(mechanism)
result.addInt32(int32(data.len))
result.addLen32(data.len, "SASLInitialResponse data")
result.add(data)
result.patchLen()

Expand All @@ -485,7 +570,7 @@ proc encodeQuery*(sql: string): seq[byte] =
result.addCString(sql)
result.patchLen()

proc addFixedMsg(buf: var seq[byte], msg: array[5, byte]) {.inline.} =
proc addFixedMsg(buf: var seq[byte], msg: array[5, byte]) =
let oldLen = buf.len
buf.setLen(oldLen + 5)
buf.writeBytesAt(oldLen, msg)
Expand All @@ -502,7 +587,7 @@ proc addParse*(
buf.addInt32(0) # length placeholder
buf.addCString(stmtName)
buf.addCString(sql)
buf.addInt16(int16(paramTypeOids.len))
buf.addCount16(paramTypeOids.len, "Parse parameter-type")
for oid in paramTypeOids:
buf.addInt32(oid)
buf.patchMsgLen(msgStart)
Expand All @@ -522,20 +607,20 @@ proc addBind*(
buf.addCString(portalName)
buf.addCString(stmtName)
# Parameter format codes
buf.addInt16(int16(paramFormats.len))
buf.addCount16(paramFormats.len, "Bind parameter-format")
for f in paramFormats:
buf.addInt16(f)
# Parameter values
buf.addInt16(int16(paramValues.len))
buf.addCount16(paramValues.len, "Bind parameter")
for v in paramValues:
if v.isNone:
buf.addInt32(-1) # NULL
else:
let data = v.get
buf.addInt32(int32(data.len))
buf.addLen32(data.len, "Bind parameter value")
buf.appendBytes(data)
# Result format codes
buf.addInt16(int16(resultFormats.len))
buf.addCount16(resultFormats.len, "Bind result-format")
for f in resultFormats:
buf.addInt16(f)
buf.patchMsgLen(msgStart)
Expand Down Expand Up @@ -564,10 +649,10 @@ proc addBindRaw*(
buf.addInt32(0) # length placeholder
buf.addCString(portalName)
buf.addCString(stmtName)
buf.addInt16(int16(paramFormats.len))
buf.addCount16(paramFormats.len, "Bind parameter-format")
for f in paramFormats:
buf.addInt16(f)
buf.addInt16(int16(paramRanges.len))
buf.addCount16(paramRanges.len, "Bind parameter")
for r in paramRanges:
if r.len < -1:
raise newException(ValueError, "addBindRaw: invalid range len " & $r.len)
Expand All @@ -587,7 +672,7 @@ proc addBindRaw*(
let oldLen = buf.len
buf.setLen(oldLen + r.len)
buf.writeBytesAt(oldLen, paramData.toOpenArray(r.off, r.off + r.len - 1))
buf.addInt16(int16(resultFormats.len))
buf.addCount16(resultFormats.len, "Bind result-format")
for f in resultFormats:
buf.addInt16(f)
buf.patchMsgLen(msgStart)
Expand Down Expand Up @@ -619,15 +704,15 @@ proc addClose*(buf: var seq[byte], kind: DescribeKind, name: string) =
buf.addCString(name)
buf.patchMsgLen(msgStart)

proc addSync*(buf: var seq[byte]) {.inline.} =
proc addSync*(buf: var seq[byte]) =
## Append a Sync message to the buffer.
buf.addFixedMsg(syncMsg)

proc addFlush*(buf: var seq[byte]) {.inline.} =
proc addFlush*(buf: var seq[byte]) =
## Append a Flush message to the buffer.
buf.addFixedMsg(flushMsg)

proc addCopyDone*(buf: var seq[byte]) {.inline.} =
proc addCopyDone*(buf: var seq[byte]) =
## Append a CopyDone message to the buffer.
buf.addFixedMsg(copyDoneMsg)

Expand Down Expand Up @@ -684,6 +769,17 @@ proc encodeCancelRequest*(pid: int32, secretKey: int32): seq[byte] =
proc encodeCopyData*(buf: var seq[byte], data: openArray[byte]) =
## Encode a CopyData message, appending to `buf`.
## Single setLen for header + payload to minimize bounds checks.
## The Int32 length field covers itself plus the payload, so reject payloads
## that would overflow it (a wrapped length desyncs the stream) before any
## allocation, matching the `addLen32` guard used by the other encoders.
## Like `addLen32`, the payload length comes from the caller, so an overflow
## raises `ValueError` rather than `PgProtocolError`.
if data.len > maxInt32Len - 4:
raise newException(
ValueError,
"CopyData payload length " & $data.len & " exceeds protocol maximum of " &
$(maxInt32Len - 4),
)
let msgLen = int32(4 + data.len)
let oldLen = buf.len
buf.setLen(oldLen + 5 + data.len)
Expand Down Expand Up @@ -1059,13 +1155,6 @@ proc parseDataRowInto*(body: openArray[byte], rd: RowData) =

# Streaming backend message parser

const DefaultMaxBackendMessageLen* = 1024 * 1024 * 1024
## Default upper bound on a single backend message (header + body), 1 GiB.
## Prevents a malicious or broken server from causing unbounded recv-buffer
## growth (and OOM) by advertising an int32-max length. PostgreSQL itself
## bounds individual values by `MaxAllocSize` (~1 GiB - 1), so legitimate
## traffic stays well below this cap.

proc parseBackendMessage*(
buf: openArray[byte],
consumed: var int,
Expand Down Expand Up @@ -1238,12 +1327,12 @@ proc addCopyFieldBool*(buf: var seq[byte], val: bool) =

proc addCopyFieldText*(buf: var seq[byte], val: openArray[byte]) =
## Append a raw byte field in binary COPY format.
buf.addInt32(int32(val.len))
buf.addLen32(val.len, "COPY field")
buf.appendBytes(val)

proc addCopyFieldString*(buf: var seq[byte], val: string) =
## Append a string field in binary COPY format.
buf.addInt32(int32(val.len))
buf.addLen32(val.len, "COPY field")
if val.len > 0:
buf.appendBytes(val.toOpenArrayByte(0, val.high))

Expand Down
14 changes: 7 additions & 7 deletions async_postgres/pg_types/encoding.nim
Original file line number Diff line number Diff line change
Expand Up @@ -1440,7 +1440,7 @@ proc addParse*(
buf.addInt32(0) # length placeholder
buf.addCString(stmtName)
buf.addCString(sql)
buf.addInt16(int16(params.len))
buf.addCount16(params.len, "Parse parameter-type")
for p in params:
buf.addInt32(p.oid)
buf.patchMsgLen(msgStart)
Expand All @@ -1458,20 +1458,20 @@ proc addBind*(
buf.addCString(portalName)
buf.addCString(stmtName)
# Parameter format codes
buf.addInt16(int16(params.len))
buf.addCount16(params.len, "Bind parameter-format")
for p in params:
buf.addInt16(p.format)
# Parameter values
buf.addInt16(int16(params.len))
buf.addCount16(params.len, "Bind parameter")
for p in params:
if p.value.isNone:
buf.addInt32(-1) # NULL
else:
let data = p.value.get
buf.addInt32(int32(data.len))
buf.addLen32(data.len, "Bind parameter value")
buf.appendBytes(data)
# Result format codes
buf.addInt16(int16(resultFormats.len))
buf.addCount16(resultFormats.len, "Bind result-format")
for f in resultFormats:
buf.addInt16(f)
buf.patchMsgLen(msgStart)
Expand Down Expand Up @@ -1542,12 +1542,12 @@ proc writeParamValue*(buf: var seq[byte], v: bool) =
buf.add(if v: 1'u8 else: 0'u8)

proc writeParamValue*(buf: var seq[byte], v: string) =
buf.addInt32(int32(v.len))
buf.addLen32(v.len, "parameter value")
if v.len > 0:
buf.appendBytes(v.toOpenArrayByte(0, v.high))

proc writeParamValue*(buf: var seq[byte], v: seq[byte]) =
buf.addInt32(int32(v.len))
buf.addLen32(v.len, "parameter value")
buf.appendBytes(v)

proc writeParamValue*(buf: var seq[byte], v: PgNumeric) =
Expand Down
Loading