8282 AquaDeploymentDetail ,
8383 ConfigValidationError ,
8484 CreateModelDeploymentDetails ,
85- UpdateModelGroupDeploymentDetails ,
85+ ModelDeploymentDetails ,
86+ UpdateModelDeploymentDetails ,
8687)
8788from ads .aqua .modeldeployment .model_group_config import ModelGroupConfig
8889from ads .aqua .shaperecommend .recommend import AquaShapeRecommend
@@ -399,14 +400,14 @@ def create(
399400
400401 def _validate_input_models (
401402 self ,
402- create_deployment_details : CreateModelDeploymentDetails ,
403+ deployment_details : ModelDeploymentDetails ,
403404 ):
404- """Validates the base models and associated fine tuned models from 'models' in create_deployment_details for stacked or multi model deployment."""
405+ """Validates the base models and associated fine tuned models from 'models' in create_deployment_details or update_deployment_details for stacked or multi model deployment."""
405406 # Collect all unique model IDs (including fine-tuned models)
406407 source_model_ids = list (
407408 {
408409 model_id
409- for model in create_deployment_details .models
410+ for model in deployment_details .models
410411 for model_id in model .all_model_ids ()
411412 }
412413 )
@@ -417,7 +418,7 @@ def _validate_input_models(
417418 source_models = self .get_multi_source (source_model_ids ) or {}
418419
419420 try :
420- create_deployment_details .validate_input_models (model_details = source_models )
421+ deployment_details .validate_input_models (model_details = source_models )
421422 except ConfigValidationError as err :
422423 raise AquaValueError (f"{ err } " ) from err
423424
@@ -1255,16 +1256,14 @@ def _get_container_type_key(
12551256 def update (
12561257 self ,
12571258 model_deployment_id : str ,
1258- update_model_deployment_details : Optional [
1259- UpdateModelGroupDeploymentDetails
1260- ] = None ,
1259+ update_model_deployment_details : Optional [UpdateModelDeploymentDetails ] = None ,
12611260 ** kwargs ,
12621261 ) -> AquaDeployment :
12631262 """Updates a AQUA model group deployment.
12641263
12651264 Args:
1266- update_model_deployment_details : UpdateModelGroupDeploymentDetails , optional
1267- An instance of UpdateModelGroupDeploymentDetails containing all optional
1265+ update_model_deployment_details : UpdateModelDeploymentDetails , optional
1266+ An instance of UpdateModelDeploymentDetails containing all optional
12681267 fields for updating a model deployment via Aqua.
12691268 kwargs:
12701269 display_name (str): The name of the model deployment.
@@ -1289,15 +1288,15 @@ def update(
12891288 """
12901289 if not update_model_deployment_details :
12911290 try :
1292- update_model_deployment_details = UpdateModelGroupDeploymentDetails (
1293- ** kwargs
1294- )
1291+ update_model_deployment_details = UpdateModelDeploymentDetails (** kwargs )
12951292 except ValidationError as ex :
12961293 custom_errors = build_pydantic_error_message (ex )
12971294 raise AquaValueError (
12981295 f"Invalid parameters for updating a model group deployment. Error details: { custom_errors } ."
12991296 ) from ex
13001297
1298+ self ._validate_input_models (update_model_deployment_details )
1299+
13011300 model_deployment = ModelDeployment .from_id (model_deployment_id )
13021301
13031302 infrastructure = model_deployment .infrastructure
@@ -1308,6 +1307,7 @@ def update(
13081307 "Invalid 'model_deployment_id'. Only model group deployment is supported to update."
13091308 )
13101309
1310+ # updates model group if fine tuned weights changed.
13111311 model = self ._update_model_group (
13121312 runtime .model_group_id , update_model_deployment_details
13131313 )
@@ -1324,10 +1324,6 @@ def update(
13241324 update_model_deployment_details .web_concurrency
13251325 or infrastructure .web_concurrency
13261326 )
1327- .with_private_endpoint_id (
1328- update_model_deployment_details .private_endpoint_id
1329- or infrastructure .private_endpoint_id
1330- )
13311327 )
13321328
13331329 if (
@@ -1358,6 +1354,7 @@ def update(
13581354 memory_in_gbs = update_model_deployment_details .memory_in_gbs ,
13591355 )
13601356
1357+ # applies ZDT as default type to update parameters if model group id hasn't been changed
13611358 update_type = ModelDeploymentUpdateType .ZDT
13621359 # applies LIVE update if model group id has been changed
13631360 if runtime .model_group_id != model .id :
@@ -1400,16 +1397,16 @@ def update(
14001397 def _update_model_group (
14011398 self ,
14021399 model_group_id : str ,
1403- update_model_deployment_details : UpdateModelGroupDeploymentDetails ,
1400+ update_model_deployment_details : UpdateModelDeploymentDetails ,
14041401 ) -> DataScienceModelGroup :
14051402 """Creates a new model group if fine tuned weights changed.
14061403
14071404 Parameters
14081405 ----------
14091406 model_group_id: str
14101407 The model group id.
1411- update_model_deployment_details: UpdateModelGroupDeploymentDetails
1412- An instance of UpdateModelGroupDeploymentDetails containing all optional
1408+ update_model_deployment_details: UpdateModelDeploymentDetails
1409+ An instance of UpdateModelDeploymentDetails containing all optional
14131410 fields for updating a model deployment via Aqua.
14141411
14151412 Returns
@@ -1462,6 +1459,10 @@ def _update_model_group(
14621459 .create ()
14631460 )
14641461
1462+ logger .info (
1463+ f"Model group of base model { target_base_model_id } has been updated: { model_group .id } ."
1464+ )
1465+
14651466 return model_group
14661467
14671468 @telemetry (entry_point = "plugin=deployment&action=list" , name = "aqua" )
0 commit comments