Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions kafka/protocol/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,17 @@ def repr(self, list_of_items):


class UnsignedVarInt32(AbstractType):
@classmethod
def decode(cls, data):
value = VarInt32.decode(data)
return (value << 1) ^ (value >> 31)

@classmethod
def encode(cls, value):
return VarInt32.encode((value >> 1) ^ -(value & 1))


class VarInt32(AbstractType):
@classmethod
def decode(cls, data):
value, i = 0, 0
Expand All @@ -238,10 +249,12 @@ def decode(cls, data):
if i > 28:
raise ValueError('Invalid value {}'.format(value))
value |= b << i
return value
return (value >> 1) ^ -(value & 1)

@classmethod
def encode(cls, value):
# bring it in line with the java binary repr
value = (value << 1) ^ (value >> 31)
value &= 0xffffffff
ret = b''
while (value & 0xffffff80) != 0:
Expand All @@ -252,25 +265,12 @@ def encode(cls, value):
return ret


class VarInt32(AbstractType):
@classmethod
def decode(cls, data):
value = UnsignedVarInt32.decode(data)
return (value >> 1) ^ -(value & 1)

@classmethod
def encode(cls, value):
# bring it in line with the java binary repr
value &= 0xffffffff
return UnsignedVarInt32.encode((value << 1) ^ (value >> 31))


class VarInt64(AbstractType):
@classmethod
def decode(cls, data):
value, i = 0, 0
while True:
b = data.read(1)
b, = struct.unpack('B', data.read(1))
if not (b & 0x80):
break
value |= (b & 0x7f) << i
Expand All @@ -283,14 +283,14 @@ def decode(cls, data):
@classmethod
def encode(cls, value):
# bring it in line with the java binary repr
value = (value << 1) ^ (value >> 63)
value &= 0xffffffffffffffff
v = (value << 1) ^ (value >> 63)
ret = b''
while (v & 0xffffffffffffff80) != 0:
while (value & 0xffffffffffffff80) != 0:
b = (value & 0x7f) | 0x80
ret += struct.pack('B', b)
v >>= 7
ret += struct.pack('B', v)
value >>= 7
ret += struct.pack('B', value)
return ret


Expand Down
35 changes: 35 additions & 0 deletions test/protocol/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import struct

import pytest

from kafka.protocol.api import RequestHeader
from kafka.protocol.fetch import FetchRequest
from kafka.protocol.find_coordinator import FindCoordinatorRequest
from kafka.protocol.metadata import MetadataRequest


def test_encode_message_header():
expect = b''.join([
struct.pack('>h', 10), # API Key
struct.pack('>h', 0), # API Version
struct.pack('>i', 4), # Correlation Id
struct.pack('>h', len('client3')), # Length of clientId
b'client3', # ClientId
])

req = FindCoordinatorRequest[0]('foo')
header = RequestHeader(req, correlation_id=4, client_id='client3')
assert header.encode() == expect


def test_struct_unrecognized_kwargs():
try:
_mr = MetadataRequest[0](topicz='foo')
assert False, 'Structs should not allow unrecognized kwargs'
except ValueError:
pass


def test_struct_missing_kwargs():
fr = FetchRequest[0](max_wait_time=100)
assert fr.min_bytes is None
13 changes: 13 additions & 0 deletions test/protocol/test_bit_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import io

import pytest

from kafka.protocol.types import BitField


@pytest.mark.parametrize(('test_set',), [
(set([0, 1, 5, 10, 31]),),
(set(range(32)),),
])
def test_bit_field(test_set):
assert BitField.decode(io.BytesIO(BitField.encode(test_set))) == test_set
38 changes: 38 additions & 0 deletions test/protocol/test_compact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import io
import struct

import pytest

from kafka.protocol.types import CompactString, CompactArray, CompactBytes


def test_compact_data_structs():
cs = CompactString()
encoded = cs.encode(None)
assert encoded == struct.pack('B', 0)
decoded = cs.decode(io.BytesIO(encoded))
assert decoded is None
assert b'\x01' == cs.encode('')
assert '' == cs.decode(io.BytesIO(b'\x01'))
encoded = cs.encode("foobarbaz")
assert cs.decode(io.BytesIO(encoded)) == "foobarbaz"

