@@ -136,6 +136,30 @@ def get_image_data(self, diffgram_file):
136136 else :
137137 raise Exception ('Pytorch datasets only support images. Please provide only file_ids from images' )
138138
139+ def gen_global_attrs (self , instance_list ):
140+ res = []
141+ for inst in instance_list :
142+ if inst ['type' ] != 'global' :
143+ continue
144+ res .append (inst ['attribute_groups' ])
145+ return res
146+
147+ def gen_tag_instances (self , instance_list ):
148+ result = []
149+ for inst in instance_list :
150+ if inst ['type' ] != 'tag' :
151+ continue
152+ for k in list (inst .keys ()):
153+ val = inst [k ]
154+ if val is None :
155+ inst .pop (k )
156+ elm = {
157+ 'label' : inst ['label_file' ]['label' ]['name' ],
158+ 'label_file_id' : inst ['label_file' ]['id' ],
159+ }
160+ result .append (elm )
161+ return result
162+
139163 def get_file_instances (self , diffgram_file ):
140164 if diffgram_file .type not in ['image' , 'frame' ]:
141165 raise NotImplementedError ('File type "{}" is not supported yet' .format (diffgram_file ['type' ]))
@@ -147,6 +171,9 @@ def get_file_instances(self, diffgram_file):
147171 sample = {'image' : image , 'diffgram_file' : diffgram_file }
148172 has_boxes = False
149173 has_poly = False
174+ has_tags = False
175+ has_global = False
176+ sample ['raw_instance_list' ] = instance_list
150177 if 'box' in instance_types_in_file :
151178 has_boxes = True
152179 x_min_list , x_max_list , y_min_list , y_max_list = self .extract_bbox_values (instance_list , diffgram_file )
@@ -164,12 +191,19 @@ def get_file_instances(self, diffgram_file):
164191 has_poly = True
165192 mask_list = self .extract_masks_from_polygon (instance_list , diffgram_file )
166193 sample ['polygon_mask_list' ] = mask_list
194+ if 'tag' in instance_types_in_file :
195+ has_tags = True
196+ sample ['tags' ] = self .gen_tag_instances (instance_list )
197+ if 'global' in instance_types_in_file :
198+ has_global = True
199+ sample ['global_attributes' ] = self .gen_global_attrs (instance_list )
200+
167201 else :
168202 sample ['polygon_mask_list' ] = []
169203
170- if len (instance_types_in_file ) > 2 and has_boxes and has_boxes :
204+ if len (instance_types_in_file ) > 4 and has_poly and has_boxes and has_tags and has_global :
171205 raise NotImplementedError (
172- 'SDK only supports boxes and polygon types currently. If you want a new instance type to be supported please contact us!'
206+ 'SDK Streaming only supports boxes and polygon, tags and global attributes types currently. If you want a new instance type to be supported please contact us!'
173207 )
174208
175209 label_id_list , label_name_list = self .extract_labels (instance_list )
@@ -198,11 +232,13 @@ def extract_masks_from_polygon(self, instance_list, diffgram_file, empty_value =
198232 def extract_labels (self , instance_list , allowed_instance_types = None ):
199233 label_file_id_list = []
200234 label_names_list = []
201-
202235 for inst in instance_list :
236+ if inst ['type' ] == 'global' :
237+ continue
238+ if inst is None :
239+ continue
203240 if allowed_instance_types and inst ['type' ] in allowed_instance_types :
204241 continue
205-
206242 label_file_id_list .append (inst ['label_file' ]['id' ])
207243 label_names_list .append (inst ['label_file' ]['label' ]['name' ])
208244
0 commit comments