diff --git a/kafka/protocol/types.py b/kafka/protocol/types.py index 7889e06d5..b0811c59b 100644 --- a/kafka/protocol/types.py +++ b/kafka/protocol/types.py @@ -226,6 +226,17 @@ def repr(self, list_of_items): class UnsignedVarInt32(AbstractType): + @classmethod + def decode(cls, data): + value = VarInt32.decode(data) + return (value << 1) ^ (value >> 31) + + @classmethod + def encode(cls, value): + return VarInt32.encode((value >> 1) ^ -(value & 1)) + + +class VarInt32(AbstractType): @classmethod def decode(cls, data): value, i = 0, 0 @@ -238,10 +249,12 @@ def decode(cls, data): if i > 28: raise ValueError('Invalid value {}'.format(value)) value |= b << i - return value + return (value >> 1) ^ -(value & 1) @classmethod def encode(cls, value): + # bring it in line with the java binary repr + value = (value << 1) ^ (value >> 31) value &= 0xffffffff ret = b'' while (value & 0xffffff80) != 0: @@ -252,25 +265,12 @@ def encode(cls, value): return ret -class VarInt32(AbstractType): - @classmethod - def decode(cls, data): - value = UnsignedVarInt32.decode(data) - return (value >> 1) ^ -(value & 1) - - @classmethod - def encode(cls, value): - # bring it in line with the java binary repr - value &= 0xffffffff - return UnsignedVarInt32.encode((value << 1) ^ (value >> 31)) - - class VarInt64(AbstractType): @classmethod def decode(cls, data): value, i = 0, 0 while True: - b = data.read(1) + b, = struct.unpack('B', data.read(1)) if not (b & 0x80): break value |= (b & 0x7f) << i @@ -283,14 +283,14 @@ def decode(cls, data): @classmethod def encode(cls, value): # bring it in line with the java binary repr + value = (value << 1) ^ (value >> 63) value &= 0xffffffffffffffff - v = (value << 1) ^ (value >> 63) ret = b'' - while (v & 0xffffffffffffff80) != 0: + while (value & 0xffffffffffffff80) != 0: b = (value & 0x7f) | 0x80 ret += struct.pack('B', b) - v >>= 7 - ret += struct.pack('B', v) + value >>= 7 + ret += struct.pack('B', value) return ret diff --git a/test/protocol/test_api.py b/test/protocol/test_api.py new file mode 100644 index 000000000..4bb7273bd --- /dev/null +++ b/test/protocol/test_api.py @@ -0,0 +1,35 @@ +import struct + +import pytest + +from kafka.protocol.api import RequestHeader +from kafka.protocol.fetch import FetchRequest +from kafka.protocol.find_coordinator import FindCoordinatorRequest +from kafka.protocol.metadata import MetadataRequest + + +def test_encode_message_header(): + expect = b''.join([ + struct.pack('>h', 10), # API Key + struct.pack('>h', 0), # API Version + struct.pack('>i', 4), # Correlation Id + struct.pack('>h', len('client3')), # Length of clientId + b'client3', # ClientId + ]) + + req = FindCoordinatorRequest[0]('foo') + header = RequestHeader(req, correlation_id=4, client_id='client3') + assert header.encode() == expect + + +def test_struct_unrecognized_kwargs(): + try: + _mr = MetadataRequest[0](topicz='foo') + assert False, 'Structs should not allow unrecognized kwargs' + except ValueError: + pass + + +def test_struct_missing_kwargs(): + fr = FetchRequest[0](max_wait_time=100) + assert fr.min_bytes is None diff --git a/test/protocol/test_bit_field.py b/test/protocol/test_bit_field.py new file mode 100644 index 000000000..5db155241 --- /dev/null +++ b/test/protocol/test_bit_field.py @@ -0,0 +1,13 @@ +import io + +import pytest + +from kafka.protocol.types import BitField + + +@pytest.mark.parametrize(('test_set',), [ + (set([0, 1, 5, 10, 31]),), + (set(range(32)),), +]) +def test_bit_field(test_set): + assert BitField.decode(io.BytesIO(BitField.encode(test_set))) == test_set diff --git a/test/protocol/test_compact.py b/test/protocol/test_compact.py new file mode 100644 index 000000000..c5940aa70 --- /dev/null +++ b/test/protocol/test_compact.py @@ -0,0 +1,38 @@ +import io +import struct + +import pytest + +from kafka.protocol.types import CompactString, CompactArray, CompactBytes + + +def test_compact_data_structs(): + cs = CompactString() + encoded = cs.encode(None) + assert encoded == struct.pack('B', 0) + decoded = cs.decode(io.BytesIO(encoded)) + assert decoded is None + assert b'\x01' == cs.encode('') + assert '' == cs.decode(io.BytesIO(b'\x01')) + encoded = cs.encode("foobarbaz") + assert cs.decode(io.BytesIO(encoded)) == "foobarbaz" + + arr = CompactArray(CompactString()) + assert arr.encode(None) == b'\x00' + assert arr.decode(io.BytesIO(b'\x00')) is None + enc = arr.encode([]) + assert enc == b'\x01' + assert [] == arr.decode(io.BytesIO(enc)) + encoded = arr.encode(["foo", "bar", "baz", "quux"]) + assert arr.decode(io.BytesIO(encoded)) == ["foo", "bar", "baz", "quux"] + + enc = CompactBytes.encode(None) + assert enc == b'\x00' + assert CompactBytes.decode(io.BytesIO(b'\x00')) is None + enc = CompactBytes.encode(b'') + assert enc == b'\x01' + assert CompactBytes.decode(io.BytesIO(b'\x01')) == b'' + enc = CompactBytes.encode(b'foo') + assert CompactBytes.decode(io.BytesIO(enc)) == b'foo' + + diff --git a/test/protocol/test_fetch.py b/test/protocol/test_fetch.py new file mode 100644 index 000000000..993df9c89 --- /dev/null +++ b/test/protocol/test_fetch.py @@ -0,0 +1,68 @@ +#pylint: skip-file +import io +import struct + +import pytest + +from kafka.protocol.fetch import FetchResponse +from kafka.protocol.types import Int16, Int32, Int64, String + + +def test_decode_fetch_response_partial(): + encoded = b''.join([ + Int32.encode(1), # Num Topics (Array) + String('utf-8').encode('foobar'), + Int32.encode(2), # Num Partitions (Array) + Int32.encode(0), # Partition id + Int16.encode(0), # Error Code + Int64.encode(1234), # Highwater offset + Int32.encode(52), # MessageSet size + Int64.encode(0), # Msg Offset + Int32.encode(18), # Msg Size + struct.pack('>i', 1474775406), # CRC + struct.pack('>bb', 0, 0), # Magic, flags + struct.pack('>i', 2), # Length of key + b'k1', # Key + struct.pack('>i', 2), # Length of value + b'v1', # Value + + Int64.encode(1), # Msg Offset + struct.pack('>i', 24), # Msg Size (larger than remaining MsgSet size) + struct.pack('>i', -16383415), # CRC + struct.pack('>bb', 0, 0), # Magic, flags + struct.pack('>i', 2), # Length of key + b'k2', # Key + struct.pack('>i', 8), # Length of value + b'ar', # Value (truncated) + Int32.encode(1), + Int16.encode(0), + Int64.encode(2345), + Int32.encode(52), # MessageSet size + Int64.encode(0), # Msg Offset + Int32.encode(18), # Msg Size + struct.pack('>i', 1474775406), # CRC + struct.pack('>bb', 0, 0), # Magic, flags + struct.pack('>i', 2), # Length of key + b'k1', # Key + struct.pack('>i', 2), # Length of value + b'v1', # Value + + Int64.encode(1), # Msg Offset + struct.pack('>i', 24), # Msg Size (larger than remaining MsgSet size) + struct.pack('>i', -16383415), # CRC + struct.pack('>bb', 0, 0), # Magic, flags + struct.pack('>i', 2), # Length of key + b'k2', # Key + struct.pack('>i', 8), # Length of value + b'ar', # Value (truncated) + ]) + resp = FetchResponse[0].decode(io.BytesIO(encoded)) + assert len(resp.topics) == 1 + topic, partitions = resp.topics[0] + assert topic == 'foobar' + assert len(partitions) == 2 + + #m1 = MessageSet.decode( + # partitions[0][3], bytes_to_read=len(partitions[0][3])) + #assert len(m1) == 2 + #assert m1[1] == (None, None, PartialMessage()) diff --git a/test/protocol/test_varint.py b/test/protocol/test_varint.py new file mode 100644 index 000000000..826ad949c --- /dev/null +++ b/test/protocol/test_varint.py @@ -0,0 +1,109 @@ +import io +import struct + +import pytest + +from kafka.protocol.types import UnsignedVarInt32, VarInt32, VarInt64 + + +@pytest.mark.parametrize(('value','expected_encoded'), [ + (0, [0x00]), + (-1, [0xFF, 0xFF, 0xFF, 0xFF, 0x0F]), + (1, [0x01]), + (63, [0x3F]), + (-64, [0xC0, 0xFF, 0xFF, 0xFF, 0x0F]), + (64, [0x40]), + (8191, [0xFF, 0x3F]), + (-8192, [0x80, 0xC0, 0xFF, 0xFF, 0x0F]), + (8192, [0x80, 0x40]), + (-8193, [0xFF, 0xBF, 0xFF, 0xFF, 0x0F]), + (1048575, [0xFF, 0xFF, 0x3F]), + (1048576, [0x80, 0x80, 0x40]), + (2147483647, [0xFF, 0xFF, 0xFF, 0xFF, 0x07]), + (-2147483648, [0x80, 0x80, 0x80, 0x80, 0x08]), +]) +def test_unsigned_varint_serde(value, expected_encoded): + value &= 0xffffffff + encoded = UnsignedVarInt32.encode(value) + assert encoded == b''.join(struct.pack('>B', x) for x in expected_encoded) + assert value == UnsignedVarInt32.decode(io.BytesIO(encoded)) + + +@pytest.mark.parametrize(('value','expected_encoded'), [ + (0, [0x00]), + (-1, [0x01]), + (1, [0x02]), + (63, [0x7E]), + (-64, [0x7F]), + (64, [0x80, 0x01]), + (-65, [0x81, 0x01]), + (8191, [0xFE, 0x7F]), + (-8192, [0xFF, 0x7F]), + (8192, [0x80, 0x80, 0x01]), + (-8193, [0x81, 0x80, 0x01]), + (1048575, [0xFE, 0xFF, 0x7F]), + (-1048576, [0xFF, 0xFF, 0x7F]), + (1048576, [0x80, 0x80, 0x80, 0x01]), + (-1048577, [0x81, 0x80, 0x80, 0x01]), + (134217727, [0xFE, 0xFF, 0xFF, 0x7F]), + (-134217728, [0xFF, 0xFF, 0xFF, 0x7F]), + (134217728, [0x80, 0x80, 0x80, 0x80, 0x01]), + (-134217729, [0x81, 0x80, 0x80, 0x80, 0x01]), + (2147483647, [0xFE, 0xFF, 0xFF, 0xFF, 0x0F]), + (-2147483648, [0xFF, 0xFF, 0xFF, 0xFF, 0x0F]), +]) +def test_signed_varint_serde(value, expected_encoded): + encoded = VarInt32.encode(value) + assert encoded == b''.join(struct.pack('>B', x) for x in expected_encoded) + assert value == VarInt32.decode(io.BytesIO(encoded)) + + +@pytest.mark.parametrize(('value','expected_encoded'), [ + (0, [0x00]), + (-1, [0x01]), + (1, [0x02]), + (63, [0x7E]), + (-64, [0x7F]), + (64, [0x80, 0x01]), + (-65, [0x81, 0x01]), + (8191, [0xFE, 0x7F]), + (-8192, [0xFF, 0x7F]), + (8192, [0x80, 0x80, 0x01]), + (-8193, [0x81, 0x80, 0x01]), + (1048575, [0xFE, 0xFF, 0x7F]), + (-1048576, [0xFF, 0xFF, 0x7F]), + (1048576, [0x80, 0x80, 0x80, 0x01]), + (-1048577, [0x81, 0x80, 0x80, 0x01]), + (134217727, [0xFE, 0xFF, 0xFF, 0x7F]), + (-134217728, [0xFF, 0xFF, 0xFF, 0x7F]), + (134217728, [0x80, 0x80, 0x80, 0x80, 0x01]), + (-134217729, [0x81, 0x80, 0x80, 0x80, 0x01]), + (2147483647, [0xFE, 0xFF, 0xFF, 0xFF, 0x0F]), + (-2147483648, [0xFF, 0xFF, 0xFF, 0xFF, 0x0F]), + (17179869183, [0xFE, 0xFF, 0xFF, 0xFF, 0x7F]), + (-17179869184, [0xFF, 0xFF, 0xFF, 0xFF, 0x7F]), + (17179869184, [0x80, 0x80, 0x80, 0x80, 0x80, 0x01]), + (-17179869185, [0x81, 0x80, 0x80, 0x80, 0x80, 0x01]), + (2199023255551, [0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]), + (-2199023255552, [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]), + (2199023255552, [0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01]), + (-2199023255553, [0x81, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01]), + (281474976710655, [0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]), + (-281474976710656, [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]), + (281474976710656, [0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01]), + (-281474976710657, [0x81, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 1]), + (36028797018963967, [0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]), + (-36028797018963968, [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]), + (36028797018963968, [0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01]), + (-36028797018963969, [0x81, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01]), + (4611686018427387903, [0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]), + (-4611686018427387904, [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]), + (4611686018427387904, [0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01]), + (-4611686018427387905, [0x81, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01]), + (9223372036854775807, [0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01]), + (-9223372036854775808, [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01]), +]) +def test_signed_varlong_serde(value, expected_encoded): + encoded = VarInt64.encode(value) + assert encoded == b''.join(struct.pack('>B', x) for x in expected_encoded) + assert value == VarInt64.decode(io.BytesIO(encoded)) diff --git a/test/test_protocol.py b/test/test_protocol.py deleted file mode 100644 index 45755c4c0..000000000 --- a/test/test_protocol.py +++ /dev/null @@ -1,158 +0,0 @@ -#pylint: skip-file -import io -import struct - -import pytest - -from kafka.protocol.api import RequestHeader -from kafka.protocol.fetch import FetchRequest, FetchResponse -from kafka.protocol.find_coordinator import FindCoordinatorRequest -from kafka.protocol.metadata import MetadataRequest -from kafka.protocol.types import Int16, Int32, Int64, String, UnsignedVarInt32, CompactString, CompactArray, CompactBytes, BitField - - -def test_encode_message_header(): - expect = b''.join([ - struct.pack('>h', 10), # API Key - struct.pack('>h', 0), # API Version - struct.pack('>i', 4), # Correlation Id - struct.pack('>h', len('client3')), # Length of clientId - b'client3', # ClientId - ]) - - req = FindCoordinatorRequest[0]('foo') - header = RequestHeader(req, correlation_id=4, client_id='client3') - assert header.encode() == expect - - -def test_decode_fetch_response_partial(): - encoded = b''.join([ - Int32.encode(1), # Num Topics (Array) - String('utf-8').encode('foobar'), - Int32.encode(2), # Num Partitions (Array) - Int32.encode(0), # Partition id - Int16.encode(0), # Error Code - Int64.encode(1234), # Highwater offset - Int32.encode(52), # MessageSet size - Int64.encode(0), # Msg Offset - Int32.encode(18), # Msg Size - struct.pack('>i', 1474775406), # CRC - struct.pack('>bb', 0, 0), # Magic, flags - struct.pack('>i', 2), # Length of key - b'k1', # Key - struct.pack('>i', 2), # Length of value - b'v1', # Value - - Int64.encode(1), # Msg Offset - struct.pack('>i', 24), # Msg Size (larger than remaining MsgSet size) - struct.pack('>i', -16383415), # CRC - struct.pack('>bb', 0, 0), # Magic, flags - struct.pack('>i', 2), # Length of key - b'k2', # Key - struct.pack('>i', 8), # Length of value - b'ar', # Value (truncated) - Int32.encode(1), - Int16.encode(0), - Int64.encode(2345), - Int32.encode(52), # MessageSet size - Int64.encode(0), # Msg Offset - Int32.encode(18), # Msg Size - struct.pack('>i', 1474775406), # CRC - struct.pack('>bb', 0, 0), # Magic, flags - struct.pack('>i', 2), # Length of key - b'k1', # Key - struct.pack('>i', 2), # Length of value - b'v1', # Value - - Int64.encode(1), # Msg Offset - struct.pack('>i', 24), # Msg Size (larger than remaining MsgSet size) - struct.pack('>i', -16383415), # CRC - struct.pack('>bb', 0, 0), # Magic, flags - struct.pack('>i', 2), # Length of key - b'k2', # Key - struct.pack('>i', 8), # Length of value - b'ar', # Value (truncated) - ]) - resp = FetchResponse[0].decode(io.BytesIO(encoded)) - assert len(resp.topics) == 1 - topic, partitions = resp.topics[0] - assert topic == 'foobar' - assert len(partitions) == 2 - - #m1 = MessageSet.decode( - # partitions[0][3], bytes_to_read=len(partitions[0][3])) - #assert len(m1) == 2 - #assert m1[1] == (None, None, PartialMessage()) - - -def test_struct_unrecognized_kwargs(): - try: - _mr = MetadataRequest[0](topicz='foo') - assert False, 'Structs should not allow unrecognized kwargs' - except ValueError: - pass - - -def test_struct_missing_kwargs(): - fr = FetchRequest[0](max_wait_time=100) - assert fr.min_bytes is None - - -def test_unsigned_varint_serde(): - pairs = { - 0: [0], - -1: [0xff, 0xff, 0xff, 0xff, 0x0f], - 1: [1], - 63: [0x3f], - -64: [0xc0, 0xff, 0xff, 0xff, 0x0f], - 64: [0x40], - 8191: [0xff, 0x3f], - -8192: [0x80, 0xc0, 0xff, 0xff, 0x0f], - 8192: [0x80, 0x40], - -8193: [0xff, 0xbf, 0xff, 0xff, 0x0f], - 1048575: [0xff, 0xff, 0x3f], - - } - for value, expected_encoded in pairs.items(): - value &= 0xffffffff - encoded = UnsignedVarInt32.encode(value) - assert encoded == b''.join(struct.pack('>B', x) for x in expected_encoded) - assert value == UnsignedVarInt32.decode(io.BytesIO(encoded)) - - -def test_compact_data_structs(): - cs = CompactString() - encoded = cs.encode(None) - assert encoded == struct.pack('B', 0) - decoded = cs.decode(io.BytesIO(encoded)) - assert decoded is None - assert b'\x01' == cs.encode('') - assert '' == cs.decode(io.BytesIO(b'\x01')) - encoded = cs.encode("foobarbaz") - assert cs.decode(io.BytesIO(encoded)) == "foobarbaz" - - arr = CompactArray(CompactString()) - assert arr.encode(None) == b'\x00' - assert arr.decode(io.BytesIO(b'\x00')) is None - enc = arr.encode([]) - assert enc == b'\x01' - assert [] == arr.decode(io.BytesIO(enc)) - encoded = arr.encode(["foo", "bar", "baz", "quux"]) - assert arr.decode(io.BytesIO(encoded)) == ["foo", "bar", "baz", "quux"] - - enc = CompactBytes.encode(None) - assert enc == b'\x00' - assert CompactBytes.decode(io.BytesIO(b'\x00')) is None - enc = CompactBytes.encode(b'') - assert enc == b'\x01' - assert CompactBytes.decode(io.BytesIO(b'\x01')) == b'' - enc = CompactBytes.encode(b'foo') - assert CompactBytes.decode(io.BytesIO(enc)) == b'foo' - - -@pytest.mark.parametrize(('test_set',), [ - (set([0, 1, 5, 10, 31]),), - (set(range(32)),), -]) -def test_bit_field(test_set): - assert BitField.decode(io.BytesIO(BitField.encode(test_set))) == test_set