|
8 | 8 | from pydantic import BaseModel, Field, model_validator |
9 | 9 |
|
10 | 10 | from ads.aqua import logger |
11 | | -from ads.aqua.common.entities import AquaMultiModelRef |
| 11 | +from ads.aqua.common.entities import AquaMultiModelRef, LoraModuleSpec |
12 | 12 | from ads.aqua.common.enums import Tags |
13 | 13 | from ads.aqua.common.errors import AquaValueError |
| 14 | +from ads.aqua.common.utils import is_valid_ocid |
14 | 15 | from ads.aqua.config.utils.serializer import Serializable |
15 | 16 | from ads.aqua.constants import ( |
16 | 17 | AQUA_FINE_TUNE_MODEL_VERSION, |
@@ -717,34 +718,65 @@ def validate_ft_model_v2( |
717 | 718 | f"Invalid fine-tuned model ID '{base_model.id}': for fine tuned models like Phi4, the deployment is not supported. " |
718 | 719 | ) |
719 | 720 |
|
720 | | - def validate_base_model(self, model_id: str) -> None: |
| 721 | + def validate_base_model(self, model_id: str) -> Union[str, AquaMultiModelRef]: |
721 | 722 | """ |
722 | 723 | Validates the input base model for single model deployment configuration. |
723 | 724 |
|
724 | 725 | Validation Criteria: |
725 | | - - Fine-tuned models are not supported in single model deployment. |
| 726 | + - Legacy fine-tuned models will be deployed as single model deployment. |
| 727 | + - Fine-tuned models v2 will be deployed as stacked deployment. |
726 | 728 |
|
727 | 729 | Parameters |
728 | 730 | ---------- |
729 | 731 | model_id : str |
730 | 732 | The OCID of DataScienceModel instance. |
731 | 733 |
|
| 734 | + Returns |
| 735 | + ------- |
| 736 | + Union[str, AquaMultiModelRef] |
| 737 | + A string of model id or an instance of AquaMultiModelRef. |
| 738 | +
|
732 | 739 | Raises |
733 | 740 | ------ |
734 | 741 | ConfigValidationError |
735 | 742 | If any of the above conditions are violated. |
736 | 743 | """ |
737 | 744 | base_model = DataScienceModel.from_id(model_id) |
738 | | - if Tags.AQUA_FINE_TUNED_MODEL_TAG in base_model.freeform_tags: |
739 | | - logger.error( |
740 | | - "Validation failed: Fine-tuned model ID '%s' is not supported for single-model deployment.", |
741 | | - base_model.id, |
742 | | - ) |
743 | | - raise ConfigValidationError( |
744 | | - f"Invalid base model ID '{base_model.id}': " |
745 | | - "single-model deployment does not support fine-tuned models. " |
746 | | - f"Please deploy the fine-tuned model '{base_model.id}' as a stacked model deployment instead." |
| 745 | + freeform_tags = base_model.freeform_tags |
| 746 | + aqua_fine_tuned_model = freeform_tags.get( |
| 747 | + Tags.AQUA_FINE_TUNED_MODEL_TAG, UNKNOWN |
| 748 | + ) |
| 749 | + if aqua_fine_tuned_model: |
| 750 | + fine_tuned_model_version = freeform_tags.get( |
| 751 | + Tags.AQUA_FINE_TUNE_MODEL_VERSION, UNKNOWN |
747 | 752 | ) |
| 753 | + # TODO: revisit to block deploying single fine tuned model after AQUA UI is integrated. |
| 754 | + if fine_tuned_model_version.lower() == AQUA_FINE_TUNE_MODEL_VERSION: |
| 755 | + # extracts base model id from tag 'aqua_fine_tuned_model' and builds AquaMultiModelRef instance for stacked deployment. |
| 756 | + logger.debug( |
| 757 | + f"Detected base model is fine-tuned model {AQUA_FINE_TUNE_MODEL_VERSION} and switched to stack deployment." |
| 758 | + ) |
| 759 | + segments = aqua_fine_tuned_model.split("#") |
| 760 | + if not segments or not is_valid_ocid(segments[0]): |
| 761 | + logger.error( |
| 762 | + "Validation failed: Fine-tuned model ID '%s' is not supported for model deployment.", |
| 763 | + base_model.id, |
| 764 | + ) |
| 765 | + raise ConfigValidationError( |
| 766 | + f"Invalid fine-tuned model ID '{base_model.id}': missing or invalid tag '{Tags.AQUA_FINE_TUNED_MODEL_TAG}' format. " |
| 767 | + f"Make sure tag '{Tags.AQUA_FINE_TUNED_MODEL_TAG}' is added with format <service_model_id>#<service_model_name>." |
| 768 | + ) |
| 769 | + # reset the model_id and models in create_model_deployment_details for stack deployment |
| 770 | + self.model_id = None |
| 771 | + self.models = [ |
| 772 | + AquaMultiModelRef( |
| 773 | + model_id=segments[0], |
| 774 | + fine_tune_weights=[LoraModuleSpec(model_id=base_model.id)], |
| 775 | + ) |
| 776 | + ] |
| 777 | + return self.models[0] |
| 778 | + |
| 779 | + return model_id |
748 | 780 |
|
749 | 781 | class Config: |
750 | 782 | extra = "allow" |
|
0 commit comments