33"""
44
55import lzma
6+ from pathlib import Path
67import struct
78
9+ from abc import ABC , abstractmethod
10+
11+ from mcap_protobuf .decoder import DecoderFactory
12+ from mcap .reader import make_reader
13+
814from osi3 .osi_sensorview_pb2 import SensorView
915from osi3 .osi_sensorviewconfiguration_pb2 import SensorViewConfiguration
1016from osi3 .osi_groundtruth_pb2 import GroundTruth
3238
3339
3440class OSITrace :
35- """This class can import and decode OSI trace files."""
36-
41+ """This class can import and decode OSI single- and multi-channel trace files."""
42+
3743 @staticmethod
3844 def map_message_type (type_name ):
3945 """Map the type name to the protobuf message type."""
@@ -43,9 +49,120 @@ def map_message_type(type_name):
4349 def message_types ():
4450 """Message types that OSITrace supports."""
4551 return list (MESSAGES_TYPE .keys ())
52+
53+ def __init__ (self , path = None , type_name = "SensorView" , cache_messages = False , topic = None ):
54+ """
55+ Initializes the trace reader depending on the trace file format.
56+
57+ Args:
58+ path (str): The path to the trace file.
59+ type_name (str): The type name of the messages in the trace; check supported message types with `OSITrace.message_types()`.
60+ cache_messages (bool): Whether to cache messages in memory (only applies to single-channel traces).
61+ topic (str): The topic name for multi-channel traces (only applies to multi-channel traces); Using the first available topic if not specified.
62+ """
63+ self .reader = None
64+ self .path = None
65+
66+ if path is not None :
67+ self .reader = self ._init_reader (Path (path ), type_name , cache_messages , topic )
68+
69+ def _init_reader (self , path , type_name , cache_messages , topic ):
70+ if not path .exists ():
71+ raise FileNotFoundError ("File not found" )
72+
73+ if path .suffix .lower () == ".mcap" :
74+ reader = OSITraceMulti (path , topic )
75+ if reader .get_message_type () != type_name :
76+ raise ValueError (f"Channel message type '{ reader .get_message_type ()} ' does not match expected type '{ type_name } '" )
77+ return reader
78+ elif path .suffix .lower () in [".osi" , ".lzma" , ".xz" ]:
79+ return OSITraceSingle (str (path ), type_name , cache_messages )
80+ else :
81+ raise ValueError (f"Unsupported file format: '{ path .suffix } '" )
82+
83+ def from_file (self , path , type_name = "SensorView" , cache_messages = False , topic = None ):
84+ """
85+ Initializes the trace reader depending on the trace file format.
86+
87+ Args:
88+ path (str): The path to the trace file.
89+ type_name (str): The type name of the messages in the trace; check supported message types with `OSITrace.message_types()`.
90+ cache_messages (bool): Whether to cache messages in memory (only applies to single-channel traces).
91+ topic (str): The topic name for multi-channel traces (only applies to multi-channel traces); Using the first available topic if not specified.
92+ """
93+ self .reader = self ._init_reader (Path (path ), type_name , cache_messages , topic )
94+
95+ def restart (self , index = None ):
96+ """
97+ Restart the trace reader.
98+
99+ Note:
100+ Multi-channel traces don't support restarting from a specific index.
101+ """
102+ return self .reader .restart (index )
103+
104+ def __iter__ (self ):
105+ return self .reader .__iter__ ()
106+
107+ def close (self ):
108+ return self .reader .close ()
109+
110+ def retrieve_offsets (self , limit = None ):
111+ if isinstance (self .reader , OSITraceSingle ):
112+ return self .reader .retrieve_offsets (limit )
113+ raise NotImplementedError ("Offsets are only supported for single-channel traces." )
114+
115+ def retrieve_message (self , index = None , skip = False ):
116+ if isinstance (self .reader , OSITraceSingle ):
117+ return self .reader .retrieve_message (index , skip )
118+ raise NotImplementedError ("Index-based message retrieval is only supported for single-channel traces." )
119+
120+ def get_message_by_index (self , index ):
121+ if isinstance (self .reader , OSITraceSingle ):
122+ return self .reader .get_message_by_index (index )
123+ raise NotImplementedError ("Index-based message retrieval is only supported for single-channel traces." )
124+
125+ def get_messages_in_index_range (self , begin , end ):
126+ if isinstance (self .reader , OSITraceSingle ):
127+ return self .reader .get_messages_in_index_range (begin , end )
128+ raise NotImplementedError ("Index-based message retrieval is only supported for single-channel traces." )
129+
130+ def get_available_topics (self ):
131+ if isinstance (self .reader , OSITraceMulti ):
132+ return self .reader .get_available_topics ()
133+ raise NotImplementedError ("Getting available topics is only supported for multi-channel traces." )
134+
135+ def get_file_metadata (self ):
136+ if isinstance (self .reader , OSITraceMulti ):
137+ return self .reader .get_file_metadata ()
138+ raise NotImplementedError ("Getting file metadata is only supported for multi-channel traces." )
139+
140+ def get_channel_metadata (self ):
141+ if isinstance (self .reader , OSITraceMulti ):
142+ return self .reader .get_channel_metadata ()
143+ raise NotImplementedError ("Getting channel metadata is only supported for multi-channel traces." )
144+
145+
146+ class ReaderBase (ABC ):
147+ """Common interface for trace readers"""
148+ @abstractmethod
149+ def restart (self , index = None ):
150+ pass
151+
152+ @abstractmethod
153+ def __iter__ (self ):
154+ pass
155+
156+ @abstractmethod
157+ def close (self ):
158+ pass
46159
160+
161+ class OSITraceSingle (ReaderBase ):
162+ """OSI single-channel trace reader"""
163+
47164 def __init__ (self , path = None , type_name = "SensorView" , cache_messages = False ):
48- self .type = self .map_message_type (type_name )
165+ self .type = OSITrace .map_message_type (type_name )
49166 self .file = None
50167 self .current_index = None
51168 self .message_offsets = None
@@ -57,7 +174,7 @@ def __init__(self, path=None, type_name="SensorView", cache_messages=False):
57174
58175 def from_file (self , path , type_name = "SensorView" , cache_messages = False ):
59176 """Import a trace from a file"""
60- self .type = self .map_message_type (type_name )
177+ self .type = OSITrace .map_message_type (type_name )
61178
62179 if path .lower ().endswith ((".lzma" , ".xz" )):
63180 self .file = lzma .open (path , "rb" )
@@ -186,3 +303,60 @@ def close(self):
186303 self .read_complete = False
187304 self .read_limit = None
188305 self .type = None
306+
307+
308+ class OSITraceMulti (ReaderBase ):
309+ """OSI multi-channel trace reader"""
310+
311+ def __init__ (self , path , topic ):
312+ self .path = Path (path )
313+ self ._file = open (self .path , "rb" )
314+ self .mcap_reader = make_reader (self ._file , decoder_factories = [DecoderFactory ()])
315+ self ._summary = self .mcap_reader .get_summary ()
316+ available_topics = self .get_available_topics ()
317+ if topic == None :
318+ topic = available_topics [0 ]
319+ if topic not in available_topics :
320+ raise ValueError (f"The requested topic '{ topic } ' is not present in the trace file." )
321+ self .topic = topic
322+
323+ def restart (self , index = None ):
324+ if index != None :
325+ raise NotImplementedError ("Restarting from a given index is not supported for multi-channel traces." )
326+ if hasattr (self , "_iter" ):
327+ del self ._iter
328+
329+ def __iter__ (self ):
330+ """Stateful iterator over the channel's messages in log time order."""
331+ if not hasattr (self , "_iter" ):
332+ self ._iter = self .mcap_reader .iter_decoded_messages (topics = [self .topic ])
333+ for message in self ._iter :
334+ yield message .decoded_message
335+
336+ def close (self ):
337+ self ._file .close ()
338+
339+ def get_available_topics (self ):
340+ return [channel .topic for id , channel in self ._summary .channels .items ()]
341+
342+ def get_file_metadata (self ):
343+ metadata = []
344+ for metadata_entry in self .mcap_reader .iter_metadata ():
345+ metadata .append (metadata_entry )
346+ return metadata
347+
348+ def get_channel_metadata (self ):
349+ for id , channel in self ._summary .channels .items ():
350+ if channel .topic == self .topic :
351+ return channel .metadata
352+ return None
353+
354+ def get_message_type (self ):
355+ for channel_id , channel in self ._summary .channels .items ():
356+ if channel .topic == self .topic :
357+ schema = self ._summary .schemas [channel .schema_id ]
358+ if schema .name .startswith ("osi3." ):
359+ return schema .name [len ("osi3." ) :]
360+ else :
361+ raise ValueError (f"Schema '{ schema .name } ' is not an 'osi3.' schema." )
362+ return None
0 commit comments