8282 AquaDeploymentDetail ,
8383 ConfigValidationError ,
8484 CreateModelDeploymentDetails ,
85+ ModelDeploymentDetails ,
86+ UpdateModelDeploymentDetails ,
8587)
8688from ads .aqua .modeldeployment .model_group_config import ModelGroupConfig
8789from ads .aqua .shaperecommend .recommend import AquaShapeRecommend
110112 ModelDeploymentInfrastructure ,
111113 ModelDeploymentMode ,
112114)
115+ from ads .model .deployment .model_deployment import (
116+ ModelDeploymentUpdateType ,
117+ )
113118from ads .model .model_metadata import ModelCustomMetadata , ModelCustomMetadataItem
114119from ads .telemetry import telemetry
115120
@@ -397,14 +402,14 @@ def create(
397402
398403 def _validate_input_models (
399404 self ,
400- create_deployment_details : CreateModelDeploymentDetails ,
405+ deployment_details : ModelDeploymentDetails ,
401406 ):
402- """Validates the base models and associated fine tuned models from 'models' in create_deployment_details for stacked or multi model deployment."""
407+ """Validates the base models and associated fine tuned models from 'models' in create_deployment_details or update_deployment_details for stacked or multi model deployment."""
403408 # Collect all unique model IDs (including fine-tuned models)
404409 source_model_ids = list (
405410 {
406411 model_id
407- for model in create_deployment_details .models
412+ for model in deployment_details .models
408413 for model_id in model .all_model_ids ()
409414 }
410415 )
@@ -415,7 +420,7 @@ def _validate_input_models(
415420 source_models = self .get_multi_source (source_model_ids ) or {}
416421
417422 try :
418- create_deployment_details .validate_input_models (model_details = source_models )
423+ deployment_details .validate_input_models (model_details = source_models )
419424 except ConfigValidationError as err :
420425 raise AquaValueError (f"{ err } " ) from err
421426
@@ -1249,6 +1254,219 @@ def _get_container_type_key(
12491254
12501255 return container_type_key
12511256
1257+ @telemetry (entry_point = "plugin=deployment&action=update" , name = "aqua" )
1258+ def update (
1259+ self ,
1260+ model_deployment_id : str ,
1261+ update_model_deployment_details : Optional [UpdateModelDeploymentDetails ] = None ,
1262+ ** kwargs ,
1263+ ) -> AquaDeployment :
1264+ """Updates a AQUA model group deployment.
1265+
1266+ Args:
1267+ update_model_deployment_details : UpdateModelDeploymentDetails, optional
1268+ An instance of UpdateModelDeploymentDetails containing all optional
1269+ fields for updating a model deployment via Aqua.
1270+ kwargs:
1271+ display_name (str): The name of the model deployment.
1272+ description (Optional[str]): The description of the deployment.
1273+ models (Optional[List[AquaMultiModelRef]]): List of models for deployment.
1274+ instance_count (int): Number of instances used for deployment.
1275+ log_group_id (Optional[str]): OCI logging group ID for logs.
1276+ access_log_id (Optional[str]): OCID for access logs.
1277+ predict_log_id (Optional[str]): OCID for prediction logs.
1278+ bandwidth_mbps (Optional[int]): Bandwidth limit on the load balancer in Mbps.
1279+ web_concurrency (Optional[int]): Number of worker processes/threads for handling requests.
1280+ memory_in_gbs (Optional[float]): Memory (in GB) for the selected shape.
1281+ ocpus (Optional[float]): OCPU count for the selected shape.
1282+ freeform_tags (Optional[Dict]): Freeform tags for model deployment.
1283+ defined_tags (Optional[Dict]): Defined tags for model deployment.
1284+
1285+ Returns
1286+ -------
1287+ AquaDeployment
1288+ An Aqua deployment instance.
1289+ """
1290+ if not update_model_deployment_details :
1291+ try :
1292+ update_model_deployment_details = UpdateModelDeploymentDetails (** kwargs )
1293+ except ValidationError as ex :
1294+ custom_errors = build_pydantic_error_message (ex )
1295+ raise AquaValueError (
1296+ f"Invalid parameters for updating a model group deployment. Error details: { custom_errors } ."
1297+ ) from ex
1298+
1299+ model_deployment = ModelDeployment .from_id (model_deployment_id )
1300+
1301+ infrastructure = model_deployment .infrastructure
1302+ runtime = model_deployment .runtime
1303+
1304+ if not runtime .model_group_id :
1305+ raise AquaValueError (
1306+ "Invalid 'model_deployment_id'. Only model group deployment is supported to update."
1307+ )
1308+
1309+ # updates model group if fine tuned weights changed.
1310+ model = self ._update_model_group (
1311+ runtime .model_group_id , update_model_deployment_details
1312+ )
1313+
1314+ # updates model group deployment infrastructure
1315+ (
1316+ infrastructure .with_bandwidth_mbps (
1317+ update_model_deployment_details .bandwidth_mbps
1318+ or infrastructure .bandwidth_mbps
1319+ )
1320+ .with_replica (
1321+ update_model_deployment_details .instance_count or infrastructure .replica
1322+ )
1323+ .with_web_concurrency (
1324+ update_model_deployment_details .web_concurrency
1325+ or infrastructure .web_concurrency
1326+ )
1327+ )
1328+
1329+ if (
1330+ update_model_deployment_details .log_group_id
1331+ and update_model_deployment_details .access_log_id
1332+ ):
1333+ infrastructure .with_access_log (
1334+ log_group_id = update_model_deployment_details .log_group_id ,
1335+ log_id = update_model_deployment_details .access_log_id ,
1336+ )
1337+
1338+ if (
1339+ update_model_deployment_details .log_group_id
1340+ and update_model_deployment_details .predict_log_id
1341+ ):
1342+ infrastructure .with_predict_log (
1343+ log_group_id = update_model_deployment_details .log_group_id ,
1344+ log_id = update_model_deployment_details .predict_log_id ,
1345+ )
1346+
1347+ if (
1348+ update_model_deployment_details .memory_in_gbs
1349+ and update_model_deployment_details .ocpus
1350+ and infrastructure .shape_name .endswith ("Flex" )
1351+ ):
1352+ infrastructure .with_shape_config_details (
1353+ ocpus = update_model_deployment_details .ocpus ,
1354+ memory_in_gbs = update_model_deployment_details .memory_in_gbs ,
1355+ )
1356+
1357+ # applies ZDT as default type to update parameters if model group id hasn't been changed
1358+ update_type = ModelDeploymentUpdateType .ZDT
1359+ # applies LIVE update if model group id has been changed
1360+ if runtime .model_group_id != model .id :
1361+ runtime .with_model_group_id (model .id )
1362+ update_type = ModelDeploymentUpdateType .LIVE
1363+
1364+ freeform_tags = (
1365+ update_model_deployment_details .freeform_tags
1366+ or model_deployment .freeform_tags
1367+ )
1368+ defined_tags = (
1369+ update_model_deployment_details .defined_tags
1370+ or model_deployment .defined_tags
1371+ )
1372+
1373+ # updates model group deployment
1374+ (
1375+ model_deployment .with_display_name (
1376+ update_model_deployment_details .display_name
1377+ or model_deployment .display_name
1378+ )
1379+ .with_description (
1380+ update_model_deployment_details .description
1381+ or model_deployment .description
1382+ )
1383+ .with_freeform_tags (** (freeform_tags or {}))
1384+ .with_defined_tags (** (defined_tags or {}))
1385+ .with_infrastructure (infrastructure )
1386+ .with_runtime (runtime )
1387+ )
1388+
1389+ model_deployment .update (wait_for_completion = False , update_type = update_type )
1390+
1391+ logger .info (f"Updating Aqua Model Deployment { model_deployment .id } ." )
1392+
1393+ return AquaDeployment .from_oci_model_deployment (
1394+ model_deployment .dsc_model_deployment , self .region
1395+ )
1396+
1397+ def _update_model_group (
1398+ self ,
1399+ model_group_id : str ,
1400+ update_model_deployment_details : UpdateModelDeploymentDetails ,
1401+ ) -> DataScienceModelGroup :
1402+ """Creates a new model group if fine tuned weights changed.
1403+
1404+ Parameters
1405+ ----------
1406+ model_group_id: str
1407+ The model group id.
1408+ update_model_deployment_details: UpdateModelDeploymentDetails
1409+ An instance of UpdateModelDeploymentDetails containing all optional
1410+ fields for updating a model deployment via Aqua.
1411+
1412+ Returns
1413+ -------
1414+ DataScienceModelGroup
1415+ The instance of DataScienceModelGroup.
1416+ """
1417+ model_group = DataScienceModelGroup .from_id (model_group_id )
1418+ # create a new model group if fine tune weights changed as member models in ds model group is inmutable
1419+ if update_model_deployment_details .models :
1420+ if len (update_model_deployment_details .models ) != 1 :
1421+ raise AquaValueError (
1422+ "Invalid 'models' provided. Only one base model is required for updating model stack deployment."
1423+ )
1424+ # validates input base and fine tune models
1425+ self ._validate_input_models (update_model_deployment_details )
1426+ target_stacked_model = update_model_deployment_details .models [0 ]
1427+ target_base_model_id = target_stacked_model .model_id
1428+ if model_group .base_model_id != target_base_model_id :
1429+ raise AquaValueError (
1430+ "Invalid parameter 'models'. Base model id can't be changed for stacked model deployment."
1431+ )
1432+
1433+ # add member models
1434+ member_models = [
1435+ {
1436+ "inference_key" : fine_tune_weight .model_name ,
1437+ "model_id" : fine_tune_weight .model_id ,
1438+ }
1439+ for fine_tune_weight in target_stacked_model .fine_tune_weights
1440+ ]
1441+ # add base model
1442+ member_models .append (
1443+ {
1444+ "inference_key" : target_stacked_model .model_name ,
1445+ "model_id" : target_base_model_id ,
1446+ }
1447+ )
1448+
1449+ # creates a model group with the same configurations from original model group except member models
1450+ model_group = (
1451+ DataScienceModelGroup ()
1452+ .with_compartment_id (model_group .compartment_id )
1453+ .with_project_id (model_group .project_id )
1454+ .with_display_name (model_group .display_name )
1455+ .with_description (model_group .description )
1456+ .with_freeform_tags (** (model_group .freeform_tags or {}))
1457+ .with_defined_tags (** (model_group .defined_tags or {}))
1458+ .with_custom_metadata_list (model_group .custom_metadata_list )
1459+ .with_base_model_id (target_base_model_id )
1460+ .with_member_models (member_models )
1461+ .create ()
1462+ )
1463+
1464+ logger .info (
1465+ f"Model group of base model { target_base_model_id } has been updated: { model_group .id } ."
1466+ )
1467+
1468+ return model_group
1469+
12521470 @telemetry (entry_point = "plugin=deployment&action=list" , name = "aqua" )
12531471 def list (self , ** kwargs ) -> List ["AquaDeployment" ]:
12541472 """List Aqua model deployments in a given compartment and under certain project.
0 commit comments