Skip to content

Commit 732219d

Browse files
committed
Updated pr.
1 parent f21244f commit 732219d

File tree

5 files changed

+408
-286
lines changed

5 files changed

+408
-286
lines changed

ads/aqua/extension/deployment_handler.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,15 @@ def post(self, *args, **kwargs): # noqa: ARG002
119119
if not input_data:
120120
raise HTTPError(400, Errors.NO_INPUT_DATA)
121121

122-
self.finish(AquaDeploymentApp().create(**input_data))
122+
model_deployment_id = input_data.pop("model_deployment_id", None)
123+
if model_deployment_id:
124+
self.finish(
125+
AquaDeploymentApp().update(
126+
model_deployment_id=model_deployment_id, **input_data
127+
)
128+
)
129+
else:
130+
self.finish(AquaDeploymentApp().create(**input_data))
123131

124132
def read(self, id):
125133
"""Read the information of an Aqua model deployment."""

ads/aqua/modeldeployment/deployment.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@
8282
AquaDeploymentDetail,
8383
ConfigValidationError,
8484
CreateModelDeploymentDetails,
85-
UpdateModelGroupDeploymentDetails,
85+
ModelDeploymentDetails,
86+
UpdateModelDeploymentDetails,
8687
)
8788
from ads.aqua.modeldeployment.model_group_config import ModelGroupConfig
8889
from ads.aqua.shaperecommend.recommend import AquaShapeRecommend
@@ -399,14 +400,14 @@ def create(
399400

400401
def _validate_input_models(
401402
self,
402-
create_deployment_details: CreateModelDeploymentDetails,
403+
deployment_details: ModelDeploymentDetails,
403404
):
404-
"""Validates the base models and associated fine tuned models from 'models' in create_deployment_details for stacked or multi model deployment."""
405+
"""Validates the base models and associated fine tuned models from 'models' in create_deployment_details or update_deployment_details for stacked or multi model deployment."""
405406
# Collect all unique model IDs (including fine-tuned models)
406407
source_model_ids = list(
407408
{
408409
model_id
409-
for model in create_deployment_details.models
410+
for model in deployment_details.models
410411
for model_id in model.all_model_ids()
411412
}
412413
)
@@ -417,7 +418,7 @@ def _validate_input_models(
417418
source_models = self.get_multi_source(source_model_ids) or {}
418419

419420
try:
420-
create_deployment_details.validate_input_models(model_details=source_models)
421+
deployment_details.validate_input_models(model_details=source_models)
421422
except ConfigValidationError as err:
422423
raise AquaValueError(f"{err}") from err
423424

@@ -1255,16 +1256,14 @@ def _get_container_type_key(
12551256
def update(
12561257
self,
12571258
model_deployment_id: str,
1258-
update_model_deployment_details: Optional[
1259-
UpdateModelGroupDeploymentDetails
1260-
] = None,
1259+
update_model_deployment_details: Optional[UpdateModelDeploymentDetails] = None,
12611260
**kwargs,
12621261
) -> AquaDeployment:
12631262
"""Updates a AQUA model group deployment.
12641263
12651264
Args:
1266-
update_model_deployment_details : UpdateModelGroupDeploymentDetails, optional
1267-
An instance of UpdateModelGroupDeploymentDetails containing all optional
1265+
update_model_deployment_details : UpdateModelDeploymentDetails, optional
1266+
An instance of UpdateModelDeploymentDetails containing all optional
12681267
fields for updating a model deployment via Aqua.
12691268
kwargs:
12701269
display_name (str): The name of the model deployment.
@@ -1289,15 +1288,15 @@ def update(
12891288
"""
12901289
if not update_model_deployment_details:
12911290
try:
1292-
update_model_deployment_details = UpdateModelGroupDeploymentDetails(
1293-
**kwargs
1294-
)
1291+
update_model_deployment_details = UpdateModelDeploymentDetails(**kwargs)
12951292
except ValidationError as ex:
12961293
custom_errors = build_pydantic_error_message(ex)
12971294
raise AquaValueError(
12981295
f"Invalid parameters for updating a model group deployment. Error details: {custom_errors}."
12991296
) from ex
13001297

1298+
self._validate_input_models(update_model_deployment_details)
1299+
13011300
model_deployment = ModelDeployment.from_id(model_deployment_id)
13021301

13031302
infrastructure = model_deployment.infrastructure
@@ -1308,6 +1307,7 @@ def update(
13081307
"Invalid 'model_deployment_id'. Only model group deployment is supported to update."
13091308
)
13101309

1310+
# updates model group if fine tuned weights changed.
13111311
model = self._update_model_group(
13121312
runtime.model_group_id, update_model_deployment_details
13131313
)
@@ -1324,10 +1324,6 @@ def update(
13241324
update_model_deployment_details.web_concurrency
13251325
or infrastructure.web_concurrency
13261326
)
1327-
.with_private_endpoint_id(
1328-
update_model_deployment_details.private_endpoint_id
1329-
or infrastructure.private_endpoint_id
1330-
)
13311327
)
13321328

13331329
if (
@@ -1358,6 +1354,7 @@ def update(
13581354
memory_in_gbs=update_model_deployment_details.memory_in_gbs,
13591355
)
13601356

1357+
# applies ZDT as default type to update parameters if model group id hasn't been changed
13611358
update_type = ModelDeploymentUpdateType.ZDT
13621359
# applies LIVE update if model group id has been changed
13631360
if runtime.model_group_id != model.id:
@@ -1400,16 +1397,16 @@ def update(
14001397
def _update_model_group(
14011398
self,
14021399
model_group_id: str,
1403-
update_model_deployment_details: UpdateModelGroupDeploymentDetails,
1400+
update_model_deployment_details: UpdateModelDeploymentDetails,
14041401
) -> DataScienceModelGroup:
14051402
"""Creates a new model group if fine tuned weights changed.
14061403
14071404
Parameters
14081405
----------
14091406
model_group_id: str
14101407
The model group id.
1411-
update_model_deployment_details: UpdateModelGroupDeploymentDetails
1412-
An instance of UpdateModelGroupDeploymentDetails containing all optional
1408+
update_model_deployment_details: UpdateModelDeploymentDetails
1409+
An instance of UpdateModelDeploymentDetails containing all optional
14131410
fields for updating a model deployment via Aqua.
14141411
14151412
Returns
@@ -1462,6 +1459,10 @@ def _update_model_group(
14621459
.create()
14631460
)
14641461

1462+
logger.info(
1463+
f"Model group of base model {target_base_model_id} has been updated: {model_group.id}."
1464+
)
1465+
14651466
return model_group
14661467

14671468
@telemetry(entry_point="plugin=deployment&action=list", name="aqua")

0 commit comments

Comments
 (0)