Skip to content

Commit e725eab

Browse files
Lash-LCopilot
andauthored
fix: handle random length bytes before version bytes (#656)
* fix: handle random length bytes before version bytes * fix: filter tests to be warnings only * chore: apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * chore: update roborock/protocol.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * chore: add debug to help us determine if buffer is source of problem * chore: only log if remaining --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent f32c304 commit e725eab

File tree

4 files changed

+126
-9
lines changed

4 files changed

+126
-9
lines changed

roborock/protocol.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -341,9 +341,35 @@ class PrefixedStruct(Struct):
341341
def _parse(self, stream, context, path):
342342
subcon1 = Peek(Optional(Bytes(3)))
343343
peek_version = subcon1.parse_stream(stream, **context)
344-
if peek_version not in (b"1.0", b"A01", b"B01", b"L01"):
345-
subcon2 = Bytes(4)
346-
subcon2.parse_stream(stream, **context)
344+
345+
valid_versions = (b"1.0", b"A01", b"B01", b"L01")
346+
if peek_version not in valid_versions:
347+
# Current stream position does not start with a valid version.
348+
# Scan forward to find one.
349+
current_pos = stream_tell(stream, path)
350+
# Read remaining data to find a valid header
351+
data = stream.read()
352+
353+
start_index = -1
354+
# Find the earliest occurrence of any valid version in a single pass
355+
for i in range(len(data) - 2):
356+
if data[i : i + 3] in valid_versions:
357+
start_index = i
358+
break
359+
360+
if start_index != -1:
361+
# Found a valid version header at `start_index`.
362+
# Seek to that position (original_pos + index).
363+
if start_index != 4:
364+
# 4 is the typical/expected amount we prune off,
365+
# therefore, we only want a debug if we have a different length.
366+
_LOGGER.debug("Stripping %d bytes of invalid data from stream", start_index)
367+
stream_seek(stream, current_pos + start_index, 0, path)
368+
else:
369+
_LOGGER.debug("No valid version header found in stream, continuing anyways...")
370+
# Seek back to the original position to avoid parsing at EOF
371+
stream_seek(stream, current_pos, 0, path)
372+
347373
return super()._parse(stream, context, path)
348374

349375
def _build(self, obj, stream, context, path):
@@ -511,6 +537,8 @@ def decode(bytes_data: bytes) -> list[RoborockMessage]:
511537
parsed_messages, remaining = MessageParser.parse(
512538
buffer, local_key=local_key, connect_nonce=connect_nonce, ack_nonce=ack_nonce
513539
)
540+
if remaining:
541+
_LOGGER.debug("Found %d extra bytes: %s", len(remaining), remaining)
514542
buffer = remaining
515543
return parsed_messages
516544

tests/devices/test_local_channel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,9 @@ async def test_message_decode_error(local_channel: LocalChannel, caplog: pytest.
147147
local_channel._data_received(b"invalid_payload")
148148
await asyncio.sleep(0.01) # yield
149149

150-
assert len(caplog.records) == 1
151-
assert caplog.records[0].levelname == "WARNING"
152-
assert "Failed to decode message" in caplog.records[0].message
150+
warning_records = [record for record in caplog.records if record.levelname == "WARNING"]
151+
assert len(warning_records) == 1
152+
assert "Failed to decode message" in warning_records[0].message
153153

154154

155155
async def test_subscribe_callback(
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from roborock.protocol import create_local_decoder, create_local_encoder
2+
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol
3+
4+
TEST_LOCAL_KEY = "local_key"
5+
6+
7+
def test_decoder_clean_message():
8+
encoder = create_local_encoder(TEST_LOCAL_KEY)
9+
decoder = create_local_decoder(TEST_LOCAL_KEY)
10+
11+
msg = RoborockMessage(
12+
protocol=RoborockMessageProtocol.RPC_REQUEST,
13+
payload=b"test_payload",
14+
version=b"1.0",
15+
seq=1,
16+
random=123,
17+
)
18+
encoded = encoder(msg)
19+
20+
decoded = decoder(encoded)
21+
assert len(decoded) == 1
22+
assert decoded[0].payload == b"test_payload"
23+
24+
25+
def test_decoder_4byte_padding():
26+
"""Test existing behavior: 4 byte padding should be skipped."""
27+
encoder = create_local_encoder(TEST_LOCAL_KEY)
28+
decoder = create_local_decoder(TEST_LOCAL_KEY)
29+
30+
msg = RoborockMessage(
31+
protocol=RoborockMessageProtocol.RPC_REQUEST,
32+
payload=b"test_payload",
33+
version=b"1.0",
34+
)
35+
encoded = encoder(msg)
36+
37+
# Prepend 4 bytes of garbage
38+
garbage = b"\x00\x00\x05\xa1"
39+
data = garbage + encoded
40+
41+
decoded = decoder(data)
42+
assert len(decoded) == 1
43+
assert decoded[0].payload == b"test_payload"
44+
45+
46+
def test_decoder_variable_padding():
47+
"""Test variable length padding handling."""
48+
encoder = create_local_encoder(TEST_LOCAL_KEY, connect_nonce=123, ack_nonce=456)
49+
decoder = create_local_decoder(TEST_LOCAL_KEY, connect_nonce=123, ack_nonce=456)
50+
51+
msg = RoborockMessage(
52+
protocol=RoborockMessageProtocol.RPC_REQUEST,
53+
payload=b"test_payload",
54+
version=b"L01",
55+
)
56+
encoded = encoder(msg)
57+
58+
# Prepend 6 bytes of garbage
59+
garbage = b"\x00\x00\x05\xa1\xff\xff"
60+
data = garbage + encoded
61+
62+
decoded = decoder(data)
63+
assert len(decoded) == 1
64+
assert decoded[0].payload == b"test_payload"
65+
66+
67+
def test_decoder_split_padding_variable():
68+
"""Test variable padding split across chunks."""
69+
encoder = create_local_encoder(TEST_LOCAL_KEY, connect_nonce=123, ack_nonce=456)
70+
decoder = create_local_decoder(TEST_LOCAL_KEY, connect_nonce=123, ack_nonce=456)
71+
72+
msg = RoborockMessage(
73+
protocol=RoborockMessageProtocol.RPC_REQUEST,
74+
payload=b"test_payload",
75+
version=b"L01",
76+
)
77+
encoded = encoder(msg)
78+
79+
garbage = b"\x00\x00\x05\xa1\xff\xff" # 6 bytes
80+
81+
# Send garbage
82+
decoded1 = decoder(garbage)
83+
assert len(decoded1) == 0
84+
85+
# Send message
86+
decoded2 = decoder(encoded)
87+
88+
assert len(decoded2) == 1
89+
assert decoded2[0].payload == b"test_payload"

tests/devices/test_mqtt_channel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,9 @@ async def test_message_decode_error(
158158
mqtt_message_handler(b"invalid_payload")
159159
await asyncio.sleep(0.01) # yield
160160

161-
assert len(caplog.records) == 1
162-
assert caplog.records[0].levelname == "WARNING"
163-
assert "Failed to decode message" in caplog.records[0].message
161+
warning_records = [record for record in caplog.records if record.levelname == "WARNING"]
162+
assert len(warning_records) == 1
163+
assert "Failed to decode message" in warning_records[0].message
164164
unsub()
165165

166166

0 commit comments

Comments
 (0)