arr = CompactArray(CompactString())
assert arr.encode(None) == b'\x00'
assert arr.decode(io.BytesIO(b'\x00')) is None
enc = arr.encode([])
assert enc == b'\x01'
assert [] == arr.decode(io.BytesIO(enc))
encoded = arr.encode(["foo", "bar", "baz", "quux"])
assert arr.decode(io.BytesIO(encoded)) == ["foo", "bar", "baz", "quux"]

enc = CompactBytes.encode(None)
assert enc == b'\x00'
assert CompactBytes.decode(io.BytesIO(b'\x00')) is None
enc = CompactBytes.encode(b'')
assert enc == b'\x01'
assert CompactBytes.decode(io.BytesIO(b'\x01')) == b''
enc = CompactBytes.encode(b'foo')
assert CompactBytes.decode(io.BytesIO(enc)) == b'foo'


68 changes: 68 additions & 0 deletions test/protocol/test_fetch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#pylint: skip-file
import io
import struct

import pytest

from kafka.protocol.fetch import FetchResponse
from kafka.protocol.types import Int16, Int32, Int64, String


def test_decode_fetch_response_partial():
encoded = b''.join([
Int32.encode(1), # Num Topics (Array)
String('utf-8').encode('foobar'),
Int32.encode(2), # Num Partitions (Array)
Int32.encode(0), # Partition id
Int16.encode(0), # Error Code
Int64.encode(1234), # Highwater offset
Int32.encode(52), # MessageSet size
Int64.encode(0), # Msg Offset
Int32.encode(18), # Msg Size
struct.pack('>i', 1474775406), # CRC
struct.pack('>bb', 0, 0), # Magic, flags
struct.pack('>i', 2), # Length of key
b'k1', # Key
struct.pack('>i', 2), # Length of value
b'v1', # Value

Int64.encode(1), # Msg Offset
struct.pack('>i', 24), # Msg Size (larger than remaining MsgSet size)
struct.pack('>i', -16383415), # CRC
struct.pack('>bb', 0, 0), # Magic, flags
struct.pack('>i', 2), # Length of key
b'k2', # Key
struct.pack('>i', 8), # Length of value
b'ar', # Value (truncated)
Int32.encode(1),
Int16.encode(0),
Int64.encode(2345),
Int32.encode(52), # MessageSet size
Int64.encode(0), # Msg Offset
Int32.encode(18), # Msg Size
struct.pack('>i', 1474775406), # CRC
struct.pack('>bb', 0, 0), # Magic, flags
struct.pack('>i', 2), # Length of key
b'k1', # Key
struct.pack('>i', 2), # Length of value
b'v1', # Value

Int64.encode(1), # Msg Offset
struct.pack('>i', 24), # Msg Size (larger than remaining MsgSet size)
struct.pack('>i', -16383415), # CRC
struct.pack('>bb', 0, 0), # Magic, flags
struct.pack('>i', 2), # Length of key
b'k2', # Key
struct.pack('>i', 8), # Length of value
b'ar', # Value (truncated)
])
resp = FetchResponse[0].decode(io.BytesIO(encoded))
assert len(resp.topics) == 1
topic, partitions = resp.topics[0]
assert topic == 'foobar'
assert len(partitions) == 2

#m1 = MessageSet.decode(
# partitions[0][3], bytes_to_read=len(partitions[0][3]))
#assert len(m1) == 2
#assert m1[1] == (None, None, PartialMessage())
109 changes: 109 additions & 0 deletions test/protocol/test_varint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import io
import struct

import pytest

from kafka.protocol.types import UnsignedVarInt32, VarInt32, VarInt64


@pytest.mark.parametrize(('value','expected_encoded'), [
(0, [0x00]),
(-1, [0xFF, 0xFF, 0xFF, 0xFF, 0x0F]),
(1, [0x01]),
(63, [0x3F]),
(-64, [0xC0, 0xFF, 0xFF, 0xFF, 0x0F]),
(64, [0x40]),
(8191, [0xFF, 0x3F]),
(-8192, [0x80, 0xC0, 0xFF, 0xFF, 0x0F]),
(8192, [0x80, 0x40]),
(-8193, [0xFF, 0xBF, 0xFF, 0xFF, 0x0F]),
(1048575, [0xFF, 0xFF, 0x3F]),
(1048576, [0x80, 0x80, 0x40]),
(2147483647, [0xFF, 0xFF, 0xFF, 0xFF, 0x07]),
(-2147483648, [0x80, 0x80, 0x80, 0x80, 0x08]),
])
def test_unsigned_varint_serde(value, expected_encoded):
value &= 0xffffffff
encoded = UnsignedVarInt32.encode(value)
assert encoded == b''.join(struct.pack('>B', x) for x in expected_encoded)
assert value == UnsignedVarInt32.decode(io.BytesIO(encoded))


