Skip to content

Commit 1ec75b7

Browse files
Add mcap support for osi3trace
Signed-off-by: Thomas Sedlmayer <tsedlmayer@pmsfit.de>
1 parent d610c69 commit 1ec75b7

File tree

1 file changed

+178
-4
lines changed

1 file changed

+178
-4
lines changed

osi3trace/osi_trace.py

Lines changed: 178 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,14 @@
33
"""
44

55
import lzma
6+
from pathlib import Path
67
import struct
78

9+
from abc import ABC, abstractmethod
10+
11+
from mcap_protobuf.decoder import DecoderFactory
12+
from mcap.reader import make_reader
13+
814
from osi3.osi_sensorview_pb2 import SensorView
915
from osi3.osi_sensorviewconfiguration_pb2 import SensorViewConfiguration
1016
from osi3.osi_groundtruth_pb2 import GroundTruth
@@ -32,8 +38,8 @@
3238

3339

3440
class OSITrace:
35-
"""This class can import and decode OSI trace files."""
36-
41+
"""This class can import and decode OSI single- and multi-channel trace files."""
42+
3743
@staticmethod
3844
def map_message_type(type_name):
3945
"""Map the type name to the protobuf message type."""
@@ -43,9 +49,120 @@ def map_message_type(type_name):
4349
def message_types():
4450
"""Message types that OSITrace supports."""
4551
return list(MESSAGES_TYPE.keys())
52+
53+
def __init__(self, path=None, type_name="SensorView", cache_messages=False, topic=None):
54+
"""
55+
Initializes the trace reader depending on the trace file format.
56+
57+
Args:
58+
path (str): The path to the trace file.
59+
type_name (str): The type name of the messages in the trace; check supported message types with `OSITrace.message_types()`.
60+
cache_messages (bool): Whether to cache messages in memory (only applies to single-channel traces).
61+
topic (str): The topic name for multi-channel traces (only applies to multi-channel traces); Using the first available topic if not specified.
62+
"""
63+
self.reader = None
64+
self.path = None
65+
66+
if path is not None:
67+
self.reader = self._init_reader(Path(path), type_name, cache_messages, topic)
68+
69+
def _init_reader(self, path, type_name, cache_messages, topic):
70+
if not path.exists():
71+
raise FileNotFoundError("File not found")
72+
73+
if path.suffix.lower() == ".mcap":
74+
reader = OSITraceMulti(path, topic)
75+
if reader.get_message_type() != type_name:
76+
raise ValueError(f"Channel message type '{reader.get_message_type()}' does not match expected type '{type_name}'")
77+
return reader
78+
elif path.suffix.lower() in [".osi", ".lzma", ".xz"]:
79+
return OSITraceSingle(str(path), type_name, cache_messages)
80+
else:
81+
raise ValueError(f"Unsupported file format: '{path.suffix}'")
82+
83+
def from_file(self, path, type_name="SensorView", cache_messages=False, topic=None):
84+
"""
85+
Initializes the trace reader depending on the trace file format.
86+
87+
Args:
88+
path (str): The path to the trace file.
89+
type_name (str): The type name of the messages in the trace; check supported message types with `OSITrace.message_types()`.
90+
cache_messages (bool): Whether to cache messages in memory (only applies to single-channel traces).
91+
topic (str): The topic name for multi-channel traces (only applies to multi-channel traces); Using the first available topic if not specified.
92+
"""
93+
self.reader = self._init_reader(Path(path), type_name, cache_messages, topic)
94+
95+
def restart(self, index=None):
96+
"""
97+
Restart the trace reader.
98+
99+
Note:
100+
Multi-channel traces don't support restarting from a specific index.
101+
"""
102+
return self.reader.restart(index)
103+
104+
def __iter__(self):
105+
return self.reader.__iter__()
106+
107+
def close(self):
108+
return self.reader.close()
109+
110+
def retrieve_offsets(self, limit=None):
111+
if isinstance(self.reader, OSITraceSingle):
112+
return self.reader.retrieve_offsets(limit)
113+
raise NotImplementedError("Offsets are only supported for single-channel traces.")
114+
115+
def retrieve_message(self, index=None, skip=False):
116+
if isinstance(self.reader, OSITraceSingle):
117+
return self.reader.retrieve_message(index, skip)
118+
raise NotImplementedError("Index-based message retrieval is only supported for single-channel traces.")
119+
120+
def get_message_by_index(self, index):
121+
if isinstance(self.reader, OSITraceSingle):
122+
return self.reader.get_message_by_index(index)
123+
raise NotImplementedError("Index-based message retrieval is only supported for single-channel traces.")
124+
125+
def get_messages_in_index_range(self, begin, end):
126+
if isinstance(self.reader, OSITraceSingle):
127+
return self.reader.get_messages_in_index_range(begin, end)
128+
raise NotImplementedError("Index-based message retrieval is only supported for single-channel traces.")
129+
130+
def get_available_topics(self):
131+
if isinstance(self.reader, OSITraceMulti):
132+
return self.reader.get_available_topics()
133+
raise NotImplementedError("Getting available topics is only supported for multi-channel traces.")
134+
135+
def get_file_metadata(self):
136+
if isinstance(self.reader, OSITraceMulti):
137+
return self.reader.get_file_metadata()
138+
raise NotImplementedError("Getting file metadata is only supported for multi-channel traces.")
139+
140+
def get_channel_metadata(self):
141+
if isinstance(self.reader, OSITraceMulti):
142+
return self.reader.get_channel_metadata()
143+
raise NotImplementedError("Getting channel metadata is only supported for multi-channel traces.")
144+
145+
146+
class ReaderBase(ABC):
147+
"""Common interface for trace readers"""
148+
@abstractmethod
149+
def restart(self, index=None):
150+
pass
151+
152+
@abstractmethod
153+
def __iter__(self):
154+
pass
155+
156+
@abstractmethod
157+
def close(self):
158+
pass
46159

160+
161+
class OSITraceSingle(ReaderBase):
162+
"""OSI single-channel trace reader"""
163+
47164
def __init__(self, path=None, type_name="SensorView", cache_messages=False):
48-
self.type = self.map_message_type(type_name)
165+
self.type = OSITrace.map_message_type(type_name)
49166
self.file = None
50167
self.current_index = None
51168
self.message_offsets = None
@@ -57,7 +174,7 @@ def __init__(self, path=None, type_name="SensorView", cache_messages=False):
57174

58175
def from_file(self, path, type_name="SensorView", cache_messages=False):
59176
"""Import a trace from a file"""
60-
self.type = self.map_message_type(type_name)
177+
self.type = OSITrace.map_message_type(type_name)
61178

62179
if path.lower().endswith((".lzma", ".xz")):
63180
self.file = lzma.open(path, "rb")
@@ -186,3 +303,60 @@ def close(self):
186303
self.read_complete = False
187304
self.read_limit = None
188305
self.type = None
306+
307+
308+
class OSITraceMulti(ReaderBase):
309+
"""OSI multi-channel trace reader"""
310+
311+
def __init__(self, path, topic):
312+
self.path = Path(path)
313+
self._file = open(self.path, "rb")
314+
self.mcap_reader = make_reader(self._file, decoder_factories=[DecoderFactory()])
315+
self._summary = self.mcap_reader.get_summary()
316+
available_topics = self.get_available_topics()
317+
if topic == None:
318+
topic = available_topics[0]
319+
if topic not in available_topics:
320+
raise ValueError(f"The requested topic '{topic}' is not present in the trace file.")
321+
self.topic = topic
322+
323+
def restart(self, index=None):
324+
if index != None:
325+
raise NotImplementedError("Restarting from a given index is not supported for multi-channel traces.")
326+
if hasattr(self, "_iter"):
327+
del self._iter
328+
329+
def __iter__(self):
330+
"""Stateful iterator over the channel's messages in log time order."""
331+
if not hasattr(self, "_iter"):
332+
self._iter = self.mcap_reader.iter_decoded_messages(topics=[self.topic])
333+
for message in self._iter:
334+
yield message.decoded_message
335+
336+
def close(self):
337+
self._file.close()
338+
339+
def get_available_topics(self):
340+
return [channel.topic for id, channel in self._summary.channels.items()]
341+
342+
def get_file_metadata(self):
343+
metadata = []
344+
for metadata_entry in self.mcap_reader.iter_metadata():
345+
metadata.append(metadata_entry)
346+
return metadata
347+
348+
def get_channel_metadata(self):
349+
for id, channel in self._summary.channels.items():
350+
if channel.topic == self.topic:
351+
return channel.metadata
352+
return None
353+
354+
def get_message_type(self):
355+
for channel_id, channel in self._summary.channels.items():
356+
if channel.topic == self.topic:
357+
schema = self._summary.schemas[channel.schema_id]
358+
if schema.name.startswith("osi3."):
359+
return schema.name[len("osi3.") :]
360+
else:
361+
raise ValueError(f"Schema '{schema.name}' is not an 'osi3.' schema.")
362+
return None

0 commit comments

Comments
 (0)