Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 24 additions & 16 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,8 @@ def sequence_number(self) -> Optional[int]:

:rtype: int or None
"""
if self._raw_amqp_message.annotations is None:
return None
return self._raw_amqp_message.annotations.get(PROP_SEQ_NUMBER, None)

@property
Expand All @@ -327,8 +329,8 @@ def offset(self) -> Optional[str]:
:rtype: str or None
"""
try:
return self._raw_amqp_message.annotations[PROP_OFFSET].decode("UTF-8")
except (KeyError, AttributeError):
return self._raw_amqp_message.annotations[PROP_OFFSET].decode("UTF-8") # type: ignore[index]
except (KeyError, AttributeError, TypeError):
return None

@property
Expand All @@ -337,7 +339,8 @@ def enqueued_time(self) -> Optional[datetime.datetime]:

:rtype: datetime.datetime or None
"""
timestamp = self._raw_amqp_message.annotations.get(PROP_TIMESTAMP, None)
annotations = self._raw_amqp_message.annotations or {}
timestamp = annotations.get(PROP_TIMESTAMP, None)
if timestamp:
return utc_from_timestamp(float(timestamp) / 1000)
return None
Expand All @@ -348,6 +351,8 @@ def partition_key(self) -> Optional[bytes]:

:rtype: bytes or None
"""
if self._raw_amqp_message.annotations is None:
return None
return self._raw_amqp_message.annotations.get(PROP_PARTITION_KEY, None)

@property
Expand All @@ -356,6 +361,8 @@ def properties(self) -> Dict[Union[str, bytes], Any]:

:rtype: dict[str, any] or dict[bytes, any]
"""
if self._raw_amqp_message.application_properties is None:
self._raw_amqp_message.application_properties = {}
return self._raw_amqp_message.application_properties

@properties.setter
Expand Down Expand Up @@ -402,7 +409,8 @@ def system_properties(self) -> Dict[bytes, Any]:
value = getattr(self._raw_amqp_message.properties, prop_name, None)
if value:
self._sys_properties[key] = value
self._sys_properties.update(self._raw_amqp_message.annotations)
if self._raw_amqp_message.annotations:
self._sys_properties.update(self._raw_amqp_message.annotations) # type: ignore[arg-type]
return self._sys_properties

@property
Expand Down Expand Up @@ -483,10 +491,10 @@ def content_type(self) -> Optional[str]:
return self._raw_amqp_message.properties.content_type

@content_type.setter
def content_type(self, value: str) -> None:
if not self._raw_amqp_message.properties:
self._raw_amqp_message.properties = AmqpMessageProperties()
self._raw_amqp_message.properties.content_type = value
def content_type(self, value: Optional[str]) -> None:
properties = self._raw_amqp_message.properties or AmqpMessageProperties()
properties.content_type = value
self._raw_amqp_message.properties = properties

@property
def correlation_id(self) -> Optional[str]:
Expand All @@ -503,10 +511,10 @@ def correlation_id(self) -> Optional[str]:
return self._raw_amqp_message.properties.correlation_id

@correlation_id.setter
def correlation_id(self, value: str) -> None:
if not self._raw_amqp_message.properties:
self._raw_amqp_message.properties = AmqpMessageProperties()
self._raw_amqp_message.properties.correlation_id = value
def correlation_id(self, value: Optional[str]) -> None:
properties = self._raw_amqp_message.properties or AmqpMessageProperties()
properties.correlation_id = value
self._raw_amqp_message.properties = properties

@property
def message_id(self) -> Optional[str]:
Expand All @@ -525,10 +533,10 @@ def message_id(self) -> Optional[str]:
return self._raw_amqp_message.properties.message_id

@message_id.setter
def message_id(self, value: str) -> None:
if not self._raw_amqp_message.properties:
self._raw_amqp_message.properties = AmqpMessageProperties()
self._raw_amqp_message.properties.message_id = value
def message_id(self, value: Optional[str]) -> None:
properties = self._raw_amqp_message.properties or AmqpMessageProperties()
properties.message_id = value
self._raw_amqp_message.properties = properties


class EventDataBatch:
Expand Down
68 changes: 63 additions & 5 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,24 @@

_LOGGER = logging.getLogger(__name__)
_HEADER_PREFIX = memoryview(b"AMQP")

# Maximum number of elements permitted in any AMQP compound type (list, array, map).
#
# The AMQP 1.0 wire format encodes compound element counts as either a 1-byte
# field (the *_small variants, naturally bounded at 255) or a 4-byte field
# (the *_large variants, wire-level maximum 0xFFFFFFFF). The large variants
# allocate a Python list/dict sized directly from this count, so without an
# upper bound a small frame can demand a multi-gigabyte allocation. This cap
# is applied at every large-variant decode site to keep allocation sizes
# proportional to the bytes actually delivered.
Comment thread
j7nw4r marked this conversation as resolved.
#
# The value mirrors MAX_AMQPVALUE_ITEM_COUNT (65536) in the C reference
# implementation, azure-uamqp-c, and the equivalent MAX_COMPOUND_COUNT bound
# applied across list/array/map decode sites in the Java AMQP codec. Real
# AMQP traffic does not approach this many elements in a single compound
# value, so the bound functions as a hard ceiling rather than a practical
# constraint on legitimate workloads.
_MAX_COMPOUND_COUNT = 65536
_COMPOSITES = {
35: "received",
36: "accepted",
Expand Down Expand Up @@ -222,6 +240,12 @@ def _decode_list_small(buffer: memoryview) -> Tuple[memoryview, List[Any]]:

def _decode_list_large(buffer: memoryview) -> Tuple[memoryview, List[Any]]:
count = c_unsigned_long.unpack(buffer[4:8])[0]
# Validate the wire-supplied count before allocating `[None] * count`,
# which would otherwise scale linearly with an untrusted 32-bit value.
if count > _MAX_COMPOUND_COUNT:
raise ValueError(
f"AMQP list element count {count} exceeds maximum {_MAX_COMPOUND_COUNT}"
)
buffer = buffer[8:]
values = [None] * count
for i in range(count):
Expand All @@ -230,7 +254,12 @@ def _decode_list_large(buffer: memoryview) -> Tuple[memoryview, List[Any]]:


def _decode_map_small(buffer: memoryview) -> Tuple[memoryview, Dict[Any, Any]]:
count = int(buffer[1] / 2)
raw_count = buffer[1]
if raw_count % 2 != 0:
raise ValueError(
f"AMQP map element count {raw_count} must be even (key/value pairs)"
)
count = raw_count // 2
buffer = buffer[2:]
values = {}
for _ in range(count):
Expand All @@ -241,7 +270,22 @@ def _decode_map_small(buffer: memoryview) -> Tuple[memoryview, Dict[Any, Any]]:


def _decode_map_large(buffer: memoryview) -> Tuple[memoryview, Dict[Any, Any]]:
count = int(c_unsigned_long.unpack(buffer[4:8])[0] / 2)
# Validate the raw on-wire count *before* halving it (the AMQP encoding
# stores total entries; pairs = entries / 2). Checking pre-halve keeps
# the comparison aligned with the bound used by _decode_list_large /
# _decode_array_large. Odd counts are rejected explicitly: silently
# flooring to (raw_count - 1) // 2 would leave a trailing key with no
# value, leaking bytes into the next decoder.
raw_count = c_unsigned_long.unpack(buffer[4:8])[0]
if raw_count > _MAX_COMPOUND_COUNT:
raise ValueError(
f"AMQP map element count {raw_count} exceeds maximum {_MAX_COMPOUND_COUNT}"
)
if raw_count % 2 != 0:
raise ValueError(
f"AMQP map element count {raw_count} must be even (key/value pairs)"
)
count = raw_count // 2
buffer = buffer[8:]
values = {}
for _ in range(count):
Expand All @@ -265,6 +309,13 @@ def _decode_array_small(buffer: memoryview) -> Tuple[memoryview, List[Any]]:

def _decode_array_large(buffer: memoryview) -> Tuple[memoryview, List[Any]]:
count = c_unsigned_long.unpack(buffer[4:8])[0]
# Validate the wire-supplied count before allocating `[None] * count`.
# An Array32 frame's COUNT is read directly from the network and would
# otherwise drive a Python list allocation of arbitrary size.
if count > _MAX_COMPOUND_COUNT:
raise ValueError(
f"AMQP array element count {count} exceeds maximum {_MAX_COMPOUND_COUNT}"
)
if count:
subconstructor = buffer[8]
buffer = buffer[9:]
Expand Down Expand Up @@ -333,11 +384,18 @@ def decode_frame(data: memoryview) -> Tuple[int, List[Any]]:
frame_type = data[2]
compound_list_type = data[3]
if compound_list_type == 0xD0:
# list32 0xd0: data[4:8] is size, data[8:12] is count
count = c_signed_int.unpack(data[8:12])[0]
# list32 0xd0: data[4:8] is size, data[8:12] is count. The AMQP 1.0
# wire format defines COUNT as an unsigned 32-bit field; decoding it
# as a signed int and skipping the cap would let a malicious peer
# request a multi-gigabyte field-list allocation below.
count = c_unsigned_long.unpack(data[8:12])[0]
if count > _MAX_COMPOUND_COUNT:
raise ValueError(
f"AMQP frame field count {count} exceeds maximum {_MAX_COMPOUND_COUNT}"
)
buffer = data[12:]
else:
# list8 0xc0: data[4] is size, data[5] is count
# list8 0xc0: data[4] is size, data[5] is count (1 byte, bounded at 255).
count = data[5]
buffer = data[6:]
fields: List[Optional[memoryview]] = [None] * count
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
from __future__ import annotations
import datetime
from typing import List, Tuple, Union, TYPE_CHECKING, Optional, Any, Dict, Callable
from abc import ABC, abstractmethod

Expand Down Expand Up @@ -209,7 +210,7 @@ def create_send_client(
idle_timeout: Optional[float],
network_trace: bool,
retry_policy: Any,
keep_alive_interval: int,
keep_alive_interval: Optional[int],
client_name: str,
link_properties: Optional[Dict[str, Any]],
properties: Optional[Dict[str, Any]],
Expand Down Expand Up @@ -270,7 +271,11 @@ def add_batch(

@staticmethod
@abstractmethod
def create_source(source: Union["uamqp_Source", "pyamqp_Source"], offset: int, selector: bytes):
def create_source(
source: Union["uamqp_Source", "pyamqp_Source"],
offset: Optional[Union[int, str, datetime.datetime]],
selector: bytes,
):
"""
Creates and returns the Source.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def create_send_client(
idle_timeout: Optional[float],
network_trace: bool,
retry_policy: Any,
keep_alive_interval: int,
keep_alive_interval: Optional[int],
client_name: str,
link_properties: Optional[Dict[str, Any]] = None,
properties: Optional[Dict[str, Any]] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def create_send_client(
idle_timeout: Optional[float],
network_trace: bool,
retry_policy: Any,
keep_alive_interval: int,
keep_alive_interval: Optional[int],
client_name: str,
link_properties: Optional[Dict[str, Any]] = None,
properties: Optional[Dict[str, Any]] = None,
Expand Down
2 changes: 1 addition & 1 deletion sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def set_event_partition_key(
raw_message.header.durable = True


def event_position_selector(value: Union[str, int, datetime.datetime], inclusive: bool = False) -> bytes:
def event_position_selector(value: Optional[Union[str, int, datetime.datetime]], inclusive: bool = False) -> bytes:
"""Creates a selector expression of the offset.
:param int or str or datetime.datetime value: The offset value to use for the offset.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
from __future__ import annotations
import datetime
from abc import ABC, abstractmethod
from typing import List, Tuple, Union, TYPE_CHECKING, Optional, Any, Dict, Callable
from typing_extensions import Literal
Expand Down Expand Up @@ -201,7 +202,7 @@ def create_send_client(
idle_timeout: Optional[float],
network_trace: bool,
retry_policy: Any,
keep_alive_interval: int,
keep_alive_interval: Optional[int],
client_name: str,
link_properties: Optional[Dict[str, Any]],
properties: Optional[Dict[str, Any]],
Expand Down Expand Up @@ -249,7 +250,11 @@ def set_message_partition_key(

@staticmethod
@abstractmethod
def create_source(source: str, offset: int, selector: bytes) -> Union["uamqp_Source", "pyamqp_Source"]:
def create_source(
source: str,
offset: Optional[Union[int, str, datetime.datetime]],
selector: bytes,
) -> Union["uamqp_Source", "pyamqp_Source"]:
"""
Creates and returns the Source.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def create_send_client(
idle_timeout: Optional[float],
network_trace: bool,
retry_policy: Any,
keep_alive_interval: int,
keep_alive_interval: Optional[int],
client_name: str,
link_properties: Optional[Dict[str, Any]],
properties: Optional[Dict[str, Any]] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def create_send_client(
idle_timeout: Optional[float],
network_trace: bool,
retry_policy: Any,
keep_alive_interval: int,
keep_alive_interval: Optional[int],
client_name: str,
link_properties: Optional[Dict[str, Any]] = None,
properties: Optional[Dict[str, Any]] = None,
Expand Down
Loading