diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py index 0e4ebfee4a68..290a6268a556 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py @@ -318,6 +318,8 @@ def sequence_number(self) -> Optional[int]: :rtype: int or None """ + if self._raw_amqp_message.annotations is None: + return None return self._raw_amqp_message.annotations.get(PROP_SEQ_NUMBER, None) @property @@ -327,8 +329,8 @@ def offset(self) -> Optional[str]: :rtype: str or None """ try: - return self._raw_amqp_message.annotations[PROP_OFFSET].decode("UTF-8") - except (KeyError, AttributeError): + return self._raw_amqp_message.annotations[PROP_OFFSET].decode("UTF-8") # type: ignore[index] + except (KeyError, AttributeError, TypeError): return None @property @@ -337,7 +339,8 @@ def enqueued_time(self) -> Optional[datetime.datetime]: :rtype: datetime.datetime or None """ - timestamp = self._raw_amqp_message.annotations.get(PROP_TIMESTAMP, None) + annotations = self._raw_amqp_message.annotations or {} + timestamp = annotations.get(PROP_TIMESTAMP, None) if timestamp: return utc_from_timestamp(float(timestamp) / 1000) return None @@ -348,6 +351,8 @@ def partition_key(self) -> Optional[bytes]: :rtype: bytes or None """ + if self._raw_amqp_message.annotations is None: + return None return self._raw_amqp_message.annotations.get(PROP_PARTITION_KEY, None) @property @@ -356,6 +361,8 @@ def properties(self) -> Dict[Union[str, bytes], Any]: :rtype: dict[str, any] or dict[bytes, any] """ + if self._raw_amqp_message.application_properties is None: + self._raw_amqp_message.application_properties = {} return self._raw_amqp_message.application_properties @properties.setter @@ -402,7 +409,8 @@ def system_properties(self) -> Dict[bytes, Any]: value = getattr(self._raw_amqp_message.properties, prop_name, None) if value: self._sys_properties[key] = value - self._sys_properties.update(self._raw_amqp_message.annotations) + if self._raw_amqp_message.annotations: + self._sys_properties.update(self._raw_amqp_message.annotations) # type: ignore[arg-type] return self._sys_properties @property @@ -483,10 +491,10 @@ def content_type(self) -> Optional[str]: return self._raw_amqp_message.properties.content_type @content_type.setter - def content_type(self, value: str) -> None: - if not self._raw_amqp_message.properties: - self._raw_amqp_message.properties = AmqpMessageProperties() - self._raw_amqp_message.properties.content_type = value + def content_type(self, value: Optional[str]) -> None: + properties = self._raw_amqp_message.properties or AmqpMessageProperties() + properties.content_type = value + self._raw_amqp_message.properties = properties @property def correlation_id(self) -> Optional[str]: @@ -503,10 +511,10 @@ def correlation_id(self) -> Optional[str]: return self._raw_amqp_message.properties.correlation_id @correlation_id.setter - def correlation_id(self, value: str) -> None: - if not self._raw_amqp_message.properties: - self._raw_amqp_message.properties = AmqpMessageProperties() - self._raw_amqp_message.properties.correlation_id = value + def correlation_id(self, value: Optional[str]) -> None: + properties = self._raw_amqp_message.properties or AmqpMessageProperties() + properties.correlation_id = value + self._raw_amqp_message.properties = properties @property def message_id(self) -> Optional[str]: @@ -525,10 +533,10 @@ def message_id(self) -> Optional[str]: return self._raw_amqp_message.properties.message_id @message_id.setter - def message_id(self, value: str) -> None: - if not self._raw_amqp_message.properties: - self._raw_amqp_message.properties = AmqpMessageProperties() - self._raw_amqp_message.properties.message_id = value + def message_id(self, value: Optional[str]) -> None: + properties = self._raw_amqp_message.properties or AmqpMessageProperties() + properties.message_id = value + self._raw_amqp_message.properties = properties class EventDataBatch: diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_decode.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_decode.py index afe58f311a0c..c3c72afbb8c2 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_decode.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_decode.py @@ -29,6 +29,24 @@ _LOGGER = logging.getLogger(__name__) _HEADER_PREFIX = memoryview(b"AMQP") + +# Maximum number of elements permitted in any AMQP compound type (list, array, map). +# +# The AMQP 1.0 wire format encodes compound element counts as either a 1-byte +# field (the *_small variants, naturally bounded at 255) or a 4-byte field +# (the *_large variants, wire-level maximum 0xFFFFFFFF). The large variants +# allocate a Python list/dict sized directly from this count, so without an +# upper bound a small frame can demand a multi-gigabyte allocation. This cap +# is applied at every large-variant decode site to keep allocation sizes +# proportional to the bytes actually delivered. +# +# The value mirrors MAX_AMQPVALUE_ITEM_COUNT (65536) in the C reference +# implementation, azure-uamqp-c, and the equivalent MAX_COMPOUND_COUNT bound +# applied across list/array/map decode sites in the Java AMQP codec. Real +# AMQP traffic does not approach this many elements in a single compound +# value, so the bound functions as a hard ceiling rather than a practical +# constraint on legitimate workloads. +_MAX_COMPOUND_COUNT = 65536 _COMPOSITES = { 35: "received", 36: "accepted", @@ -222,6 +240,12 @@ def _decode_list_small(buffer: memoryview) -> Tuple[memoryview, List[Any]]: def _decode_list_large(buffer: memoryview) -> Tuple[memoryview, List[Any]]: count = c_unsigned_long.unpack(buffer[4:8])[0] + # Validate the wire-supplied count before allocating `[None] * count`, + # which would otherwise scale linearly with an untrusted 32-bit value. + if count > _MAX_COMPOUND_COUNT: + raise ValueError( + f"AMQP list element count {count} exceeds maximum {_MAX_COMPOUND_COUNT}" + ) buffer = buffer[8:] values = [None] * count for i in range(count): @@ -230,7 +254,12 @@ def _decode_list_large(buffer: memoryview) -> Tuple[memoryview, List[Any]]: def _decode_map_small(buffer: memoryview) -> Tuple[memoryview, Dict[Any, Any]]: - count = int(buffer[1] / 2) + raw_count = buffer[1] + if raw_count % 2 != 0: + raise ValueError( + f"AMQP map element count {raw_count} must be even (key/value pairs)" + ) + count = raw_count // 2 buffer = buffer[2:] values = {} for _ in range(count): @@ -241,7 +270,22 @@ def _decode_map_small(buffer: memoryview) -> Tuple[memoryview, Dict[Any, Any]]: def _decode_map_large(buffer: memoryview) -> Tuple[memoryview, Dict[Any, Any]]: - count = int(c_unsigned_long.unpack(buffer[4:8])[0] / 2) + # Validate the raw on-wire count *before* halving it (the AMQP encoding + # stores total entries; pairs = entries / 2). Checking pre-halve keeps + # the comparison aligned with the bound used by _decode_list_large / + # _decode_array_large. Odd counts are rejected explicitly: silently + # flooring to (raw_count - 1) // 2 would leave a trailing key with no + # value, leaking bytes into the next decoder. + raw_count = c_unsigned_long.unpack(buffer[4:8])[0] + if raw_count > _MAX_COMPOUND_COUNT: + raise ValueError( + f"AMQP map element count {raw_count} exceeds maximum {_MAX_COMPOUND_COUNT}" + ) + if raw_count % 2 != 0: + raise ValueError( + f"AMQP map element count {raw_count} must be even (key/value pairs)" + ) + count = raw_count // 2 buffer = buffer[8:] values = {} for _ in range(count): @@ -265,6 +309,13 @@ def _decode_array_small(buffer: memoryview) -> Tuple[memoryview, List[Any]]: def _decode_array_large(buffer: memoryview) -> Tuple[memoryview, List[Any]]: count = c_unsigned_long.unpack(buffer[4:8])[0] + # Validate the wire-supplied count before allocating `[None] * count`. + # An Array32 frame's COUNT is read directly from the network and would + # otherwise drive a Python list allocation of arbitrary size. + if count > _MAX_COMPOUND_COUNT: + raise ValueError( + f"AMQP array element count {count} exceeds maximum {_MAX_COMPOUND_COUNT}" + ) if count: subconstructor = buffer[8] buffer = buffer[9:] @@ -333,11 +384,18 @@ def decode_frame(data: memoryview) -> Tuple[int, List[Any]]: frame_type = data[2] compound_list_type = data[3] if compound_list_type == 0xD0: - # list32 0xd0: data[4:8] is size, data[8:12] is count - count = c_signed_int.unpack(data[8:12])[0] + # list32 0xd0: data[4:8] is size, data[8:12] is count. The AMQP 1.0 + # wire format defines COUNT as an unsigned 32-bit field; decoding it + # as a signed int and skipping the cap would let a malicious peer + # request a multi-gigabyte field-list allocation below. + count = c_unsigned_long.unpack(data[8:12])[0] + if count > _MAX_COMPOUND_COUNT: + raise ValueError( + f"AMQP frame field count {count} exceeds maximum {_MAX_COMPOUND_COUNT}" + ) buffer = data[12:] else: - # list8 0xc0: data[4] is size, data[5] is count + # list8 0xc0: data[4] is size, data[5] is count (1 byte, bounded at 255). count = data[5] buffer = data[6:] fields: List[Optional[memoryview]] = [None] * count diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py index a297c111eef6..cfbc284b38cd 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- from __future__ import annotations +import datetime from typing import List, Tuple, Union, TYPE_CHECKING, Optional, Any, Dict, Callable from abc import ABC, abstractmethod @@ -209,7 +210,7 @@ def create_send_client( idle_timeout: Optional[float], network_trace: bool, retry_policy: Any, - keep_alive_interval: int, + keep_alive_interval: Optional[int], client_name: str, link_properties: Optional[Dict[str, Any]], properties: Optional[Dict[str, Any]], @@ -270,7 +271,11 @@ def add_batch( @staticmethod @abstractmethod - def create_source(source: Union["uamqp_Source", "pyamqp_Source"], offset: int, selector: bytes): + def create_source( + source: Union["uamqp_Source", "pyamqp_Source"], + offset: Optional[Union[int, str, datetime.datetime]], + selector: bytes, + ): """ Creates and returns the Source. diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py index d9b8c5211253..450e959122fa 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py @@ -304,7 +304,7 @@ def create_send_client( idle_timeout: Optional[float], network_trace: bool, retry_policy: Any, - keep_alive_interval: int, + keep_alive_interval: Optional[int], client_name: str, link_properties: Optional[Dict[str, Any]] = None, properties: Optional[Dict[str, Any]] = None, diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py index e04aa072686e..af915a02bc66 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py @@ -365,7 +365,7 @@ def create_send_client( idle_timeout: Optional[float], network_trace: bool, retry_policy: Any, - keep_alive_interval: int, + keep_alive_interval: Optional[int], client_name: str, link_properties: Optional[Dict[str, Any]] = None, properties: Optional[Dict[str, Any]] = None, diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py index e347ae2e3fc5..6711545dd264 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py @@ -120,7 +120,7 @@ def set_event_partition_key( raw_message.header.durable = True -def event_position_selector(value: Union[str, int, datetime.datetime], inclusive: bool = False) -> bytes: +def event_position_selector(value: Optional[Union[str, int, datetime.datetime]], inclusive: bool = False) -> bytes: """Creates a selector expression of the offset. :param int or str or datetime.datetime value: The offset value to use for the offset. diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py index aaad055c7bbd..895c3abcd88e 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- from __future__ import annotations +import datetime from abc import ABC, abstractmethod from typing import List, Tuple, Union, TYPE_CHECKING, Optional, Any, Dict, Callable from typing_extensions import Literal @@ -201,7 +202,7 @@ def create_send_client( idle_timeout: Optional[float], network_trace: bool, retry_policy: Any, - keep_alive_interval: int, + keep_alive_interval: Optional[int], client_name: str, link_properties: Optional[Dict[str, Any]], properties: Optional[Dict[str, Any]], @@ -249,7 +250,11 @@ def set_message_partition_key( @staticmethod @abstractmethod - def create_source(source: str, offset: int, selector: bytes) -> Union["uamqp_Source", "pyamqp_Source"]: + def create_source( + source: str, + offset: Optional[Union[int, str, datetime.datetime]], + selector: bytes, + ) -> Union["uamqp_Source", "pyamqp_Source"]: """ Creates and returns the Source. diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_pyamqp_transport_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_pyamqp_transport_async.py index 64c56e0b71d4..fd49d6851642 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_pyamqp_transport_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_pyamqp_transport_async.py @@ -96,7 +96,7 @@ def create_send_client( idle_timeout: Optional[float], network_trace: bool, retry_policy: Any, - keep_alive_interval: int, + keep_alive_interval: Optional[int], client_name: str, link_properties: Optional[Dict[str, Any]], properties: Optional[Dict[str, Any]] = None, diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py index c4fae70a0542..e6ef72d030a0 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py @@ -119,7 +119,7 @@ def create_send_client( idle_timeout: Optional[float], network_trace: bool, retry_policy: Any, - keep_alive_interval: int, + keep_alive_interval: Optional[int], client_name: str, link_properties: Optional[Dict[str, Any]] = None, properties: Optional[Dict[str, Any]] = None, diff --git a/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/unittest/test_decode_bounds.py b/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/unittest/test_decode_bounds.py new file mode 100644 index 000000000000..30cb90582297 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/unittest/test_decode_bounds.py @@ -0,0 +1,100 @@ +import pytest + +from azure.eventhub._pyamqp._decode import ( + _decode_array_large, + _decode_list_large, + _decode_map_large, + _decode_map_small, + decode_frame, + _MAX_COMPOUND_COUNT, +) + + +def _header(count: int) -> bytes: + # 4 bytes size (unused by the decoder beyond slicing) + 4 bytes big-endian count + return b"\x00\x00\x00\x00" + count.to_bytes(4, "big") + + +HUGE_COUNT = 0x7FFFFFFF +JUST_OVER = _MAX_COMPOUND_COUNT + 1 + + +@pytest.mark.parametrize("count", [HUGE_COUNT, JUST_OVER]) +def test_decode_array_large_rejects_oversized_count(count): + buffer = memoryview(_header(count)) + with pytest.raises(ValueError, match="exceeds maximum"): + _decode_array_large(buffer) + + +@pytest.mark.parametrize("count", [HUGE_COUNT, JUST_OVER]) +def test_decode_list_large_rejects_oversized_count(count): + buffer = memoryview(_header(count)) + with pytest.raises(ValueError, match="exceeds maximum"): + _decode_list_large(buffer) + + +@pytest.mark.parametrize("count", [HUGE_COUNT, JUST_OVER]) +def test_decode_map_large_rejects_oversized_count(count): + buffer = memoryview(_header(count)) + with pytest.raises(ValueError, match="exceeds maximum"): + _decode_map_large(buffer) + + +def test_decode_array_large_accepts_boundary(): + # COUNT exactly at the cap with a null subconstructor (0x40). _decode_null + # consumes no bytes, so the result is a list of _MAX_COMPOUND_COUNT Nones. + buffer = memoryview(_header(_MAX_COMPOUND_COUNT) + b"\x40") + remaining, values = _decode_array_large(buffer) + assert len(values) == _MAX_COMPOUND_COUNT + assert all(v is None for v in values) + assert bytes(remaining) == b"" + + +def test_decode_list_large_accepts_boundary(): + # Each element carries its own constructor byte; _MAX_COMPOUND_COUNT nulls. + body = b"\x40" * _MAX_COMPOUND_COUNT + buffer = memoryview(_header(_MAX_COMPOUND_COUNT) + body) + remaining, values = _decode_list_large(buffer) + assert len(values) == _MAX_COMPOUND_COUNT + assert all(v is None for v in values) + assert bytes(remaining) == b"" + + +def test_decode_map_large_accepts_boundary(): + # COUNT counts entries (keys + values); pairs = count // 2. + body = b"\x40" * _MAX_COMPOUND_COUNT + buffer = memoryview(_header(_MAX_COMPOUND_COUNT) + body) + remaining, values = _decode_map_large(buffer) + # All keys collapse to None, so the dict has a single entry. + assert values == {None: None} + assert bytes(remaining) == b"" + + +def _frame_list32(count: int) -> bytes: + # decode_frame skips data[0:2] (described/ulong constructors), reads + # frame_type at data[2], the compound marker at data[3], size at + # data[4:8], and the COUNT under test at data[8:12]. + return b"\x00\x53\x00\xd0" + b"\x00\x00\x00\x00" + count.to_bytes(4, "big") + + +@pytest.mark.parametrize("count", [HUGE_COUNT, JUST_OVER]) +def test_decode_frame_rejects_oversized_list32_count(count): + buffer = memoryview(_frame_list32(count)) + with pytest.raises(ValueError, match="exceeds maximum"): + decode_frame(buffer) + + +def test_decode_map_large_rejects_odd_count(): + # An odd raw COUNT would silently floor to pairs = (count - 1) // 2 and + # leave a trailing key with no value, corrupting subsequent decoding. + buffer = memoryview(_header(3)) + with pytest.raises(ValueError, match="must be even"): + _decode_map_large(buffer) + + +def test_decode_map_small_rejects_odd_count(): + # _decode_map_small reads the COUNT from buffer[1] (1 byte, 0-255). An + # odd value has the same trailing-key problem as the large variant. + buffer = memoryview(b"\x00\x03") + with pytest.raises(ValueError, match="must be even"): + _decode_map_small(buffer) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_decode.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_decode.py index afe58f311a0c..c3c72afbb8c2 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_decode.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_decode.py @@ -29,6 +29,24 @@ _LOGGER = logging.getLogger(__name__) _HEADER_PREFIX = memoryview(b"AMQP") + +# Maximum number of elements permitted in any AMQP compound type (list, array, map). +# +# The AMQP 1.0 wire format encodes compound element counts as either a 1-byte +# field (the *_small variants, naturally bounded at 255) or a 4-byte field +# (the *_large variants, wire-level maximum 0xFFFFFFFF). The large variants +# allocate a Python list/dict sized directly from this count, so without an +# upper bound a small frame can demand a multi-gigabyte allocation. This cap +# is applied at every large-variant decode site to keep allocation sizes +# proportional to the bytes actually delivered. +# +# The value mirrors MAX_AMQPVALUE_ITEM_COUNT (65536) in the C reference +# implementation, azure-uamqp-c, and the equivalent MAX_COMPOUND_COUNT bound +# applied across list/array/map decode sites in the Java AMQP codec. Real +# AMQP traffic does not approach this many elements in a single compound +# value, so the bound functions as a hard ceiling rather than a practical +# constraint on legitimate workloads. +_MAX_COMPOUND_COUNT = 65536 _COMPOSITES = { 35: "received", 36: "accepted", @@ -222,6 +240,12 @@ def _decode_list_small(buffer: memoryview) -> Tuple[memoryview, List[Any]]: def _decode_list_large(buffer: memoryview) -> Tuple[memoryview, List[Any]]: count = c_unsigned_long.unpack(buffer[4:8])[0] + # Validate the wire-supplied count before allocating `[None] * count`, + # which would otherwise scale linearly with an untrusted 32-bit value. + if count > _MAX_COMPOUND_COUNT: + raise ValueError( + f"AMQP list element count {count} exceeds maximum {_MAX_COMPOUND_COUNT}" + ) buffer = buffer[8:] values = [None] * count for i in range(count): @@ -230,7 +254,12 @@ def _decode_list_large(buffer: memoryview) -> Tuple[memoryview, List[Any]]: def _decode_map_small(buffer: memoryview) -> Tuple[memoryview, Dict[Any, Any]]: - count = int(buffer[1] / 2) + raw_count = buffer[1] + if raw_count % 2 != 0: + raise ValueError( + f"AMQP map element count {raw_count} must be even (key/value pairs)" + ) + count = raw_count // 2 buffer = buffer[2:] values = {} for _ in range(count): @@ -241,7 +270,22 @@ def _decode_map_small(buffer: memoryview) -> Tuple[memoryview, Dict[Any, Any]]: def _decode_map_large(buffer: memoryview) -> Tuple[memoryview, Dict[Any, Any]]: - count = int(c_unsigned_long.unpack(buffer[4:8])[0] / 2) + # Validate the raw on-wire count *before* halving it (the AMQP encoding + # stores total entries; pairs = entries / 2). Checking pre-halve keeps + # the comparison aligned with the bound used by _decode_list_large / + # _decode_array_large. Odd counts are rejected explicitly: silently + # flooring to (raw_count - 1) // 2 would leave a trailing key with no + # value, leaking bytes into the next decoder. + raw_count = c_unsigned_long.unpack(buffer[4:8])[0] + if raw_count > _MAX_COMPOUND_COUNT: + raise ValueError( + f"AMQP map element count {raw_count} exceeds maximum {_MAX_COMPOUND_COUNT}" + ) + if raw_count % 2 != 0: + raise ValueError( + f"AMQP map element count {raw_count} must be even (key/value pairs)" + ) + count = raw_count // 2 buffer = buffer[8:] values = {} for _ in range(count): @@ -265,6 +309,13 @@ def _decode_array_small(buffer: memoryview) -> Tuple[memoryview, List[Any]]: def _decode_array_large(buffer: memoryview) -> Tuple[memoryview, List[Any]]: count = c_unsigned_long.unpack(buffer[4:8])[0] + # Validate the wire-supplied count before allocating `[None] * count`. + # An Array32 frame's COUNT is read directly from the network and would + # otherwise drive a Python list allocation of arbitrary size. + if count > _MAX_COMPOUND_COUNT: + raise ValueError( + f"AMQP array element count {count} exceeds maximum {_MAX_COMPOUND_COUNT}" + ) if count: subconstructor = buffer[8] buffer = buffer[9:] @@ -333,11 +384,18 @@ def decode_frame(data: memoryview) -> Tuple[int, List[Any]]: frame_type = data[2] compound_list_type = data[3] if compound_list_type == 0xD0: - # list32 0xd0: data[4:8] is size, data[8:12] is count - count = c_signed_int.unpack(data[8:12])[0] + # list32 0xd0: data[4:8] is size, data[8:12] is count. The AMQP 1.0 + # wire format defines COUNT as an unsigned 32-bit field; decoding it + # as a signed int and skipping the cap would let a malicious peer + # request a multi-gigabyte field-list allocation below. + count = c_unsigned_long.unpack(data[8:12])[0] + if count > _MAX_COMPOUND_COUNT: + raise ValueError( + f"AMQP frame field count {count} exceeds maximum {_MAX_COMPOUND_COUNT}" + ) buffer = data[12:] else: - # list8 0xc0: data[4] is size, data[5] is count + # list8 0xc0: data[4] is size, data[5] is count (1 byte, bounded at 255). count = data[5] buffer = data[6:] fields: List[Optional[memoryview]] = [None] * count diff --git a/sdk/servicebus/azure-servicebus/tests/unittests/test_decode_bounds.py b/sdk/servicebus/azure-servicebus/tests/unittests/test_decode_bounds.py new file mode 100644 index 000000000000..d0892e871f3e --- /dev/null +++ b/sdk/servicebus/azure-servicebus/tests/unittests/test_decode_bounds.py @@ -0,0 +1,100 @@ +import pytest + +from azure.servicebus._pyamqp._decode import ( + _decode_array_large, + _decode_list_large, + _decode_map_large, + _decode_map_small, + decode_frame, + _MAX_COMPOUND_COUNT, +) + + +def _header(count: int) -> bytes: + # 4 bytes size (unused by the decoder beyond slicing) + 4 bytes big-endian count + return b"\x00\x00\x00\x00" + count.to_bytes(4, "big") + + +HUGE_COUNT = 0x7FFFFFFF +JUST_OVER = _MAX_COMPOUND_COUNT + 1 + + +@pytest.mark.parametrize("count", [HUGE_COUNT, JUST_OVER]) +def test_decode_array_large_rejects_oversized_count(count): + buffer = memoryview(_header(count)) + with pytest.raises(ValueError, match="exceeds maximum"): + _decode_array_large(buffer) + + +@pytest.mark.parametrize("count", [HUGE_COUNT, JUST_OVER]) +def test_decode_list_large_rejects_oversized_count(count): + buffer = memoryview(_header(count)) + with pytest.raises(ValueError, match="exceeds maximum"): + _decode_list_large(buffer) + + +@pytest.mark.parametrize("count", [HUGE_COUNT, JUST_OVER]) +def test_decode_map_large_rejects_oversized_count(count): + buffer = memoryview(_header(count)) + with pytest.raises(ValueError, match="exceeds maximum"): + _decode_map_large(buffer) + + +def test_decode_array_large_accepts_boundary(): + # COUNT exactly at the cap with a null subconstructor (0x40). _decode_null + # consumes no bytes, so the result is a list of _MAX_COMPOUND_COUNT Nones. + buffer = memoryview(_header(_MAX_COMPOUND_COUNT) + b"\x40") + remaining, values = _decode_array_large(buffer) + assert len(values) == _MAX_COMPOUND_COUNT + assert all(v is None for v in values) + assert bytes(remaining) == b"" + + +def test_decode_list_large_accepts_boundary(): + # Each element carries its own constructor byte; _MAX_COMPOUND_COUNT nulls. + body = b"\x40" * _MAX_COMPOUND_COUNT + buffer = memoryview(_header(_MAX_COMPOUND_COUNT) + body) + remaining, values = _decode_list_large(buffer) + assert len(values) == _MAX_COMPOUND_COUNT + assert all(v is None for v in values) + assert bytes(remaining) == b"" + + +def test_decode_map_large_accepts_boundary(): + # COUNT counts entries (keys + values); pairs = count // 2. + body = b"\x40" * _MAX_COMPOUND_COUNT + buffer = memoryview(_header(_MAX_COMPOUND_COUNT) + body) + remaining, values = _decode_map_large(buffer) + # All keys collapse to None, so the dict has a single entry. + assert values == {None: None} + assert bytes(remaining) == b"" + + +def _frame_list32(count: int) -> bytes: + # decode_frame skips data[0:2] (described/ulong constructors), reads + # frame_type at data[2], the compound marker at data[3], size at + # data[4:8], and the COUNT under test at data[8:12]. + return b"\x00\x53\x00\xd0" + b"\x00\x00\x00\x00" + count.to_bytes(4, "big") + + +@pytest.mark.parametrize("count", [HUGE_COUNT, JUST_OVER]) +def test_decode_frame_rejects_oversized_list32_count(count): + buffer = memoryview(_frame_list32(count)) + with pytest.raises(ValueError, match="exceeds maximum"): + decode_frame(buffer) + + +def test_decode_map_large_rejects_odd_count(): + # An odd raw COUNT would silently floor to pairs = (count - 1) // 2 and + # leave a trailing key with no value, corrupting subsequent decoding. + buffer = memoryview(_header(3)) + with pytest.raises(ValueError, match="must be even"): + _decode_map_large(buffer) + + +def test_decode_map_small_rejects_odd_count(): + # _decode_map_small reads the COUNT from buffer[1] (1 byte, 0-255). An + # odd value has the same trailing-key problem as the large variant. + buffer = memoryview(b"\x00\x03") + with pytest.raises(ValueError, match="must be even"): + _decode_map_small(buffer)