Skip to content

Commit d19346e

Browse files
authored
Merge pull request #211 from superannotateai/ml_model_prediction
Ml model prediction
2 parents 6fd474c + 6ba0a0f commit d19346e

File tree

5 files changed

+50
-35
lines changed

5 files changed

+50
-35
lines changed

src/superannotate/lib/app/interface/sdk_interface.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,9 @@ def clone_project(
270270
)
271271
if response.errors:
272272
raise AppException(response.errors)
273+
logger.info(
274+
f"Created project {project_name} (ID {response.data.uuid} ) with type { constances.ProjectType.get_name(response.data.project_type)}."
275+
)
273276
return ProjectSerializer(response.data).serialize()
274277

275278

src/superannotate/lib/core/usecases.py

Lines changed: 43 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import numpy as np
2323
import pandas as pd
2424
import requests
25-
from boto3.exceptions import Boto3Error
2625
from botocore.exceptions import ClientError
2726
from lib.app.analytics.common import aggregate_annotations_as_df
2827
from lib.app.analytics.common import consensus_plot
@@ -3437,7 +3436,9 @@ def execute(self):
34373436
)
34383437
if self._project.project_type == constances.ProjectType.PIXEL.value:
34393438
mask_path = None
3440-
png_path = self._annotation_path.replace("___pixel.json", "___save.png")
3439+
png_path = self._annotation_path.replace(
3440+
"___pixel.json", "___save.png"
3441+
)
34413442
if os.path.exists(png_path) and not self._mask:
34423443
mask_path = png_path
34433444
elif self._mask:
@@ -4122,36 +4123,44 @@ def __init__(
41224123
self._backend_service = backend_service_provider
41234124
self._team_id = team_id
41244125

4126+
def validate_training_status(self):
4127+
if self._model.training_status not in [
4128+
constances.TrainingStatus.COMPLETED.value,
4129+
constances.TrainingStatus.FAILED_AFTER_EVALUATION_WITH_SAVE_MODEL.value,
4130+
]:
4131+
raise AppException("Unable to download.")
4132+
41254133
def execute(self):
4126-
metrics_name = os.path.basename(self._model.path).replace(".pth", ".json")
4127-
mapper_path = self._model.config_path.replace(
4128-
os.path.basename(self._model.config_path), "classes_mapper.json"
4129-
)
4130-
metrics_path = self._model.config_path.replace(
4131-
os.path.basename(self._model.config_path), metrics_name
4132-
)
4134+
if self.is_valid():
4135+
metrics_name = os.path.basename(self._model.path).replace(".pth", ".json")
4136+
mapper_path = self._model.config_path.replace(
4137+
os.path.basename(self._model.config_path), "classes_mapper.json"
4138+
)
4139+
metrics_path = self._model.config_path.replace(
4140+
os.path.basename(self._model.config_path), metrics_name
4141+
)
41334142

4134-
auth_response = self._backend_service.get_ml_model_download_tokens(
4135-
self._team_id, self._model.uuid
4136-
)
4137-
if not auth_response.ok:
4138-
raise AppException(auth_response.error)
4139-
s3_session = boto3.Session(
4140-
aws_access_key_id=auth_response.data.access_key,
4141-
aws_secret_access_key=auth_response.data.secret_key,
4142-
aws_session_token=auth_response.data.session_token,
4143-
region_name=auth_response.data.region,
4144-
)
4145-
bucket = s3_session.resource("s3").Bucket(auth_response.data.bucket)
4143+
auth_response = self._backend_service.get_ml_model_download_tokens(
4144+
self._team_id, self._model.uuid
4145+
)
4146+
if not auth_response.ok:
4147+
raise AppException(auth_response.error)
4148+
s3_session = boto3.Session(
4149+
aws_access_key_id=auth_response.data.access_key,
4150+
aws_secret_access_key=auth_response.data.secret_key,
4151+
aws_session_token=auth_response.data.session_token,
4152+
region_name=auth_response.data.region,
4153+
)
4154+
bucket = s3_session.resource("s3").Bucket(auth_response.data.bucket)
41464155

4147-
bucket.download_file(
4148-
self._model.config_path, os.path.join(self._download_path, "config.yaml")
4149-
)
4150-
bucket.download_file(
4151-
self._model.path,
4152-
os.path.join(self._download_path, os.path.basename(self._model.path)),
4153-
)
4154-
if self._model.is_global:
4156+
bucket.download_file(
4157+
self._model.config_path,
4158+
os.path.join(self._download_path, "config.yaml"),
4159+
)
4160+
bucket.download_file(
4161+
self._model.path,
4162+
os.path.join(self._download_path, os.path.basename(self._model.path)),
4163+
)
41554164
try:
41564165
bucket.download_file(
41574166
metrics_path, os.path.join(self._download_path, metrics_name)
@@ -4160,11 +4169,11 @@ def execute(self):
41604169
mapper_path,
41614170
os.path.join(self._download_path, "classes_mapper.json"),
41624171
)
4163-
except Boto3Error:
4164-
self._response.errors = AppException(
4172+
except ClientError:
4173+
logger.info(
41654174
"The specified model does not contain a classes_mapper and/or a metrics file."
41664175
)
4167-
self._response.data = self._model
4176+
self._response.data = self._model
41684177
return self._response
41694178

41704179

@@ -4531,13 +4540,13 @@ def execute(self):
45314540
success_images = [
45324541
img.name
45334542
for img in images_metadata
4534-
if img.segmentation_status
4543+
if img.prediction_status
45354544
== constances.SegmentationStatus.COMPLETED.value
45364545
]
45374546
failed_images = [
45384547
img.name
45394548
for img in images_metadata
4540-
if img.segmentation_status
4549+
if img.prediction_status
45414550
== constances.SegmentationStatus.FAILED.value
45424551
]
45434552

src/superannotate/lib/infrastructure/controller.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,6 +1402,8 @@ def download_ml_model(self, model_data: dict, download_path: str):
14021402
path=model_data["path"],
14031403
config_path=model_data["config_path"],
14041404
team_id=model_data["team_id"],
1405+
training_status=model_data["training_status"],
1406+
is_global=model_data["is_global"],
14051407
)
14061408
use_case = usecases.DownloadMLModelUseCase(
14071409
model=model,

src/superannotate/lib/infrastructure/repositories.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,4 +468,5 @@ def dict2entity(data: dict):
468468
path=data["path"],
469469
config_path=data["config_path"],
470470
is_global=data["is_global"],
471+
training_status=data["training_status"],
471472
)

src/superannotate/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "5.0.0b27"
1+
__version__ = "5.0.0b29"

0 commit comments

Comments
 (0)