Skip to content

Commit 129cd8f

Browse files
authored
[AQUA][STMD]Updated logic to deploy single ft model (#1269)
2 parents 049af26 + d48c70d commit 129cd8f

File tree

3 files changed

+51
-17
lines changed

3 files changed

+51
-17
lines changed

ads/aqua/modeldeployment/deployment.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,9 @@ def create(
242242
model = create_deployment_details.models[0]
243243
else:
244244
try:
245-
create_deployment_details.validate_base_model(model_id=model)
245+
model = create_deployment_details.validate_base_model(
246+
model_id=model
247+
)
246248
except ConfigValidationError as err:
247249
raise AquaValueError(f"{err}") from err
248250

ads/aqua/modeldeployment/entities.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
from pydantic import BaseModel, Field, model_validator
99

1010
from ads.aqua import logger
11-
from ads.aqua.common.entities import AquaMultiModelRef
11+
from ads.aqua.common.entities import AquaMultiModelRef, LoraModuleSpec
1212
from ads.aqua.common.enums import Tags
1313
from ads.aqua.common.errors import AquaValueError
14+
from ads.aqua.common.utils import is_valid_ocid
1415
from ads.aqua.config.utils.serializer import Serializable
1516
from ads.aqua.constants import (
1617
AQUA_FINE_TUNE_MODEL_VERSION,
@@ -717,34 +718,65 @@ def validate_ft_model_v2(
717718
f"Invalid fine-tuned model ID '{base_model.id}': for fine tuned models like Phi4, the deployment is not supported. "
718719
)
719720

720-
def validate_base_model(self, model_id: str) -> None:
721+
def validate_base_model(self, model_id: str) -> Union[str, AquaMultiModelRef]:
721722
"""
722723
Validates the input base model for single model deployment configuration.
723724
724725
Validation Criteria:
725-
- Fine-tuned models are not supported in single model deployment.
726+
- Legacy fine-tuned models will be deployed as single model deployment.
727+
- Fine-tuned models v2 will be deployed as stacked deployment.
726728
727729
Parameters
728730
----------
729731
model_id : str
730732
The OCID of DataScienceModel instance.
731733
734+
Returns
735+
-------
736+
Union[str, AquaMultiModelRef]
737+
A string of model id or an instance of AquaMultiModelRef.
738+
732739
Raises
733740
------
734741
ConfigValidationError
735742
If any of the above conditions are violated.
736743
"""
737744
base_model = DataScienceModel.from_id(model_id)
738-
if Tags.AQUA_FINE_TUNED_MODEL_TAG in base_model.freeform_tags:
739-
logger.error(
740-
"Validation failed: Fine-tuned model ID '%s' is not supported for single-model deployment.",
741-
base_model.id,
742-
)
743-
raise ConfigValidationError(
744-
f"Invalid base model ID '{base_model.id}': "
745-
"single-model deployment does not support fine-tuned models. "
746-
f"Please deploy the fine-tuned model '{base_model.id}' as a stacked model deployment instead."
745+
freeform_tags = base_model.freeform_tags
746+
aqua_fine_tuned_model = freeform_tags.get(
747+
Tags.AQUA_FINE_TUNED_MODEL_TAG, UNKNOWN
748+
)
749+
if aqua_fine_tuned_model:
750+
fine_tuned_model_version = freeform_tags.get(
751+
Tags.AQUA_FINE_TUNE_MODEL_VERSION, UNKNOWN
747752
)
753+
# TODO: revisit to block deploying single fine tuned model after AQUA UI is integrated.
754+
if fine_tuned_model_version.lower() == AQUA_FINE_TUNE_MODEL_VERSION:
755+
# extracts base model id from tag 'aqua_fine_tuned_model' and builds AquaMultiModelRef instance for stacked deployment.
756+
logger.debug(
757+
f"Detected base model is fine-tuned model {AQUA_FINE_TUNE_MODEL_VERSION} and switched to stack deployment."
758+
)
759+
segments = aqua_fine_tuned_model.split("#")
760+
if not segments or not is_valid_ocid(segments[0]):
761+
logger.error(
762+
"Validation failed: Fine-tuned model ID '%s' is not supported for model deployment.",
763+
base_model.id,
764+
)
765+
raise ConfigValidationError(
766+
f"Invalid fine-tuned model ID '{base_model.id}': missing or invalid tag '{Tags.AQUA_FINE_TUNED_MODEL_TAG}' format. "
767+
f"Make sure tag '{Tags.AQUA_FINE_TUNED_MODEL_TAG}' is added with format <service_model_id>#<service_model_name>."
768+
)
769+
# reset the model_id and models in create_model_deployment_details for stack deployment
770+
self.model_id = None
771+
self.models = [
772+
AquaMultiModelRef(
773+
model_id=segments[0],
774+
fine_tune_weights=[LoraModuleSpec(model_id=base_model.id)],
775+
)
776+
]
777+
return self.models[0]
778+
779+
return model_id
748780

749781
class Config:
750782
extra = "allow"

tests/unitary/with_extras/aqua/test_deployment.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1539,7 +1539,7 @@ def test_create_deployment_for_foundation_model(
15391539

15401540
mock_validate_base_model.assert_called()
15411541
mock_create.assert_called_with(
1542-
model=TestDataset.MODEL_ID,
1542+
model=mock_validate_base_model.return_value,
15431543
compartment_id=TestDataset.USER_COMPARTMENT_ID,
15441544
project_id=TestDataset.USER_PROJECT_ID,
15451545
freeform_tags=freeform_tags,
@@ -1640,7 +1640,7 @@ def test_create_deployment_for_fine_tuned_model(
16401640

16411641
mock_validate_base_model.assert_called()
16421642
mock_create.assert_called_with(
1643-
model=TestDataset.MODEL_ID,
1643+
model=mock_validate_base_model.return_value,
16441644
compartment_id=TestDataset.USER_COMPARTMENT_ID,
16451645
project_id=TestDataset.USER_PROJECT_ID,
16461646
freeform_tags=None,
@@ -1741,7 +1741,7 @@ def test_create_deployment_for_gguf_model(
17411741

17421742
mock_validate_base_model.assert_called()
17431743
mock_create.assert_called_with(
1744-
model=TestDataset.MODEL_ID,
1744+
model=mock_validate_base_model.return_value,
17451745
compartment_id=TestDataset.USER_COMPARTMENT_ID,
17461746
project_id=TestDataset.USER_PROJECT_ID,
17471747
freeform_tags=None,
@@ -1849,7 +1849,7 @@ def test_create_deployment_for_tei_byoc_embedding_model(
18491849

18501850
mock_validate_base_model.assert_called()
18511851
mock_create.assert_called_with(
1852-
model=TestDataset.MODEL_ID,
1852+
model=mock_validate_base_model.return_value,
18531853
compartment_id=TestDataset.USER_COMPARTMENT_ID,
18541854
project_id=TestDataset.USER_PROJECT_ID,
18551855
freeform_tags=None,

0 commit comments

Comments
 (0)