diff --git a/kafka/protocol/message.py b/kafka/protocol/message.py deleted file mode 100644 index 03d1098c0..000000000 --- a/kafka/protocol/message.py +++ /dev/null @@ -1,214 +0,0 @@ -import io -import time - -from kafka.codec import (has_gzip, has_snappy, has_lz4, has_zstd, - gzip_decode, snappy_decode, zstd_decode, - lz4_decode, lz4_decode_old_kafka) -from kafka.protocol.frame import KafkaBytes -from kafka.protocol.struct import Struct -from kafka.protocol.types import ( - Int8, Int32, Int64, Bytes, Schema, AbstractType -) -from kafka.util import crc32, WeakMethod - - -class Message(Struct): - SCHEMAS = [ - Schema( - ('crc', Int32), - ('magic', Int8), - ('attributes', Int8), - ('key', Bytes), - ('value', Bytes)), - Schema( - ('crc', Int32), - ('magic', Int8), - ('attributes', Int8), - ('timestamp', Int64), - ('key', Bytes), - ('value', Bytes)), - ] - SCHEMA = SCHEMAS[1] - CODEC_MASK = 0x07 - CODEC_GZIP = 0x01 - CODEC_SNAPPY = 0x02 - CODEC_LZ4 = 0x03 - CODEC_ZSTD = 0x04 - TIMESTAMP_TYPE_MASK = 0x08 - HEADER_SIZE = 22 # crc(4), magic(1), attributes(1), timestamp(8), key+value size(4*2) - - def __init__(self, value, key=None, magic=0, attributes=0, crc=0, - timestamp=None): - assert value is None or isinstance(value, bytes), 'value must be bytes' - assert key is None or isinstance(key, bytes), 'key must be bytes' - assert magic > 0 or timestamp is None, 'timestamp not supported in v0' - - # Default timestamp to now for v1 messages - if magic > 0 and timestamp is None: - timestamp = int(time.time() * 1000) - self.timestamp = timestamp - self.crc = crc - self._validated_crc = None - self.magic = magic - self.attributes = attributes - self.key = key - self.value = value - self.encode = WeakMethod(self._encode_self) - - @property - def timestamp_type(self): - """0 for CreateTime; 1 for LogAppendTime; None if unsupported. - - Value is determined by broker; produced messages should always set to 0 - Requires Kafka >= 0.10 / message version >= 1 - """ - if self.magic == 0: - return None - elif self.attributes & self.TIMESTAMP_TYPE_MASK: - return 1 - else: - return 0 - - def _encode_self(self, recalc_crc=True): - version = self.magic - if version == 1: - fields = (self.crc, self.magic, self.attributes, self.timestamp, self.key, self.value) - elif version == 0: - fields = (self.crc, self.magic, self.attributes, self.key, self.value) - else: - raise ValueError('Unrecognized message version: %s' % (version,)) - message = Message.SCHEMAS[version].encode(fields) - if not recalc_crc: - return message - self.crc = crc32(message[4:]) - crc_field = self.SCHEMAS[version].fields[0] - return crc_field.encode(self.crc) + message[4:] - - @classmethod - def decode(cls, data): - _validated_crc = None - if isinstance(data, bytes): - _validated_crc = crc32(data[4:]) - data = io.BytesIO(data) - # Partial decode required to determine message version - base_fields = cls.SCHEMAS[0].fields[0:3] - crc, magic, attributes = [field.decode(data) for field in base_fields] - remaining = cls.SCHEMAS[magic].fields[3:] - fields = [field.decode(data) for field in remaining] - if magic == 1: - timestamp = fields[0] - else: - timestamp = None - msg = cls(fields[-1], key=fields[-2], - magic=magic, attributes=attributes, crc=crc, - timestamp=timestamp) - msg._validated_crc = _validated_crc - return msg - - def validate_crc(self): - if self._validated_crc is None: - raw_msg = self._encode_self(recalc_crc=False) - self._validated_crc = crc32(raw_msg[4:]) - if self.crc == self._validated_crc: - return True - return False - - def is_compressed(self): - return self.attributes & self.CODEC_MASK != 0 - - def decompress(self): - codec = self.attributes & self.CODEC_MASK - assert codec in (self.CODEC_GZIP, self.CODEC_SNAPPY, self.CODEC_LZ4, self.CODEC_ZSTD) - if codec == self.CODEC_GZIP: - assert has_gzip(), 'Gzip decompression unsupported' - raw_bytes = gzip_decode(self.value) - elif codec == self.CODEC_SNAPPY: - assert has_snappy(), 'Snappy decompression unsupported' - raw_bytes = snappy_decode(self.value) - elif codec == self.CODEC_LZ4: - assert has_lz4(), 'LZ4 decompression unsupported' - if self.magic == 0: - raw_bytes = lz4_decode_old_kafka(self.value) - else: - raw_bytes = lz4_decode(self.value) - elif codec == self.CODEC_ZSTD: - assert has_zstd(), "ZSTD decompression unsupported" - raw_bytes = zstd_decode(self.value) - else: - raise Exception('This should be impossible') - - return MessageSet.decode(raw_bytes, bytes_to_read=len(raw_bytes)) - - def __hash__(self): - return hash(self._encode_self(recalc_crc=False)) - - -class PartialMessage(bytes): - def __repr__(self): - return 'PartialMessage(%s)' % (self,) - - -class MessageSet(AbstractType): - ITEM = Schema( - ('offset', Int64), - ('message', Bytes) - ) - HEADER_SIZE = 12 # offset + message_size - - @classmethod - def encode(cls, items, prepend_size=True): - # RecordAccumulator encodes messagesets internally - if isinstance(items, (io.BytesIO, KafkaBytes)): - size = Int32.decode(items) - if prepend_size: - # rewind and return all the bytes - items.seek(items.tell() - 4) - size += 4 - return items.read(size) - - encoded_values = [] - for (offset, message) in items: - encoded_values.append(Int64.encode(offset)) - encoded_values.append(Bytes.encode(message)) - encoded = b''.join(encoded_values) - if prepend_size: - return Bytes.encode(encoded) - else: - return encoded - - @classmethod - def decode(cls, data, bytes_to_read=None): - """Compressed messages should pass in bytes_to_read (via message size) - otherwise, we decode from data as Int32 - """ - if isinstance(data, bytes): - data = io.BytesIO(data) - if bytes_to_read is None: - bytes_to_read = Int32.decode(data) - - # if FetchRequest max_bytes is smaller than the available message set - # the server returns partial data for the final message - # So create an internal buffer to avoid over-reading - raw = io.BytesIO(data.read(bytes_to_read)) - - items = [] - while bytes_to_read: - try: - offset = Int64.decode(raw) - msg_bytes = Bytes.decode(raw) - bytes_to_read -= 8 + 4 + len(msg_bytes) - items.append((offset, len(msg_bytes), Message.decode(msg_bytes))) - except ValueError: - # PartialMessage to signal that max_bytes may be too small - items.append((None, None, PartialMessage())) - break - return items - - @classmethod - def repr(cls, messages): - if isinstance(messages, (KafkaBytes, io.BytesIO)): - offset = messages.tell() - decoded = cls.decode(messages) - messages.seek(offset) - messages = decoded - return str([cls.ITEM.repr(m) for m in messages]) diff --git a/kafka/util.py b/kafka/util.py index 29482bce1..5c7dd927c 100644 --- a/kafka/util.py +++ b/kafka/util.py @@ -7,18 +7,6 @@ from kafka.errors import KafkaTimeoutError -MAX_INT = 2 ** 31 -TO_SIGNED = 2 ** 32 - -def crc32(data): - crc = binascii.crc32(data) - # CRC is encoded as a signed int in kafka protocol - # so we'll convert the unsigned result to signed - if crc >= MAX_INT: - crc -= TO_SIGNED - return crc - - class Timer: __slots__ = ('_start_at', '_expire_at', '_timeout_ms', '_error_message') diff --git a/test/test_protocol.py b/test/test_protocol.py index 35ca938e1..45755c4c0 100644 --- a/test/test_protocol.py +++ b/test/test_protocol.py @@ -7,158 +7,10 @@ from kafka.protocol.api import RequestHeader from kafka.protocol.fetch import FetchRequest, FetchResponse from kafka.protocol.find_coordinator import FindCoordinatorRequest -from kafka.protocol.message import Message, MessageSet, PartialMessage from kafka.protocol.metadata import MetadataRequest from kafka.protocol.types import Int16, Int32, Int64, String, UnsignedVarInt32, CompactString, CompactArray, CompactBytes, BitField -def test_create_message(): - payload = b'test' - key = b'key' - msg = Message(payload, key=key) - assert msg.magic == 0 - assert msg.attributes == 0 - assert msg.key == key - assert msg.value == payload - - -def test_encode_message_v0(): - message = Message(b'test', key=b'key') - encoded = message.encode() - expect = b''.join([ - struct.pack('>i', -1427009701), # CRC - struct.pack('>bb', 0, 0), # Magic, flags - struct.pack('>i', 3), # Length of key - b'key', # key - struct.pack('>i', 4), # Length of value - b'test', # value - ]) - assert encoded == expect - - -def test_encode_message_v1(): - message = Message(b'test', key=b'key', magic=1, timestamp=1234) - encoded = message.encode() - expect = b''.join([ - struct.pack('>i', 1331087195), # CRC - struct.pack('>bb', 1, 0), # Magic, flags - struct.pack('>q', 1234), # Timestamp - struct.pack('>i', 3), # Length of key - b'key', # key - struct.pack('>i', 4), # Length of value - b'test', # value - ]) - assert encoded == expect - - -def test_decode_message(): - encoded = b''.join([ - struct.pack('>i', -1427009701), # CRC - struct.pack('>bb', 0, 0), # Magic, flags - struct.pack('>i', 3), # Length of key - b'key', # key - struct.pack('>i', 4), # Length of value - b'test', # value - ]) - decoded_message = Message.decode(encoded) - msg = Message(b'test', key=b'key') - msg.encode() # crc is recalculated during encoding - assert decoded_message == msg - - -def test_decode_message_validate_crc(): - encoded = b''.join([ - struct.pack('>i', -1427009701), # CRC - struct.pack('>bb', 0, 0), # Magic, flags - struct.pack('>i', 3), # Length of key - b'key', # key - struct.pack('>i', 4), # Length of value - b'test', # value - ]) - decoded_message = Message.decode(encoded) - assert decoded_message.validate_crc() is True - - encoded = b''.join([ - struct.pack('>i', 1234), # Incorrect CRC - struct.pack('>bb', 0, 0), # Magic, flags - struct.pack('>i', 3), # Length of key - b'key', # key - struct.pack('>i', 4), # Length of value - b'test', # value - ]) - decoded_message = Message.decode(encoded) - assert decoded_message.validate_crc() is False - - -def test_encode_message_set(): - messages = [ - Message(b'v1', key=b'k1'), - Message(b'v2', key=b'k2') - ] - encoded = MessageSet.encode([(0, msg.encode()) - for msg in messages]) - expect = b''.join([ - struct.pack('>q', 0), # MsgSet Offset - struct.pack('>i', 18), # Msg Size - struct.pack('>i', 1474775406), # CRC - struct.pack('>bb', 0, 0), # Magic, flags - struct.pack('>i', 2), # Length of key - b'k1', # Key - struct.pack('>i', 2), # Length of value - b'v1', # Value - - struct.pack('>q', 0), # MsgSet Offset - struct.pack('>i', 18), # Msg Size - struct.pack('>i', -16383415), # CRC - struct.pack('>bb', 0, 0), # Magic, flags - struct.pack('>i', 2), # Length of key - b'k2', # Key - struct.pack('>i', 2), # Length of value - b'v2', # Value - ]) - expect = struct.pack('>i', len(expect)) + expect - assert encoded == expect - - -def test_decode_message_set(): - encoded = b''.join([ - struct.pack('>q', 0), # MsgSet Offset - struct.pack('>i', 18), # Msg Size - struct.pack('>i', 1474775406), # CRC - struct.pack('>bb', 0, 0), # Magic, flags - struct.pack('>i', 2), # Length of key - b'k1', # Key - struct.pack('>i', 2), # Length of value - b'v1', # Value - - struct.pack('>q', 1), # MsgSet Offset - struct.pack('>i', 18), # Msg Size - struct.pack('>i', -16383415), # CRC - struct.pack('>bb', 0, 0), # Magic, flags - struct.pack('>i', 2), # Length of key - b'k2', # Key - struct.pack('>i', 2), # Length of value - b'v2', # Value - ]) - - msgs = MessageSet.decode(encoded, bytes_to_read=len(encoded)) - assert len(msgs) == 2 - msg1, msg2 = msgs - - returned_offset1, message1_size, decoded_message1 = msg1 - returned_offset2, message2_size, decoded_message2 = msg2 - - assert returned_offset1 == 0 - message1 = Message(b'v1', key=b'k1') - message1.encode() - assert decoded_message1 == message1 - - assert returned_offset2 == 1 - message2 = Message(b'v2', key=b'k2') - message2.encode() - assert decoded_message2 == message2 - - def test_encode_message_header(): expect = b''.join([ struct.pack('>h', 10), # API Key @@ -173,44 +25,6 @@ def test_encode_message_header(): assert header.encode() == expect -def test_decode_message_set_partial(): - encoded = b''.join([ - struct.pack('>q', 0), # Msg Offset - struct.pack('>i', 18), # Msg Size - struct.pack('>i', 1474775406), # CRC - struct.pack('>bb', 0, 0), # Magic, flags - struct.pack('>i', 2), # Length of key - b'k1', # Key - struct.pack('>i', 2), # Length of value - b'v1', # Value - - struct.pack('>q', 1), # Msg Offset - struct.pack('>i', 24), # Msg Size (larger than remaining MsgSet size) - struct.pack('>i', -16383415), # CRC - struct.pack('>bb', 0, 0), # Magic, flags - struct.pack('>i', 2), # Length of key - b'k2', # Key - struct.pack('>i', 8), # Length of value - b'ar', # Value (truncated) - ]) - - msgs = MessageSet.decode(encoded, bytes_to_read=len(encoded)) - assert len(msgs) == 2 - msg1, msg2 = msgs - - returned_offset1, message1_size, decoded_message1 = msg1 - returned_offset2, message2_size, decoded_message2 = msg2 - - assert returned_offset1 == 0 - message1 = Message(b'v1', key=b'k1') - message1.encode() - assert decoded_message1 == message1 - - assert returned_offset2 is None - assert message2_size is None - assert decoded_message2 == PartialMessage() - - def test_decode_fetch_response_partial(): encoded = b''.join([ Int32.encode(1), # Num Topics (Array) @@ -265,10 +79,10 @@ def test_decode_fetch_response_partial(): assert topic == 'foobar' assert len(partitions) == 2 - m1 = MessageSet.decode( - partitions[0][3], bytes_to_read=len(partitions[0][3])) - assert len(m1) == 2 - assert m1[1] == (None, None, PartialMessage()) + #m1 = MessageSet.decode( + # partitions[0][3], bytes_to_read=len(partitions[0][3])) + #assert len(m1) == 2 + #assert m1[1] == (None, None, PartialMessage()) def test_struct_unrecognized_kwargs():