2222import numpy as np
2323import pandas as pd
2424import requests
25- from boto3 .exceptions import Boto3Error
2625from botocore .exceptions import ClientError
2726from lib .app .analytics .common import aggregate_annotations_as_df
2827from lib .app .analytics .common import consensus_plot
@@ -3437,7 +3436,9 @@ def execute(self):
34373436 )
34383437 if self ._project .project_type == constances .ProjectType .PIXEL .value :
34393438 mask_path = None
3440- png_path = self ._annotation_path .replace ("___pixel.json" , "___save.png" )
3439+ png_path = self ._annotation_path .replace (
3440+ "___pixel.json" , "___save.png"
3441+ )
34413442 if os .path .exists (png_path ) and not self ._mask :
34423443 mask_path = png_path
34433444 elif self ._mask :
@@ -4122,36 +4123,44 @@ def __init__(
41224123 self ._backend_service = backend_service_provider
41234124 self ._team_id = team_id
41244125
4126+ def validate_training_status (self ):
4127+ if self ._model .training_status not in [
4128+ constances .TrainingStatus .COMPLETED .value ,
4129+ constances .TrainingStatus .FAILED_AFTER_EVALUATION_WITH_SAVE_MODEL .value ,
4130+ ]:
4131+ raise AppException ("Unable to download." )
4132+
41254133 def execute (self ):
4126- metrics_name = os .path .basename (self ._model .path ).replace (".pth" , ".json" )
4127- mapper_path = self ._model .config_path .replace (
4128- os .path .basename (self ._model .config_path ), "classes_mapper.json"
4129- )
4130- metrics_path = self ._model .config_path .replace (
4131- os .path .basename (self ._model .config_path ), metrics_name
4132- )
4134+ if self .is_valid ():
4135+ metrics_name = os .path .basename (self ._model .path ).replace (".pth" , ".json" )
4136+ mapper_path = self ._model .config_path .replace (
4137+ os .path .basename (self ._model .config_path ), "classes_mapper.json"
4138+ )
4139+ metrics_path = self ._model .config_path .replace (
4140+ os .path .basename (self ._model .config_path ), metrics_name
4141+ )
41334142
4134- auth_response = self ._backend_service .get_ml_model_download_tokens (
4135- self ._team_id , self ._model .uuid
4136- )
4137- if not auth_response .ok :
4138- raise AppException (auth_response .error )
4139- s3_session = boto3 .Session (
4140- aws_access_key_id = auth_response .data .access_key ,
4141- aws_secret_access_key = auth_response .data .secret_key ,
4142- aws_session_token = auth_response .data .session_token ,
4143- region_name = auth_response .data .region ,
4144- )
4145- bucket = s3_session .resource ("s3" ).Bucket (auth_response .data .bucket )
4143+ auth_response = self ._backend_service .get_ml_model_download_tokens (
4144+ self ._team_id , self ._model .uuid
4145+ )
4146+ if not auth_response .ok :
4147+ raise AppException (auth_response .error )
4148+ s3_session = boto3 .Session (
4149+ aws_access_key_id = auth_response .data .access_key ,
4150+ aws_secret_access_key = auth_response .data .secret_key ,
4151+ aws_session_token = auth_response .data .session_token ,
4152+ region_name = auth_response .data .region ,
4153+ )
4154+ bucket = s3_session .resource ("s3" ).Bucket (auth_response .data .bucket )
41464155
4147- bucket .download_file (
4148- self ._model .config_path , os . path . join ( self . _download_path , "config.yaml" )
4149- )
4150- bucket . download_file (
4151- self . _model . path ,
4152- os . path . join ( self . _download_path , os . path . basename ( self ._model .path )) ,
4153- )
4154- if self . _model . is_global :
4156+ bucket .download_file (
4157+ self ._model .config_path ,
4158+ os . path . join ( self . _download_path , "config.yaml" ),
4159+ )
4160+ bucket . download_file (
4161+ self ._model .path ,
4162+ os . path . join ( self . _download_path , os . path . basename ( self . _model . path )),
4163+ )
41554164 try :
41564165 bucket .download_file (
41574166 metrics_path , os .path .join (self ._download_path , metrics_name )
@@ -4160,11 +4169,11 @@ def execute(self):
41604169 mapper_path ,
41614170 os .path .join (self ._download_path , "classes_mapper.json" ),
41624171 )
4163- except Boto3Error :
4164- self . _response . errors = AppException (
4172+ except ClientError :
4173+ logger . info (
41654174 "The specified model does not contain a classes_mapper and/or a metrics file."
41664175 )
4167- self ._response .data = self ._model
4176+ self ._response .data = self ._model
41684177 return self ._response
41694178
41704179
@@ -4531,13 +4540,13 @@ def execute(self):
45314540 success_images = [
45324541 img .name
45334542 for img in images_metadata
4534- if img .segmentation_status
4543+ if img .prediction_status
45354544 == constances .SegmentationStatus .COMPLETED .value
45364545 ]
45374546 failed_images = [
45384547 img .name
45394548 for img in images_metadata
4540- if img .segmentation_status
4549+ if img .prediction_status
45414550 == constances .SegmentationStatus .FAILED .value
45424551 ]
45434552
0 commit comments