|
| 1 | +# coding=utf-8 |
| 2 | +""" |
| 3 | +讯飞 TTS 工厂类 Credential,根据 api_version 路由到具体 Credential |
| 4 | +""" |
| 5 | +from typing import Dict |
| 6 | + |
| 7 | +from django.utils.translation import gettext_lazy as _, gettext |
| 8 | + |
| 9 | +from common import forms |
| 10 | +from common.exception.app_exception import AppApiException |
| 11 | +from common.forms import BaseForm, TooltipLabel |
| 12 | +from models_provider.base_model_provider import BaseModelCredential, ValidCode |
| 13 | +from common.utils.logger import maxkb_logger |
| 14 | + |
| 15 | + |
| 16 | +class XunFeiDefaultTTSModelCredential(BaseForm, BaseModelCredential): |
| 17 | + """讯飞 TTS 工厂类 Credential,根据 api_version 参数路由到具体实现""" |
| 18 | + |
| 19 | + api_version = forms.SingleSelect( |
| 20 | + _("API Version"), required=True, |
| 21 | + text_field='label', |
| 22 | + value_field='value', |
| 23 | + default_value='online', |
| 24 | + option_list=[ |
| 25 | + {'label': _('Online TTS'), 'value': 'online'}, |
| 26 | + {'label': _('Super Humanoid TTS'), 'value': 'super_humanoid'} |
| 27 | + ]) |
| 28 | + |
| 29 | + spark_api_url = forms.TextInputField('API URL', required=True, |
| 30 | + default_value='wss://tts-api.xfyun.cn/v2/tts', |
| 31 | + relation_show_field_dict={"api_version": ["online"]}) |
| 32 | + spark_api_url_super = forms.TextInputField('API URL', required=True, |
| 33 | + default_value='wss://cbm01.cn-huabei-1.xf-yun.com/v1/private/mcd9m97e6', |
| 34 | + relation_show_field_dict={"api_version": ["super_humanoid"]}) |
| 35 | + |
| 36 | + # vcn 选择放在 credential 中,根据 api_version 联动显示 |
| 37 | + vcn_online = forms.SingleSelect( |
| 38 | + TooltipLabel(_('Speaker'), _('Speaker selection for standard TTS service')), |
| 39 | + required=True, default_value='xiaoyan', |
| 40 | + text_field='value', |
| 41 | + value_field='value', |
| 42 | + option_list=[ |
| 43 | + {'text': _('iFlytek Xiaoyan'), 'value': 'xiaoyan'}, |
| 44 | + {'text': _('iFlytek Xujiu'), 'value': 'aisjiuxu'}, |
| 45 | + {'text': _('iFlytek Xiaoping'), 'value': 'aisxping'}, |
| 46 | + {'text': _('iFlytek Xiaojing'), 'value': 'aisjinger'}, |
| 47 | + {'text': _('iFlytek Xuxiaobao'), 'value': 'aisbabyxu'}, |
| 48 | + ], |
| 49 | + relation_show_field_dict={"api_version": ["online"]}) |
| 50 | + |
| 51 | + vcn_super = forms.SingleSelect( |
| 52 | + TooltipLabel(_('Speaker'), _('Speaker selection for super-humanoid TTS service')), |
| 53 | + required=True, default_value='x5_lingxiaoxuan_flow', |
| 54 | + text_field='value', |
| 55 | + value_field='value', |
| 56 | + option_list=[ |
| 57 | + {'text': _('Super-humanoid: Lingxiaoxuan Flow'), 'value': 'x5_lingxiaoxuan_flow'}, |
| 58 | + {'text': _('Super-humanoid: Lingyuyan Flow'), 'value': 'x5_lingyuyan_flow'}, |
| 59 | + {'text': _('Super-humanoid: Lingfeiyi Flow'), 'value': 'x5_lingfeiyi_flow'}, |
| 60 | + {'text': _('Super-humanoid: Lingxiaoyue Flow'), 'value': 'x5_lingxiaoyue_flow'}, |
| 61 | + {'text': _('Super-humanoid: Sun Dasheng Flow'), 'value': 'x5_sundasheng_flow'}, |
| 62 | + {'text': _('Super-humanoid: Lingyuzhao Flow'), 'value': 'x5_lingyuzhao_flow'}, |
| 63 | + {'text': _('Super-humanoid: Lingxiaotang Flow'), 'value': 'x5_lingxiaotang_flow'}, |
| 64 | + {'text': _('Super-humanoid: Lingxiaorong Flow'), 'value': 'x5_lingxiaorong_flow'}, |
| 65 | + {'text': _('Super-humanoid: Xinyun Flow'), 'value': 'x5_xinyun_flow'}, |
| 66 | + {'text': _('Super-humanoid: Grant (EN)'), 'value': 'x5_EnUs_Grant_flow'}, |
| 67 | + {'text': _('Super-humanoid: Lila (EN)'), 'value': 'x5_EnUs_Lila_flow'}, |
| 68 | + {'text': _('Super-humanoid: Lingwanwan Pro'), 'value': 'x6_lingwanwan_pro'}, |
| 69 | + {'text': _('Super-humanoid: Yiyi Pro'), 'value': 'x6_yiyi_pro'}, |
| 70 | + {'text': _('Super-humanoid: Huifangnv Pro'), 'value': 'x6_huifangnv_pro'}, |
| 71 | + {'text': _('Super-humanoid: Lingxiaoying Pro'), 'value': 'x6_lingxiaoying_pro'}, |
| 72 | + {'text': _('Super-humanoid: Lingfeibo Pro'), 'value': 'x6_lingfeibo_pro'}, |
| 73 | + {'text': _('Super-humanoid: Lingyuyan Pro'), 'value': 'x6_lingyuyan_pro'}, |
| 74 | + ], |
| 75 | + relation_show_field_dict={"api_version": ["super_humanoid"]}) |
| 76 | + |
| 77 | + spark_app_id = forms.TextInputField('APP ID', required=True) |
| 78 | + spark_api_key = forms.PasswordInputField("API Key", required=True) |
| 79 | + spark_api_secret = forms.PasswordInputField('API Secret', required=True) |
| 80 | + |
| 81 | + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, |
| 82 | + raise_exception=False): |
| 83 | + model_type_list = provider.get_model_type_list() |
| 84 | + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): |
| 85 | + raise AppApiException(ValidCode.valid_error.value, |
| 86 | + gettext('{model_type} Model type is not supported').format(model_type=model_type)) |
| 87 | + |
| 88 | + api_version = model_credential.get('api_version', 'online') |
| 89 | + if api_version == 'super_humanoid': |
| 90 | + required_keys = ['spark_api_url_super', 'spark_app_id', 'spark_api_key', 'spark_api_secret'] |
| 91 | + else: |
| 92 | + required_keys = ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret'] |
| 93 | + |
| 94 | + for key in required_keys: |
| 95 | + if key not in model_credential: |
| 96 | + if raise_exception: |
| 97 | + raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key)) |
| 98 | + else: |
| 99 | + return False |
| 100 | + try: |
| 101 | + model = provider.get_model(model_type, model_name, model_credential, **model_params) |
| 102 | + model.check_auth() |
| 103 | + except Exception as e: |
| 104 | + maxkb_logger.error(f'Exception: {e}', exc_info=True) |
| 105 | + if isinstance(e, AppApiException): |
| 106 | + raise e |
| 107 | + if raise_exception: |
| 108 | + raise AppApiException(ValidCode.valid_error.value, |
| 109 | + gettext( |
| 110 | + 'Verification failed, please check whether the parameters are correct: {error}').format( |
| 111 | + error=str(e))) |
| 112 | + else: |
| 113 | + return False |
| 114 | + return True |
| 115 | + |
| 116 | + def encryption_dict(self, model: Dict[str, object]): |
| 117 | + return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))} |
| 118 | + |
| 119 | + def get_model_params_setting_form(self, model_name): |
| 120 | + # params 只包含通用参数,vcn 已在 credential 中 |
| 121 | + return XunFeiDefaultTTSModelParams() |
| 122 | + |
| 123 | + |
| 124 | +class XunFeiDefaultTTSModelParams(BaseForm): |
| 125 | + """工厂类的参数表单,只包含通用参数""" |
| 126 | + |
| 127 | + speed = forms.SliderField( |
| 128 | + TooltipLabel(_('speaking speed'), _('Speech speed, optional value: [0-100], default is 50')), |
| 129 | + required=True, default_value=50, |
| 130 | + _min=1, |
| 131 | + _max=100, |
| 132 | + _step=5, |
| 133 | + precision=1) |
0 commit comments