@pytest.mark.parametrize(('value','expected_encoded'), [
(0, [0x00]),
(-1, [0x01]),
(1, [0x02]),
(63, [0x7E]),
(-64, [0x7F]),
(64, [0x80, 0x01]),
(-65, [0x81, 0x01]),
(8191, [0xFE, 0x7F]),
(-8192, [0xFF, 0x7F]),
(8192, [0x80, 0x80, 0x01]),
(-8193, [0x81, 0x80, 0x01]),
(1048575, [0xFE, 0xFF, 0x7F]),
(-1048576, [0xFF, 0xFF, 0x7F]),
(1048576, [0x80, 0x80, 0x80, 0x01]),
(-1048577, [0x81, 0x80, 0x80, 0x01]),
(134217727, [0xFE, 0xFF, 0xFF, 0x7F]),
(-134217728, [0xFF, 0xFF, 0xFF, 0x7F]),
(134217728, [0x80, 0x80, 0x80, 0x80, 0x01]),
(-134217729, [0x81, 0x80, 0x80, 0x80, 0x01]),
(2147483647, [0xFE, 0xFF, 0xFF, 0xFF, 0x0F]),
(-2147483648, [0xFF, 0xFF, 0xFF, 0xFF, 0x0F]),
])
def test_signed_varint_serde(value, expected_encoded):
encoded = VarInt32.encode(value)
assert encoded == b''.join(struct.pack('>B', x) for x in expected_encoded)
assert value == VarInt32.decode(io.BytesIO(encoded))


@pytest.mark.parametrize(('value','expected_encoded'), [
(0, [0x00]),
(-1, [0x01]),
(1, [0x02]),
(63, [0x7E]),
(-64, [0x7F]),
(64, [0x80, 0x01]),
(-65, [0x81, 0x01]),
(8191, [0xFE, 0x7F]),
(-8192, [0xFF, 0x7F]),
(8192, [0x80, 0x80, 0x01]),
(-8193, [0x81, 0x80, 0x01]),
(1048575, [0xFE, 0xFF, 0x7F]),
(-1048576, [0xFF, 0xFF, 0x7F]),
(1048576, [0x80, 0x80, 0x80, 0x01]),
(-1048577, [0x81, 0x80, 0x80, 0x01]),
(134217727, [0xFE, 0xFF, 0xFF, 0x7F]),
(-134217728, [0xFF, 0xFF, 0xFF, 0x7F]),
(134217728, [0x80, 0x80, 0x80, 0x80, 0x01]),
(-134217729, [0x81, 0x80, 0x80, 0x80, 0x01]),
(2147483647, [0xFE, 0xFF, 0xFF, 0xFF, 0x0F]),
(-2147483648, [0xFF, 0xFF, 0xFF, 0xFF, 0x0F]),
(17179869183, [0xFE, 0xFF, 0xFF, 0xFF, 0x7F]),
(-17179869184, [0xFF, 0xFF, 0xFF, 0xFF, 0x7F]),
(17179869184, [0x80, 0x80, 0x80, 0x80, 0x80, 0x01]),
(-17179869185, [0x81, 0x80, 0x80, 0x80, 0x80, 0x01]),
(2199023255551, [0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]),
(-2199023255552, [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]),
(2199023255552, [0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01]),
(-2199023255553, [0x81, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01]),
(281474976710655, [0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]),
(-281474976710656, [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]),
(281474976710656, [0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01]),
(-281474976710657, [0x81, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 1]),
(36028797018963967, [0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]),
(-36028797018963968, [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]),
(36028797018963968, [0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01]),
(-36028797018963969, [0x81, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01]),
(4611686018427387903, [0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]),
(-4611686018427387904, [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]),
(4611686018427387904, [0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01]),
(-4611686018427387905, [0x81, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01]),
(9223372036854775807, [0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01]),
(-9223372036854775808, [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01]),
])
def test_signed_varlong_serde(value, expected_encoded):
encoded = VarInt64.encode(value)
assert encoded == b''.join(struct.pack('>B', x) for x in expected_encoded)
assert value == VarInt64.decode(io.BytesIO(encoded))
Loading
Loading