|
13 | 13 | from lib.core.entities import ImageEntity |
14 | 14 | from lib.core.entities import ProjectEntity |
15 | 15 | from lib.core.exceptions import AppException |
16 | | -from lib.core.exceptions import AppValidationException |
17 | 16 | from lib.core.helpers import convert_to_video_editor_json |
18 | 17 | from lib.core.helpers import fill_annotation_ids |
19 | 18 | from lib.core.helpers import map_annotation_classes_name |
@@ -177,64 +176,84 @@ def _upload_annotation( |
177 | 176 | # raise e |
178 | 177 | return path, False |
179 | 178 |
|
| 179 | + def get_bucket_to_upload(self, ids: List[int]): |
| 180 | + upload_data = self.get_annotation_upload_data(ids) |
| 181 | + if upload_data: |
| 182 | + session = boto3.Session( |
| 183 | + aws_access_key_id=upload_data.access_key, |
| 184 | + aws_secret_access_key=upload_data.secret_key, |
| 185 | + aws_session_token=upload_data.session_token, |
| 186 | + region_name=upload_data.region, |
| 187 | + ) |
| 188 | + resource = session.resource("s3") |
| 189 | + return resource.Bucket(upload_data.bucket) |
| 190 | + |
| 191 | + def _log_report(self): |
| 192 | + for key, values in self.reporter.custom_messages.items(): |
| 193 | + template = key + ": {}" |
| 194 | + if key == "missing_classes": |
| 195 | + template = "Could not find annotation classes matching existing classes on the platform: [{}]" |
| 196 | + elif key == "missing_attribute_groups": |
| 197 | + template = "Could not find attribute groups matching existing attribute groups on the platform: [{}]" |
| 198 | + elif key == "missing_attributes": |
| 199 | + template = "Could not find attributes matching existing attributes on the platform: [{}]" |
| 200 | + logger.warning( |
| 201 | + template.format(", ".join(values)) |
| 202 | + ) |
| 203 | + |
180 | 204 | def execute(self): |
181 | 205 | uploaded_annotations = [] |
182 | 206 | failed_annotations = [] |
183 | | - iterations_range = range( |
184 | | - 0, len(self.annotations_to_upload), self.AUTH_DATA_CHUNK_SIZE |
185 | | - ) |
186 | | - self.reporter.start_progress(iterations_range,description="Uploading Annotations") |
187 | | - for _ in iterations_range: |
188 | | - annotations_to_upload = self.annotations_to_upload[ |
189 | | - _ : _ + self.AUTH_DATA_CHUNK_SIZE # noqa: E203 |
190 | | - ] |
191 | | - upload_data = self.get_annotation_upload_data( |
192 | | - [int(image.id) for image in annotations_to_upload] |
| 207 | + if self.annotations_to_upload: |
| 208 | + iterations_range = range( |
| 209 | + 0, len(self.annotations_to_upload), self.AUTH_DATA_CHUNK_SIZE |
193 | 210 | ) |
194 | | - if upload_data: |
195 | | - session = boto3.Session( |
196 | | - aws_access_key_id=upload_data.access_key, |
197 | | - aws_secret_access_key=upload_data.secret_key, |
198 | | - aws_session_token=upload_data.session_token, |
199 | | - region_name=upload_data.region, |
| 211 | + self.reporter.start_progress(iterations_range, description="Uploading Annotations") |
| 212 | + for _ in iterations_range: |
| 213 | + annotations_to_upload = self.annotations_to_upload[ |
| 214 | + _ : _ + self.AUTH_DATA_CHUNK_SIZE # noqa: E203 |
| 215 | + ] |
| 216 | + upload_data = self.get_annotation_upload_data( |
| 217 | + [int(image.id) for image in annotations_to_upload] |
200 | 218 | ) |
201 | | - resource = session.resource("s3") |
202 | | - bucket = resource.Bucket(upload_data.bucket) |
203 | | - image_id_name_map = { |
204 | | - image.id: image for image in self.annotations_to_upload |
205 | | - } |
206 | | - # dummy progress |
207 | | - for _ in range(len(annotations_to_upload) - len(upload_data.images)): |
208 | | - self.reporter.update_progress() |
209 | | - with concurrent.futures.ThreadPoolExecutor( |
210 | | - max_workers=self.MAX_WORKERS |
211 | | - ) as executor: |
212 | | - results = [ |
213 | | - executor.submit( |
214 | | - self._upload_annotation, |
215 | | - image_id, |
216 | | - image_id_name_map[image_id].name, |
217 | | - upload_data, |
218 | | - image_id_name_map[image_id].path, |
219 | | - bucket, |
220 | | - ) |
221 | | - for image_id, image_data in upload_data.images.items() |
222 | | - ] |
223 | | - for future in concurrent.futures.as_completed(results): |
224 | | - annotation, uploaded = future.result() |
225 | | - if uploaded: |
226 | | - uploaded_annotations.append(annotation) |
227 | | - else: |
228 | | - failed_annotations.append(annotation) |
| 219 | + bucket = self.get_bucket_to_upload([int(image.id) for image in annotations_to_upload]) |
| 220 | + if bucket: |
| 221 | + image_id_name_map = { |
| 222 | + image.id: image for image in self.annotations_to_upload |
| 223 | + } |
| 224 | + # dummy progress |
| 225 | + for _ in range(len(annotations_to_upload) - len(upload_data.images)): |
229 | 226 | self.reporter.update_progress() |
| 227 | + with concurrent.futures.ThreadPoolExecutor( |
| 228 | + max_workers=self.MAX_WORKERS |
| 229 | + ) as executor: |
| 230 | + results = [ |
| 231 | + executor.submit( |
| 232 | + self._upload_annotation, |
| 233 | + image_id, |
| 234 | + image_id_name_map[image_id].name, |
| 235 | + upload_data, |
| 236 | + image_id_name_map[image_id].path, |
| 237 | + bucket, |
| 238 | + ) |
| 239 | + for image_id, image_data in upload_data.images.items() |
| 240 | + ] |
| 241 | + for future in concurrent.futures.as_completed(results): |
| 242 | + annotation, uploaded = future.result() |
| 243 | + if uploaded: |
| 244 | + uploaded_annotations.append(annotation) |
| 245 | + else: |
| 246 | + failed_annotations.append(annotation) |
| 247 | + self.reporter.update_progress() |
230 | 248 |
|
231 | | - self._response.data = ( |
232 | | - uploaded_annotations, |
233 | | - failed_annotations, |
234 | | - [annotation.path for annotation in self._missing_annotations], |
235 | | - ) |
236 | | - for message in self.reporter.messages: |
237 | | - logger.warning(message) |
| 249 | + self._response.data = ( |
| 250 | + uploaded_annotations, |
| 251 | + failed_annotations, |
| 252 | + [annotation.path for annotation in self._missing_annotations], |
| 253 | + ) |
| 254 | + self._log_report() |
| 255 | + else: |
| 256 | + self._response.errors = "Could not find annotations matching existing items on the platform." |
238 | 257 | return self._response |
239 | 258 |
|
240 | 259 |
|
@@ -401,5 +420,5 @@ def execute(self): |
401 | 420 | self._project.name, |
402 | 421 | ) |
403 | 422 | else: |
404 | | - self._response.errors = f"Invalid json" |
| 423 | + self._response.errors = "Invalid json" |
405 | 424 | return self._response |
0 commit comments