Skip to content

Commit e1caa39

Browse files
committed
Added protocol checks for loaded interface definitions.
1 parent f6e6f3f commit e1caa39

File tree

1 file changed

+25
-25
lines changed

1 file changed

+25
-25
lines changed

simple_rpc/simple_rpc.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from serial import serial_for_url
77
from serial.serialutil import SerialException
8-
from yaml import dump, load
8+
from yaml import FullLoader, dump, load
99

1010
from .extras import make_function
1111
from .io import read, read_byte_string, until, write
@@ -18,6 +18,19 @@
1818
_list_req = 0xff
1919

2020

21+
def _assert_protocol(protocol: str) -> None:
22+
if protocol != _protocol:
23+
raise ValueError('invalid protocol header')
24+
25+
26+
def _assert_version(version: tuple) -> None:
27+
if version[0] != _version[0] or version[1] > _version[1]:
28+
raise ValueError(
29+
'version mismatch (device: {}, client: {})'.format(
30+
'.'.join(map(str, version)),
31+
'.'.join(map(str, _version))))
32+
33+
2134
class _Interface(object):
2235
"""Generic simpleRPC interface."""
2336
def __init__(
@@ -34,12 +47,13 @@ def __init__(
3447

3548
self._connection = serial_for_url(
3649
device, do_not_open=True, baudrate=baudrate)
37-
self._load = load # TODO: Content checking.
50+
self._load = load
3851
self.device = {
39-
'version': (0, 0, 0),
4052
'endianness': '<',
53+
'methods': {},
54+
'protocol': '',
4155
'size_t': 'H',
42-
'methods': {}}
56+
'version': (0, 0, 0)}
4357

4458
if autoconnect:
4559
self.open()
@@ -98,15 +112,11 @@ def _get_methods(self: object) -> dict:
98112
"""
99113
self._select(_list_req)
100114

101-
if self._read_byte_string().decode() != _protocol:
102-
raise ValueError('missing protocol header')
115+
_assert_protocol(self._read_byte_string().decode())
116+
self.device['protocol'] = _protocol
103117

104118
version = tuple(self._read('B') for _ in range(3))
105-
if version[0] != _version[0] or version[1] > _version[1]:
106-
raise ValueError(
107-
'version mismatch (device: {}, client: {})'.format(
108-
'.'.join(map(str, version)),
109-
'.'.join(map(str, _version))))
119+
_assert_version(version)
110120
self.device['version'] = version
111121

112122
self.device['endianness'], self.device['size_t'] = (
@@ -178,23 +188,13 @@ def save(self: object, handle: TextIO) -> None:
178188
179189
:arg handle: Open file handle.
180190
"""
181-
dump(
182-
#{
183-
# 'version': self._version,
184-
# 'endianness': self._endianness,
185-
# 'size_t': self._size_t,
186-
# 'methods': self.methods
187-
#},
188-
self.device,
189-
handle, width=76, default_flow_style=False)
191+
dump(self.device, handle, width=76, default_flow_style=False)
190192

191193
def load(self: object) -> None:
192194
"""Load the interface definition from a file."""
193-
self.device = load(self._load)
194-
#self._version = definition['version']
195-
#self._endianness = definition['endianness']
196-
#self._size_t = definition['size_t']
197-
#self.methods = definition['methods']
195+
self.device = load(self._load, Loader=FullLoader)
196+
_assert_protocol(self.device['protocol'])
197+
_assert_version(self.device['version'])
198198

199199

200200
class SerialInterface(_Interface):

0 commit comments

Comments
 (0)