diff --git a/.gitignore b/.gitignore index b43a8de3d6..4744a14f31 100644 --- a/.gitignore +++ b/.gitignore @@ -175,3 +175,4 @@ uploadfile *.json .vscode .cursor +deploy/ diff --git a/main/manager-api/src/main/resources/db/changelog/202506011728.sql b/main/manager-api/src/main/resources/db/changelog/202506011728.sql new file mode 100644 index 0000000000..44fe042aa1 --- /dev/null +++ b/main/manager-api/src/main/resources/db/changelog/202506011728.sql @@ -0,0 +1,17 @@ +-- 增加火山大模型网关ASR供应器和模型配置 +DELETE FROM `ai_model_provider` WHERE `id` = 'SYSTEM_ASR_VOLC_GW'; +INSERT INTO `ai_model_provider` (`id`, `model_type`, `provider_code`, `name`, `fields`, `sort`, `creator`, `create_date`, `updater`, `update_date`) VALUES +('SYSTEM_ASR_VOLC_GW', 'ASR', 'volcengine', '火山引擎边缘大模型网关', '[{"key":"api_key","label":"网关秘钥","type":"string"},{"key":"model_name","label":"模型名称","type":"string"},{"key":"host","label":"网关域名","type":"string"},{"key":"output_dir","label":"输出目录","type":"string"}]', 1, 1, NOW(), 1, NOW()); + +DELETE FROM `ai_model_config` WHERE `id` = 'ASR_VolceAIGateway'; +INSERT INTO `ai_model_config` VALUES ('ASR_VolceAIGateway', 'ASR', 'VolceAIGateway', '火山引擎边缘大模型网关', 0, 1, '{\"type\": \"volcengine\", \"api_key\": \"火山引擎边缘大模型网关的秘钥\", \"model_name\": \"bigmodel\", \"host\": \"ai-gateway.vei.volces.com\", \"output_dir\": \"tmp/\"}', NULL, NULL, 16, 1, NOW(), 1, NOW()); + +-- 火山大模型网关ASR模型配置说明文档 +UPDATE `ai_model_config` SET +`doc_link` = 'https://console.volcengine.com/vei/aigateway/', +`remark` = '火山引擎边缘大模型网关ASR配置说明: +1. 访问 https://console.volcengine.com/vei/aigateway/ +2. 创建网关访问密钥(个人用户申请时注明来自小智xiaozhi-esp32-server社区,并描述使用背景,可更快获得审批,并有机会获得更多token) +3. 搜索并勾选 Doubao-语音识别,如果需要使用LLM,一并勾选 Doubao-pro-32k-functioncall +4. 填入配置文件中' WHERE `id` = 'ASR_VolceAIGateway'; + diff --git a/main/manager-api/src/main/resources/db/changelog/202506031555.sql b/main/manager-api/src/main/resources/db/changelog/202506031555.sql new file mode 100644 index 0000000000..8c9702d25a --- /dev/null +++ b/main/manager-api/src/main/resources/db/changelog/202506031555.sql @@ -0,0 +1,115 @@ +-- 增加火山大模型网关ASR供应器 +DELETE FROM `ai_model_provider` WHERE `id` = 'SYSTEM_ASR_VOLC_GW'; +INSERT INTO `ai_model_provider` (`id`, `model_type`, `provider_code`, `name`, `fields`, `sort`, `creator`, `create_date`, `updater`, `update_date`) VALUES +('SYSTEM_ASR_VOLC_GW', 'ASR', 'volcengine', '火山引擎边缘大模型网关', '[{"key":"api_key","label":"网关秘钥","type":"string"},{"key":"model_name","label":"模型名称","type":"string"},{"key":"host","label":"网关域名","type":"string"},{"key":"output_dir","label":"输出目录","type":"string"}]', 1, 1, NOW(), 1, NOW()); +-- 增加火山大模型网关TTS供应器 +DELETE FROM `ai_model_provider` WHERE `id` = 'SYSTEM_TTS_VOLC_GW'; +INSERT INTO `ai_model_provider` (`id`, `model_type`, `provider_code`, `name`, `fields`, `sort`, `creator`, `create_date`, `updater`, `update_date`) VALUES +('SYSTEM_TTS_VOLC_GW', 'TTS', 'volcengine', '火山引擎边缘大模型网关', '[{"key":"api_key","label":"网关秘钥","type":"string"},{"key":"model_name","label":"模型名称","type":"string"},{"key":"host","label":"网关域名","type":"string"},{"key":"output_dir","label":"输出目录","type":"string"}]', 1, 1, NOW(), 1, NOW()); +-- 增加火山大模型网关LLM供应器 +DELETE FROM `ai_model_provider` WHERE `id` = 'SYSTEM_LLM_VOLC_GW'; +INSERT INTO `ai_model_provider` (`id`, `model_type`, `provider_code`, `name`, `fields`, `sort`, `creator`, `create_date`, `updater`, `update_date`) VALUES +('SYSTEM_LLM_VOLC_GW', 'LLM', 'volcengine', '火山引擎边缘大模型网关', '[{"key":"api_key","label":"网关秘钥","type":"string"},{"key":"model_name","label":"模型名称","type":"string"},{"key":"host","label":"网关域名","type":"string"},{"key":"output_dir","label":"输出目录","type":"string"}]', 1, 1, NOW(), 1, NOW()); +-- 增加火山大模型网关VLM供应器 +DELETE FROM `ai_model_provider` WHERE `id` = 'SYSTEM_VLLM_VOLC_GW'; +INSERT INTO `ai_model_provider` (`id`, `model_type`, `provider_code`, `name`, `fields`, `sort`, `creator`, `create_date`, `updater`, `update_date`) VALUES +('SYSTEM_VLLM_VOLC_GW', 'VLLM', 'volcengine', '火山引擎边缘大模型网关', '[{"key":"api_key","label":"网关秘钥","type":"string"},{"key":"model_name","label":"模型名称","type":"string"},{"key":"host","label":"网关域名","type":"string"}]', 1, 1, NOW(), 1, NOW()); + + +-- 增加火山大模型网关ASR模型配置 +DELETE FROM `ai_model_config` WHERE `id` = 'ASR_VolceAIGateway'; +INSERT INTO `ai_model_config` VALUES ('ASR_VolceAIGateway', 'ASR', 'VolceAIGateway', '火山引擎边缘大模型网关', 0, 1, '{\"type\": \"volcengine\", \"api_key\": \"火山引擎边缘大模型网关的秘钥\", \"model_name\": \"bigmodel\", \"host\": \"ai-gateway.vei.volces.com\", \"output_dir\": \"tmp/\"}', NULL, NULL, 16, 1, NOW(), 1, NOW()); +-- 增加火山大模型网关TTS模型配置 +DELETE FROM `ai_model_config` WHERE `id` = 'TTS_VolcesAiGatewayTTS'; +DELETE FROM `ai_model_config` WHERE `id` = 'TTS_VolceAIGateway'; +INSERT INTO `ai_model_config` VALUES ('TTS_VolceAIGateway', 'TTS', 'VolceAIGateway', '火山引擎边缘大模型网关', 0, 1, '{\"type\": \"volcengine\", \"api_key\": \"火山引擎边缘大模型网关的秘钥\", \"model_name\": \"doubao-tts\", \"host\": \"ai-gateway.vei.volces.com\",\"voice\": \"zh_male_shaonianzixin_moon_bigtts\", \"speed\": 1, \"output_dir\": \"tmp/\"}', NULL, NULL, 16, 1, NOW(), 1, NOW()); +-- 增加火山大模型网关LLM模型配置 +DELETE FROM `ai_model_config` WHERE `id` = 'LLM_VolcesAiGatewayLLM'; +DELETE FROM `ai_model_config` WHERE `id` = 'LLM_VolceAIGateway'; +INSERT INTO `ai_model_config` VALUES ('LLM_VolceAIGateway', 'LLM', 'VolceAIGateway', '火山引擎边缘大模型网关', 0, 1, '{\"type\": \"volcengine\", \"api_key\": \"火山引擎边缘大模型网关的秘钥\", \"model_name\": \"doubao-pro-32k-functioncall\", \"host\": \"ai-gateway.vei.volces.com\"}', NULL, NULL, 16, 1, NOW(), 1, NOW()); +-- 增加火山大模型网关VLLM模型配置 +DELETE FROM `ai_model_config` WHERE `id` = 'VLLM_VolceAIGateway'; +INSERT INTO `ai_model_config` VALUES ('VLLM_VolceAIGateway', 'VLLM', 'VolceAIGateway', '火山引擎边缘大模型网关', 0, 1, '{\"type\": \"volcengine\", \"api_key\": \"火山引擎边缘大模型网关的秘钥\", \"model_name\": \"doubao-1.5-vision-pro-32k\", \"host\": \"ai-gateway.vei.volces.com\"}', NULL, NULL, 16, 1, NOW(), 1, NOW()); + + +-- 火山大模型网关ASR模型配置说明文档 +UPDATE `ai_model_config` SET +`doc_link` = 'https://console.volcengine.com/vei/aigateway/', +`remark` = '火山引擎边缘大模型网关ASR配置说明: +1. 访问 https://console.volcengine.com/vei/aigateway/ +2. 创建网关访问密钥(个人用户申请时注明来自小智xiaozhi-esp32-server社区,并描述使用背景,可更快获得审批,并有机会获得更多token) +3. 搜索并勾选 Doubao-语音识别,网关支持一个api_key访问ASR,LLM,TTS,VLLM模型,满足智能体使用,推荐同时开通“Doubao-语音识别”、“Doubao-语音合成”、“Doubao-pro-32k-functioncall”、“Doubao-1.5-vision-pro”全量模型 +4. 填入配置文件中' WHERE `id` = 'ASR_VolceAIGateway'; +-- 火山大模型网关TTS模型配置说明文档 +UPDATE `ai_model_config` SET +`doc_link` = 'https://console.volcengine.com/vei/aigateway/', +`remark` = '火山引擎边缘大模型网关TTS配置说明: +1. 访问 https://console.volcengine.com/vei/aigateway/ +2. 创建网关访问密钥(个人用户申请时注明来自小智xiaozhi-esp32-server社区,并描述使用背景,可更快获得审批,并有机会获得更多token) +3. 搜索并勾选 Doubao-语音合成,网关支持一个api_key访问ASR,LLM,TTS,VLLM模型,满足智能体使用,推荐同时开通“Doubao-语音识别”、“Doubao-语音合成”、“Doubao-pro-32k-functioncall”、“Doubao-1.5-vision-pro”全量模型 +4. 填入配置文件中' WHERE `id` = 'TTS_VolceAIGateway'; +-- 火山大模型网关LLM模型配置说明文档 +UPDATE `ai_model_config` SET +`doc_link` = 'https://console.volcengine.com/vei/aigateway/', +`remark` = '火山引擎边缘大模型网关LLM配置说明: +1. 访问 https://console.volcengine.com/vei/aigateway/ +2. 创建网关访问密钥(个人用户申请时注明来自小智xiaozhi-esp32-server社区,并描述使用背景,可更快获得审批,并有机会获得更多token) +3. 搜索并勾选 Doubao-pro-32k-functioncall,网关支持一个api_key访问ASR,LLM,TTS,VLLM模型,满足智能体使用,推荐同时开通“Doubao-语音识别”、“Doubao-语音合成”、“Doubao-pro-32k-functioncall”、“Doubao-1.5-vision-pro”全量模型 +4. 填入配置文件中' WHERE `id` = 'LLM_VolceAIGateway'; +-- 火山大模型网关VLLM模型配置说明文档 +UPDATE `ai_model_config` SET +`doc_link` = 'https://console.volcengine.com/vei/aigateway/', +`remark` = '火山引擎边缘大模型网关VLM配置说明: +1. 访问 https://console.volcengine.com/vei/aigateway/ +2. 创建网关访问密钥(个人用户申请时注明来自小智xiaozhi-esp32-server社区,并描述使用背景,可更快获得审批,并有机会获得更多token) +3. 搜索并勾选 Doubao-1.5-vision-pro,网关支持一个api_key访问ASR,LLM,TTS,VLLM模型,满足智能体使用,推荐同时开通“Doubao-语音识别”、“Doubao-语音合成”、“Doubao-pro-32k-functioncall”、“Doubao-1.5-vision-pro”全量模型 +4. 填入配置文件中' WHERE `id` = 'VLLM_VolceAIGateway'; + + +-- 添加火山引擎边缘大模型网关语音合成音色 +DELETE FROM `ai_tts_voice` WHERE `tts_model_id` = 'TTS_VolceAIGateway'; +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0001', 'TTS_VolceAIGateway', '灿灿/Shiny', 'zh_female_cancan_mars_bigtts', '中文、美式英语', NULL, NULL, 1, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0002', 'TTS_VolceAIGateway', '清新女声', 'zh_female_qingxinnvsheng_mars_bigtts', '中文', NULL, NULL, 2, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0003', 'TTS_VolceAIGateway', '爽快思思/Skye', 'zh_female_shuangkuaisisi_moon_bigtts', '中文、美式英语', NULL, NULL, 3, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0004', 'TTS_VolceAIGateway', '温暖阿虎/Alvin', 'zh_male_wennuanahu_moon_bigtts', '中文、美式英语', NULL, NULL, 4, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0005', 'TTS_VolceAIGateway', '少年梓辛/Brayan', 'zh_male_shaonianzixin_moon_bigtts', '中文、美式英语', NULL, NULL, 5, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0006', 'TTS_VolceAIGateway', '知性女声', 'zh_female_zhixingnvsheng_mars_bigtts', '中文', NULL, NULL, 6, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0007', 'TTS_VolceAIGateway', '清爽男大', 'zh_male_qingshuangnanda_mars_bigtts', '中文', NULL, NULL, 7, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0008', 'TTS_VolceAIGateway', '邻家女孩', 'zh_female_linjianvhai_moon_bigtts', '中文', NULL, NULL, 8, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0009', 'TTS_VolceAIGateway', '渊博小叔', 'zh_male_yuanboxiaoshu_moon_bigtts', '中文', NULL, NULL, 9, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0010', 'TTS_VolceAIGateway', '阳光青年', 'zh_male_yangguangqingnian_moon_bigtts', '中文', NULL, NULL, 10, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0011', 'TTS_VolceAIGateway', '甜美小源', 'zh_female_tianmeixiaoyuan_moon_bigtts', '中文', NULL, NULL, 11, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0012', 'TTS_VolceAIGateway', '清澈梓梓', 'zh_female_qingchezizi_moon_bigtts', '中文', NULL, NULL, 12, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0013', 'TTS_VolceAIGateway', '解说小明', 'zh_male_jieshuoxiaoming_moon_bigtts', '中文', NULL, NULL, 13, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0014', 'TTS_VolceAIGateway', '开朗姐姐', 'zh_female_kailangjiejie_moon_bigtts', '中文', NULL, NULL, 14, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0015', 'TTS_VolceAIGateway', '邻家男孩', 'zh_male_linjiananhai_moon_bigtts', '中文', NULL, NULL, 15, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0016', 'TTS_VolceAIGateway', '甜美悦悦', 'zh_female_tianmeiyueyue_moon_bigtts', '中文', NULL, NULL, 16, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0017', 'TTS_VolceAIGateway', '心灵鸡汤', 'zh_female_xinlingjitang_moon_bigtts', '中文', NULL, NULL, 17, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0018', 'TTS_VolceAIGateway', '知性温婉', 'ICL_zh_female_zhixingwenwan_tob', '中文', NULL, NULL, 18, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0019', 'TTS_VolceAIGateway', '暖心体贴', 'ICL_zh_male_nuanxintitie_tob', '中文', NULL, NULL, 19, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0020', 'TTS_VolceAIGateway', '温柔文雅', 'ICL_zh_female_wenrouwenya_tob', '中文', NULL, NULL, 20, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0021', 'TTS_VolceAIGateway', '开朗轻快', 'ICL_zh_male_kailangqingkuai_tob', '中文', NULL, NULL, 21, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0022', 'TTS_VolceAIGateway', '活泼爽朗', 'ICL_zh_male_huoposhuanglang_tob', '中文', NULL, NULL, 22, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0023', 'TTS_VolceAIGateway', '率真小伙', 'ICL_zh_male_shuaizhenxiaohuo_tob', '中文', NULL, NULL, 23, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0024', 'TTS_VolceAIGateway', '温柔小哥', 'zh_male_wenrouxiaoge_mars_bigtts', '中文', NULL, NULL, 24, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0025', 'TTS_VolceAIGateway', 'Smith', 'en_male_smith_mars_bigtts', '英式英语', NULL, NULL, 25, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0026', 'TTS_VolceAIGateway', 'Anna', 'en_female_anna_mars_bigtts', '英式英语', NULL, NULL, 26, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0027', 'TTS_VolceAIGateway', 'Adam', 'en_male_adam_mars_bigtts', '美式英语', NULL, NULL, 27, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0028', 'TTS_VolceAIGateway', 'Sarah', 'en_female_sarah_mars_bigtts', '澳洲英语', NULL, NULL, 28, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0029', 'TTS_VolceAIGateway', 'Dryw', 'en_male_dryw_mars_bigtts', '澳洲英语', NULL, NULL, 29, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0030', 'TTS_VolceAIGateway', 'かずね(和音)', 'multi_male_jingqiangkanye_moon_bigtts', '日语、西语', NULL, NULL, 30, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0031', 'TTS_VolceAIGateway', 'はるこ(晴子)', 'multi_female_shuangkuaisisi_moon_bigtts', '日语、西语', NULL, NULL, 31, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0032', 'TTS_VolceAIGateway', 'ひろし(広志)', 'multi_male_wanqudashu_moon_bigtts', '日语、西语', NULL, NULL, 32, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0033', 'TTS_VolceAIGateway', 'あけみ(朱美)', 'multi_female_gaolengyujie_moon_bigtts', '日语', NULL, NULL, 33, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0034', 'TTS_VolceAIGateway', 'Amanda', 'en_female_amanda_mars_bigtts', '美式英语', NULL, NULL, 34, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0035', 'TTS_VolceAIGateway', 'Jackson', 'en_male_jackson_mars_bigtts', '美式英语', NULL, NULL, 35, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0036', 'TTS_VolceAIGateway', '京腔侃爷/Harmony', 'zh_male_jingqiangkanye_moon_bigtts', '中文-北京口音、英文', NULL, NULL, 36, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0037', 'TTS_VolceAIGateway', '湾湾小何', 'zh_female_wanwanxiaohe_moon_bigtts', '中文-台湾口音', NULL, NULL, 37, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0038', 'TTS_VolceAIGateway', '湾区大叔', 'zh_female_wanqudashu_moon_bigtts', '中文-广东口音', NULL, NULL, 38, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0039', 'TTS_VolceAIGateway', '呆萌川妹', 'zh_female_daimengchuanmei_moon_bigtts', '中文-四川口音', NULL, NULL, 39, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0040', 'TTS_VolceAIGateway', '广州德哥', 'zh_male_guozhoudege_moon_bigtts', '中文-广东口音', NULL, NULL, 40, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0041', 'TTS_VolceAIGateway', '北京小爷', 'zh_male_beijingxiaoye_moon_bigtts', '中文-北京口音', NULL, NULL, 41, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0042', 'TTS_VolceAIGateway', '浩宇小哥', 'zh_male_haoyuxiaoge_moon_bigtts', '中文-青岛口音', NULL, NULL, 42, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0043', 'TTS_VolceAIGateway', '广西远舟', 'zh_male_guangxiyuanzhou_moon_bigtts', '中文-广西口音', NULL, NULL, 43, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0044', 'TTS_VolceAIGateway', '妹坨洁儿', 'zh_female_meituojieer_moon_bigtts', '中文-长沙口音', NULL, NULL, 44, NULL, NULL, NULL, NULL); +INSERT INTO `ai_tts_voice` VALUES ('TTS_VolceAIGateway_0045', 'TTS_VolceAIGateway', '豫州子轩', 'zh_male_yuzhouzixuan_moon_bigtts', '中文-河南口音', NULL, NULL, 45, NULL, NULL, NULL, NULL); \ No newline at end of file diff --git a/main/manager-api/src/main/resources/db/changelog/202506301830.sql b/main/manager-api/src/main/resources/db/changelog/202506301830.sql new file mode 100644 index 0000000000..1984208a56 --- /dev/null +++ b/main/manager-api/src/main/resources/db/changelog/202506301830.sql @@ -0,0 +1,21 @@ +-- 增加火山大模型网关VAD供应器 +DELETE FROM `ai_model_provider` WHERE `id` = 'SYSTEM_VAD_VOLC_GW'; +INSERT INTO `ai_model_provider` (`id`, `model_type`, `provider_code`, `name`, `fields`, `sort`, `creator`, `create_date`, `updater`, `update_date`) VALUES +('SYSTEM_VAD_VOLC_GW', 'VAD', 'volcengine', '火山引擎边缘大模型网关', '[{"key":"api_key","label":"网关秘钥","type":"string"},{"key":"model_name","label":"模型名称","type":"string"},{"key":"host","label":"网关域名","type":"string"},{"key":"senmatic_only","label":"仅使用语义判停","type":"boolean"},{"key":"threshold","label":"音量检测阈值","type":"number"},{"key":"min_silence_duration_ms","label":"最小静音时长","type":"number"},{"key":"max_silence_duration_ms","label":"最大静音时长","type":"number"}]', 1, 1, NOW(), 1, NOW()); + + + +-- 增加火山大模型网关VAD模型配置 +DELETE FROM `ai_model_config` WHERE `id` = 'VAD_VolceAIGateway'; +INSERT INTO `ai_model_config` VALUES ('VAD_VolceAIGateway', 'VAD', 'VolceAIGateway', '火山引擎边缘大模型网关', 0, 1, '{\"type\": \"volcengine\", \"api_key\": \"火山引擎边缘大模型网关的秘钥\", \"model_name\": \"semantic-integrity-recognition\", \"host\": \"ai-gateway.vei.volces.com\", \"senmatic_only\": false,\"threshold\": 0.5, \"min_silence_duration_ms\": 700, \"max_silence_duration_ms\": 3000}', NULL, NULL, 16, 1, NOW(), 1, NOW()); + +-- 火山大模型网关VAD模型配置说明文档 +UPDATE `ai_model_config` SET +`doc_link` = 'https://console.volcengine.com/vei/aigateway/', +`remark` = '火山引擎边缘大模型网关VAD配置说明: +1. 访问 https://console.volcengine.com/vei/aigateway/ +2. 创建网关访问密钥(个人用户申请时注明来自小智xiaozhi-esp32-server社区,并描述使用背景,可更快获得审批,并有机会获得更多token, VAD模型需要oncall发起开白) +3. 勾选Semantic-Integrity-Recognition,网关支持一个api_key访问ASR,LLM,TTS,VLLM模型,满足智能体使用,推荐同时开通“Doubao-语音识别”、“Doubao-语音合成”、“Doubao-pro-32k-functioncall”、“Doubao-1.5-vision-pro”全量模型 +4. 填入配置文件中' WHERE `id` = 'VAD_VolceAIGateway'; + + diff --git a/main/manager-api/src/main/resources/db/changelog/db.changelog-master.yaml b/main/manager-api/src/main/resources/db/changelog/db.changelog-master.yaml index 2a520d1646..40ea970690 100755 --- a/main/manager-api/src/main/resources/db/changelog/db.changelog-master.yaml +++ b/main/manager-api/src/main/resources/db/changelog/db.changelog-master.yaml @@ -177,6 +177,20 @@ databaseChangeLog: - sqlFile: encoding: utf8 path: classpath:db/changelog/202506010920.sql + - changeSet: + id: 202506011728 + author: xh + changes: + - sqlFile: + encoding: utf8 + path: classpath:db/changelog/202506011728.sql + - changeSet: + id: 202506031555 + author: xh + changes: + - sqlFile: + encoding: utf8 + path: classpath:db/changelog/202506031555.sql - changeSet: id: 202506031639 author: hrz @@ -260,4 +274,4 @@ databaseChangeLog: changes: - sqlFile: encoding: utf8 - path: classpath:db/changelog/202507081646.sql \ No newline at end of file + path: classpath:db/changelog/202507081646.sql diff --git a/main/xiaozhi-server/config.yaml b/main/xiaozhi-server/config.yaml index 2ae718962e..8d1997cfb8 100644 --- a/main/xiaozhi-server/config.yaml +++ b/main/xiaozhi-server/config.yaml @@ -270,6 +270,12 @@ ASR: is_ssl: true api_key: none output_dir: tmp/ + VolceAIGateway: + type: volcengine + host: ai-gateway.vei.volces.com + model_name: bigmodel + api_key: 你的api_key + output_dir: tmp/ SherpaASR: type: sherpa_onnx_local model_dir: models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17 @@ -377,6 +383,11 @@ LLM: base_url: https://ark.cn-beijing.volces.com/api/v3 model_name: doubao-1-5-pro-32k-250115 api_key: 你的doubao web key + VolceAIGateway: + type: volcengine + host: ai-gateway.vei.volces.com + model_name: doubao-pro-32k-functioncall + api_key: 你的api_key DeepSeekLLM: # 定义LLM API类型 type: openai @@ -437,7 +448,7 @@ LLM: # 开通后,进入这里获取密钥:https://console.volcengine.com/vei/aigateway/tokens-list base_url: https://ai-gateway.vei.volces.com/v1 model_name: doubao-pro-32k-functioncall - api_key: 你的网关访问密钥 + api_key: 你的api_key LMStudioLLM: # 定义LLM API类型 type: openai @@ -482,6 +493,11 @@ VLLM: model_name: glm-4v-flash # 智谱AI的视觉模型 url: https://open.bigmodel.cn/api/paas/v4/ api_key: 你的api_key + VolceAIGateway: + type: volcengine + host: ai-gateway.vei.volces.com + model_name: doubao-1.5-vision-lite + api_key: 你的api_key QwenVLVLLM: type: openai model_name: qwen2.5-vl-3b-instruct @@ -513,6 +529,14 @@ TTS: speed_ratio: 1.0 volume_ratio: 1.0 pitch_ratio: 1.0 + VolceAIGateway: + type: volcengine + host: ai-gateway.vei.volces.com + model_name: doubao-tts + output_dir: tmp/ + api_key: none + voice: zh_male_shaonianzixin_moon_bigtts + speed: 1 #火山tts,支持双向流式tts HuoshanDoubleStreamTTS: type: huoshan_double_stream diff --git a/main/xiaozhi-server/core/connection.py b/main/xiaozhi-server/core/connection.py index ce68e5b24d..dcf11458e5 100644 --- a/main/xiaozhi-server/core/connection.py +++ b/main/xiaozhi-server/core/connection.py @@ -12,8 +12,6 @@ import websockets from core.utils.util import ( extract_json_from_string, - check_vad_update, - check_asr_update, filter_sensitive_info, ) from typing import Dict, Any @@ -81,6 +79,7 @@ def __init__( self.max_output_size = 0 self.chat_history_conf = 0 self.audio_format = "opus" + self.just_woken_up = False # 客户端状态相关 self.client_abort = False @@ -456,22 +455,23 @@ def _initialize_private_config(self): self.logger.bind(tag=TAG).error(f"获取差异化配置失败: {e}") private_config = {} - init_llm, init_tts, init_memory, init_intent = ( + init_llm, init_tts, init_memory, init_intent, init_vad, init_asr = ( False, False, False, False, + False, + False, ) - init_vad = check_vad_update(self.common_config, private_config) - init_asr = check_asr_update(self.common_config, private_config) - - if init_vad: + if private_config.get("VAD", None) is not None: + init_vad = True self.config["VAD"] = private_config["VAD"] self.config["selected_module"]["VAD"] = private_config["selected_module"][ "VAD" ] - if init_asr: + if private_config.get("ASR", None) is not None: + init_asr = True self.config["ASR"] = private_config["ASR"] self.config["selected_module"]["ASR"] = private_config["selected_module"][ "ASR" diff --git a/main/xiaozhi-server/core/handle/intentHandler.py b/main/xiaozhi-server/core/handle/intentHandler.py index ccff094c07..f924b75017 100644 --- a/main/xiaozhi-server/core/handle/intentHandler.py +++ b/main/xiaozhi-server/core/handle/intentHandler.py @@ -50,7 +50,7 @@ async def check_direct_exit(conn, text): async def analyze_intent_with_llm(conn, text): """使用LLM分析用户意图""" if not hasattr(conn, "intent") or not conn.intent: - conn.logger.bind(tag=TAG).warning("意图识别服务未初始化") + conn.logger.bind(tag=TAG).error("意图识别服务未初始化") return None # 对话历史记录 diff --git a/main/xiaozhi-server/core/handle/receiveAudioHandle.py b/main/xiaozhi-server/core/handle/receiveAudioHandle.py index e6f3963273..97c87e36af 100644 --- a/main/xiaozhi-server/core/handle/receiveAudioHandle.py +++ b/main/xiaozhi-server/core/handle/receiveAudioHandle.py @@ -14,6 +14,8 @@ async def handleAudioMessage(conn, audio): # 当前片段是否有人说话 have_voice = conn.vad.is_vad(conn, audio) + if have_voice: + conn.logger.bind(tag=TAG).info(f"收到音频数据,len: {len(audio)}, wake_up: {conn.just_woken_up}") # 如果设备刚刚被唤醒,短暂忽略VAD检测 if have_voice and hasattr(conn, "just_woken_up") and conn.just_woken_up: have_voice = False @@ -25,6 +27,7 @@ async def handleAudioMessage(conn, audio): if have_voice: if conn.client_is_speaking: + conn.logger.bind(tag=TAG).info("对话过程被客户端打断") await handleAbortMessage(conn) # 设备长时间空闲检测,用于say goodbye await no_voice_close_connect(conn, have_voice) @@ -76,6 +79,7 @@ async def startToChat(conn, text): await max_out_size(conn) return if conn.client_is_speaking: + conn.logger.bind(tag=TAG).info("对话过程被客户端打断") await handleAbortMessage(conn) # 首先进行意图分析,使用实际文本内容 @@ -155,7 +159,7 @@ async def check_bind_device(conn): continue conn.tts.tts_audio_queue.put((SentenceType.LAST, [], None)) else: - text = f"没有找到该设备的版本信息,请正确配置 OTA地址,然后重新编译固件。" + text = "没有找到该设备的版本信息,请正确配置 OTA地址,然后重新编译固件。" await send_stt_message(conn, text) music_path = "config/assets/bind_not_found.wav" opus_packets, _ = audio_to_data(music_path) diff --git a/main/xiaozhi-server/core/handle/sendAudioHandle.py b/main/xiaozhi-server/core/handle/sendAudioHandle.py index 486e6d900f..59a2e614fd 100644 --- a/main/xiaozhi-server/core/handle/sendAudioHandle.py +++ b/main/xiaozhi-server/core/handle/sendAudioHandle.py @@ -34,7 +34,8 @@ async def sendAudioMessage(conn, sentenceType, audios, text): # 发送句子开始消息 - conn.logger.bind(tag=TAG).info(f"发送音频消息: {sentenceType}, {text}") + audio_len = len(audios) + conn.logger.bind(tag=TAG).info(f"发送音频消息: {sentenceType}, {audio_len}, {text}") if text is not None: emotion = analyze_emotion(text) emoji = emoji_map.get(emotion, "🙂") # 默认使用笑脸 diff --git a/main/xiaozhi-server/core/providers/asr/base.py b/main/xiaozhi-server/core/providers/asr/base.py index ef9fa01e7e..4018394c33 100644 --- a/main/xiaozhi-server/core/providers/asr/base.py +++ b/main/xiaozhi-server/core/providers/asr/base.py @@ -227,7 +227,7 @@ def _pcm_to_wav(self, pcm_data: bytes) -> bytes: logger.bind(tag=TAG).error(f"WAV转换失败: {e}") return b"" - def stop_ws_connection(self): + async def stop_ws_connection(self): pass def save_audio_to_file(self, pcm_data: List[bytes], session_id: str) -> str: @@ -250,6 +250,15 @@ async def speech_to_text( ) -> Tuple[Optional[str], Optional[str]]: """将语音数据转换为文本""" pass + + def is_eou(self, conn, text) -> bool: + """判断是否为结束语句""" + if text is None or len(text) == 0: + return False + is_eou = conn.vad.is_eou(conn, text) + if is_eou: + logger.bind(tag=TAG).info(f"检测到结束语句 {text}") + return is_eou @staticmethod def decode_opus(opus_data: List[bytes]) -> List[bytes]: diff --git a/main/xiaozhi-server/core/providers/asr/doubao_stream.py b/main/xiaozhi-server/core/providers/asr/doubao_stream.py index 31704b9694..518eaed8c9 100644 --- a/main/xiaozhi-server/core/providers/asr/doubao_stream.py +++ b/main/xiaozhi-server/core/providers/asr/doubao_stream.py @@ -209,9 +209,9 @@ async def _forward_asr_results(self, conn): self.asr_ws = None self.is_processing = False - def stop_ws_connection(self): + async def stop_ws_connection(self): if self.asr_ws: - asyncio.create_task(self.asr_ws.close()) + await self.asr_ws.close() self.asr_ws = None self.is_processing = False diff --git a/main/xiaozhi-server/core/providers/asr/volcengine.py b/main/xiaozhi-server/core/providers/asr/volcengine.py new file mode 100644 index 0000000000..209343c9bc --- /dev/null +++ b/main/xiaozhi-server/core/providers/asr/volcengine.py @@ -0,0 +1,386 @@ +""" +This module provides the ASR (Automatic Speech Recognition) provider for the Volcengine service. +It implements a streaming ASR client that connects to the Volcengine ASR service via WebSocket. +""" + +import asyncio +import base64 +import json +import os +import uuid +from typing import List, Optional, Tuple + +import opuslib_next +import websockets + +from config.logger import setup_logging +from core.handle.receiveAudioHandle import startToChat +from core.handle.reportHandle import enqueue_asr_report +from core.providers.asr.base import ASRProviderBase +from core.providers.asr.dto.dto import InterfaceType +from core.utils.util import remove_punctuation_and_length + +TAG = __name__ +logger = setup_logging() + + +class ASRProvider(ASRProviderBase): + """ + Implements a streaming ASR provider for Volcengine, inheriting from ASRProviderBase. + + This class manages a WebSocket connection to the Volcengine ASR service for real-time + speech-to-text transcription. It handles audio data processing, session management, + and result forwarding. + """ + + def __init__(self, config: dict, delete_audio_file: bool): + """ + Initializes the ASRProvider for Volcengine. + + Args: + config (dict): A dictionary containing configuration parameters such as + api_key, model_name, output_dir, and host. + delete_audio_file (bool): Flag to determine if audio files should be deleted + after processing. + """ + super().__init__() + self.interface_type = InterfaceType.STREAM + self.api_key = config.get("api_key") + self.model_name = config.get("model_name") + self.output_dir = config.get("output_dir", "tmp/") + self.host = config.get("host", "ai-gateway.vei.volces.com") + self.delete_audio_file = delete_audio_file + self.ws_url = f"wss://{self.host}/v1/realtime?model={self.model_name}" + self.success_code = 1000 + self.seg_duration = 15000 + + # Ensure the output directory exists. + os.makedirs(self.output_dir, exist_ok=True) + + # State variables for streaming processing. + self.ws = None + self.conn = None + # 客户端发送任务结束修改状态 + self.session_started = False + # 客户端接收任务结束修改状态 + self.is_processing: bool = False + self.forward_task: Optional[asyncio.Task] = None + self.text: str = "" # Stores the currently recognized text. + self.decoder = opuslib_next.Decoder(16000, 1) # Opus decoder if input is opus. + self.current_session_id: Optional[str] = None + self.audio_buffer = bytearray() # Buffer for audio data. + + async def open_audio_channels(self, conn): + """ + Opens audio channels and initializes the session. + + Args: + conn: The connection object, which includes session details. + """ + await super().open_audio_channels(conn) + await self._ensure_connection() + self.conn = conn + + async def receive_audio(self, conn, audio: bytes, audio_have_voice: bool): + """ + Receives and processes incoming audio data. + + This method buffers audio, detects voice activity, and initiates the ASR session + when voice is detected. It sends audio chunks to the ASR service for processing. + + Args: + conn: The connection object. + audio (bytes): The raw audio data chunk. + audio_have_voice (bool): Flag indicating if the current audio chunk contains voice. + """ + # Buffer audio; discard old audio if there's no voice. + conn.asr_audio.append(audio) + conn.asr_audio = conn.asr_audio[-10:] + + # Start a new ASR session if voice is detected and not already processing. + if audio_have_voice and not self.is_processing: + try: + self.is_processing = True + await self.start_session() + pcm_frame = self.decode_opus(conn.asr_audio) + await self._send_audio_chunk(b"".join(pcm_frame)) + conn.asr_audio.clear() + except Exception as e: + logger.bind(tag=TAG).error( + f"Failed to establish ASR connection: {e}", exc_info=True + ) + await self.stop_ws_connection() + return + + # Send the current audio data if the session is active. + if self.ws and self.is_processing: + try: + logger.bind(tag=TAG).debug( + f"Sending audio data, size: {len(audio)} for session: {self.current_session_id}" + ) + pcm_frame = self.decode_opus(conn.asr_audio) + await self._send_audio_chunk(b"".join(pcm_frame)) + conn.asr_audio.clear() + except Exception as e: + logger.bind(tag=TAG).error( + f"Error sending audio data: {e}", exc_info=True + ) + await self.stop_ws_connection() + # Finish the session if end-of-utterance is detected. + if self.ws and self.session_started and self.is_eou(conn, self.text): + logger.bind(tag=TAG).info(f"Finishing session: {self.current_session_id}") + await self.finish_session() + + async def _send_audio_chunk(self, pcm_data: bytes): + """ + Sends a chunk of PCM audio data to the ASR service. + + The audio data is Base64-encoded and sent as a JSON message over the WebSocket. + + Args: + pcm_data (bytes): The PCM audio data to send. + """ + if not self.ws or not self.is_processing: + return + + # The Volcengine streaming ASR service expects Base64-encoded PCM data. + base64_audio = base64.b64encode(pcm_data).decode("utf-8") + audio_event = {"audio": base64_audio, "type": "input_audio_buffer.append"} + await self.ws.send(json.dumps(audio_event)) + logger.bind(tag=TAG).debug(f"Sent audio chunk, size: {len(pcm_data)}") + + async def _forward_asr_results(self): + """ + Listens for and processes incoming messages from the ASR service. + + This method runs in a background task, continuously receiving ASR results, + updating the recognized text, and handling final transcripts. + """ + try: + logger.bind(tag=TAG).debug( + f"ASR forwarder started for session: {self.current_session_id}" + ) + while self.ws and not self.conn.stop_event.is_set() and self.is_processing: + try: + message = await self.ws.recv() + event = json.loads(message) + logger.bind(tag=TAG).debug( + f"Received ASR result for session {self.current_session_id}: {event}" + ) + + # Parse the response from the Volcengine streaming ASR service. + message_type = event.get("type") + if ( + message_type + == "conversation.item.input_audio_transcription.result" + ): + transcript_segment = event.get("transcript", "") + is_final = event.get("is_final", False) + self.text = transcript_segment # Append intermediate result. + if is_final: + logger.bind(tag=TAG).info(f"Final ASR result: {self.text}") + self.conn.reset_vad_states() + await self.handle_voice_stop(self.conn, None) + elif ( + message_type + == "conversation.item.input_audio_transcription.completed" + ): + final_transcript = event.get("transcript", self.text) + logger.bind(tag=TAG).info( + f"ASR transcription completed: {final_transcript}" + ) + self.text = final_transcript # Ensure final result is used. + self.conn.reset_vad_states() + await self.handle_voice_stop(self.conn, None) + self.text = "" # Reset for next utterance. + break # End the receiving task. + elif message_type == "error": + error_msg = event.get("error", {}) + logger.bind(tag=TAG).error(f"ASR service error: {error_msg}") + break + + except websockets.ConnectionClosed: + await self.stop_ws_connection() + logger.bind(tag=TAG).error("ASR WebSocket connection closed.") + break + except json.JSONDecodeError: + logger.bind(tag=TAG).error( + f"Failed to decode JSON from ASR: {message}" + ) + except Exception as e: + logger.bind(tag=TAG).error( + f"Error processing ASR result: {e}", exc_info=True + ) + break + finally: + logger.bind(tag=TAG).debug( + f"ASR forwarder task finished for session: {self.current_session_id}" + ) + self.is_processing = False + + async def speech_to_text( + self, opus_data: List[bytes], session_id: str, audio_format="opus" + ) -> Tuple[Optional[str], Optional[str]]: + """ + In streaming mode, this method returns the currently recognized text. + The final result is handled by the `_forward_asr_results` method. + + Args: + opus_data (List[bytes]): List of Opus audio data chunks. + session_id (str): The ID of the current session. + audio_format (str): The format of the audio data. + + Returns: + A tuple containing the recognized text and None. + """ + return self.text, None + + async def start_session(self): + """ + Starts a new ASR transcription session. + + This involves ensuring a WebSocket connection is active, sending a session + start request, and creating a task to listen for ASR results. + + Args: + session_id (str): The unique identifier for the session. + + Raises: + Exception: If the session fails to start. + """ + self.current_session_id = uuid.uuid4().hex + logger.bind(tag=TAG).info(f"Starting session {self.current_session_id}") + try: + await self._ensure_connection() + # Create the request message to start streaming recognition. + config = { + "input_audio_format": "pcm", + "input_audio_codec": "raw", + "input_audio_sample_rate": 16000, + "input_audio_bits": 16, + "input_audio_channel": 1, + "input_audio_transcription": {"model": self.model_name}, + "session_id": self.current_session_id, + } + event = {"type": "transcription_session.update", "session": config} + await self.ws.send(json.dumps(event)) + self.session_started = True + logger.bind(tag=TAG).debug(f"Session start request sent: {event}") + + # Start the task to listen for results. + if self.forward_task is None: + self.forward_task = asyncio.create_task(self._forward_asr_results()) + except Exception as e: + logger.bind(tag=TAG).error(f"Failed to start session: {e}") + if self.forward_task: + self.forward_task.cancel() + try: + await self.forward_task + except asyncio.CancelledError: + pass + self.forward_task = None + await self.stop_ws_connection() + raise + + async def finish_session(self): + """ + Finishes the current ASR session. + + Sends a commit message to the service to finalize the transcription and waits + for the result forwarding task to complete. + + Args: + session_id (str): The ID of the session to finish. + """ + logger.bind(tag=TAG).info(f"Stopping session {self.current_session_id}") + try: + self.audio_buffer.clear() + done_payload = {"type": "input_audio_buffer.commit"} + await self.ws.send(json.dumps(done_payload)) + self.session_started = False + logger.bind(tag=TAG).debug( + f"Session finish: {done_payload} for session: {self.current_session_id}" + ) + except Exception as e: + await self.stop_ws_connection() + logger.bind(tag=TAG).error(f"Failed to close session: {e}") + + # Wait for the forwarding task to complete. + if self.forward_task: + try: + await self.forward_task + logger.bind(tag=TAG).debug("Forwarding task has completed.") + except Exception as e: + logger.bind(tag=TAG).error(f"Error waiting for forwarding task: {e}") + finally: + self.forward_task = None + self.current_session_id = None + + async def handle_voice_stop(self, conn, asr_audio_task): + """ + Handles the event when voice activity stops. + + Retrieves the final recognized text and initiates the chat process. + + Args: + conn: The connection object. + asr_audio_task: The audio task associated with the ASR. + """ + raw_text, _ = await self.speech_to_text( + asr_audio_task, conn.session_id, conn.audio_format + ) + conn.logger.bind(tag=TAG).info(f"Recognized text: {raw_text}") + text_len, _ = remove_punctuation_and_length(raw_text) + if text_len > 0: + await startToChat(conn, raw_text) + enqueue_asr_report(conn, raw_text, asr_audio_task) + + async def _ensure_connection(self): + """ + Ensures that the WebSocket connection to the ASR service is active. + + If the connection is down, it attempts to reconnect. + + Raises: + Exception: If the connection cannot be established. + """ + # 检查连接是否存在且处于 open 状态 + # websockets 库的自动 ping/pong 机制会处理连接健康检查 + if self.ws: + logger.bind(tag=TAG).debug("WebSocket connection is active.") + return + + # 如果连接不存在或已关闭,则重新连接 + try: + logger.bind(tag=TAG).info(f"Connecting to {self.ws_url}") + headers = {"Authorization": f"Bearer {self.api_key}"} + # 使用内置的 ping/pong 机制来维持连接和检查健康状况 + # 每 60 秒发送一次 ping,等待 30 秒超时 + self.ws = await websockets.connect( + self.ws_url, + additional_headers=headers, + ping_interval=60, # Increased from 20 + ping_timeout=30, # Increased from 10 + close_timeout=10, # Added for graceful close + ) + logger.bind(tag=TAG).info("WebSocket connection established.") + except Exception as e: + logger.bind(tag=TAG).error(f"Failed to connect to WebSocket: {e}") + self.ws = None + raise + + async def stop_ws_connection(self): + """ + Stops the WebSocket connection gracefully. + """ + logger.bind(tag=TAG).info("Stopping ASR WebSocket connection...") + if self.ws: + try: + await self.ws.close() + logger.bind(tag=TAG).info( + "ASR WebSocket connection closed successfully." + ) + except websockets.WebSocketException as e: + logger.bind(tag=TAG).error(f"Error closing ASR WebSocket: {e}") + finally: + self.ws = None diff --git a/main/xiaozhi-server/core/providers/llm/volcengine/volcengine.py b/main/xiaozhi-server/core/providers/llm/volcengine/volcengine.py new file mode 100644 index 0000000000..1df24df17c --- /dev/null +++ b/main/xiaozhi-server/core/providers/llm/volcengine/volcengine.py @@ -0,0 +1,112 @@ +""" +此模块实现了基于火山引擎 OpenAI 接口的大语言模型服务。 +定义了 LLMProvider 类,继承自 LLMProviderBase, +提供了初始化配置和生成响应的功能,支持普通对话响应和带函数调用的对话响应。 +""" +import openai +from openai.types import CompletionUsage +from config.logger import setup_logging +from core.utils.util import check_model_key +from core.providers.llm.base import LLMProviderBase + +TAG = __name__ +logger = setup_logging() + + +class LLMProvider(LLMProviderBase): + """ + LLMProvider 类用于实现基于火山引擎 OpenAI 接口的大语言模型服务。 + 该类继承自 LLMProviderBase,提供了初始化配置和生成响应的功能, + 支持普通对话响应和带函数调用的对话响应。 + """ + def __init__(self, config): + self.api_key = config.get("api_key") + self.model_name = config.get("model_name") + self.host = config.get("host") + if self.host is None: + self.host = "ai-gateway.vei.volces.com" + + self.base_url = f"https://{self.host}/v1" + + + param_defaults = { + "max_tokens": (500, int), + "temperature": (0.7, lambda x: round(float(x), 1)), + "top_p": (1.0, lambda x: round(float(x), 1)), + "frequency_penalty": (0, lambda x: round(float(x), 1)) + } + + for param, (default, converter) in param_defaults.items(): + value = config.get(param) + try: + setattr(self, param, converter(value) if value not in (None, "") else default) + except (ValueError, TypeError): + setattr(self, param, default) + + logger.debug( + f"意图识别参数初始化: {self.temperature}, {self.max_tokens}, {self.top_p}, {self.frequency_penalty}") + + check_model_key("LLM", self.api_key) + logger.bind(tag=TAG).info(f"LLM client paramters: {self.api_key} {self.base_url}") + self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url) + + def response(self, session_id, dialogue, **kwargs): + try: + responses = self.client.chat.completions.create( + model=self.model_name, + messages=dialogue, + stream=True, + max_tokens=kwargs.get("max_tokens", self.max_tokens), + temperature=kwargs.get("temperature", self.temperature), + top_p=kwargs.get("top_p", self.top_p), + frequency_penalty=kwargs.get("frequency_penalty", self.frequency_penalty), + ) + + is_active = True + for chunk in responses: + try: + # 检查是否存在有效的choice且content不为空 + delta = ( + chunk.choices[0].delta + if getattr(chunk, "choices", None) + else None + ) + content = delta.content if hasattr(delta, "content") else "" + except IndexError: + content = "" + if content: + # 处理标签跨多个chunk的情况 + if "" in content: + is_active = False + content = content.split("")[0] + if "" in content: + is_active = True + content = content.split("")[-1] + if is_active: + yield content + + except Exception as e: + logger.bind(tag=TAG).error(f"Error in response generation: {e}") + + def response_with_functions(self, session_id, dialogue, functions=None): + try: + stream = self.client.chat.completions.create( + model=self.model_name, messages=dialogue, stream=True, tools=functions + ) + + for chunk in stream: + # 检查是否存在有效的choice且content不为空 + if getattr(chunk, "choices", None): + yield chunk.choices[0].delta.content, chunk.choices[0].delta.tool_calls + # 存在 CompletionUsage 消息时,生成 Token 消耗 log + elif isinstance(getattr(chunk, 'usage', None), CompletionUsage): + usage_info = getattr(chunk, 'usage', None) + logger.bind(tag=TAG).info( + f"Token 消耗:输入 {getattr(usage_info, 'prompt_tokens', '未知')}," + f"输出 {getattr(usage_info, 'completion_tokens', '未知')}," + f"共计 {getattr(usage_info, 'total_tokens', '未知')}" + ) + + except Exception as e: + logger.bind(tag=TAG).error(f"Error in function call streaming: {e}") + yield f"【OpenAI服务响应异常: {e}】", None diff --git a/main/xiaozhi-server/core/providers/tts/volcengine.py b/main/xiaozhi-server/core/providers/tts/volcengine.py new file mode 100644 index 0000000000..4447163682 --- /dev/null +++ b/main/xiaozhi-server/core/providers/tts/volcengine.py @@ -0,0 +1,404 @@ +""" +This module implements Text-to-Speech (TTS) functionality based on the Volcengine service. + +It supports both bidirectional streaming TTS via WebSocket and one-time TTS requests +via HTTP interface. +""" +import asyncio +import base64 +import io +import json +import queue +import uuid + +import openai +import pydub +import websockets + +from config.logger import setup_logging +from core.handle.abortHandle import handleAbortMessage +from core.providers.tts.base import TTSProviderBase +from core.providers.tts.dto.dto import ContentType, InterfaceType, SentenceType +from core.utils import opus_encoder_utils +from core.utils.tts import MarkdownCleaner +from core.utils.util import check_model_key + +TAG = __name__ +logger = setup_logging() + + +class TTSProvider(TTSProviderBase): + """ + Implements the TTS provider for Volcengine, inheriting from TTSProviderBase. + + This class supports both dual-stream TTS via WebSocket and single-request TTS + via HTTP, providing real-time and non-real-time speech synthesis capabilities. + """ + + def __init__(self, config, delete_audio_file): + """ + Initializes the TTSProvider for Volcengine. + + Args: + config (dict): A dictionary containing the configuration for the TTS provider. + delete_audio_file (bool): Whether to delete the generated audio file after playback. + """ + super().__init__(config, delete_audio_file) + self.interface_type = InterfaceType.DUAL_STREAM + self.api_key = config.get("api_key") + self.model_name = config.get("model_name") + self.host = config.get("host") + if self.host is None: + self.host = "ai-gateway.vei.volces.com" + self.delete_audio_file = delete_audio_file + self.ws_url = f"wss://{self.host}/v1/realtime?model={self.model_name}" + self.base_url = f"https://{self.host}/v1" + if config.get("private_voice"): + self.voice = config.get("private_voice") + else: + self.voice = config.get("voice", "alloy") + self.audio_file_type = config.get("format", "wav") # 流式接口通常使用 pcm + self.sample_rate = config.get("sample_rate", 16000) + self.opus_encoder = opus_encoder_utils.OpusEncoderUtils( + sample_rate=16000, channels=1, frame_size_ms=60 + ) + # 处理空字符串的情况 + speed = config.get("speed", "1.0") + self.speed = float(speed) if speed else 1.0 + self.ws = None + self._monitor_task = None + check_model_key("TTS", self.api_key) + self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url) + + + async def open_audio_channels(self, conn): + """ + Opens the audio channels to prepare for TTS. + + This method ensures that a WebSocket connection to the Volcengine service is established. + + Args: + conn: The connection object for managing session state. + """ + try: + await super().open_audio_channels(conn) + await self._ensure_connection() + except Exception as e: + logger.bind(tag=TAG).error(f"Failed to open audio channels: {str(e)}") + self.ws = None + raise + + + async def _receive_audio(self): + """Listens for and processes WebSocket responses from the Volcengine TTS service.""" + opus_datas_cache = [] + is_first_sentence = True + first_sentence_segment_count = 0 # 添加计数器 + try: + while not self.conn.stop_event.is_set(): + message = await self.ws.recv() + data = json.loads(message) + event_type = data.get("type") + logger.bind(tag=TAG).debug(f"Received data: {data}") + # 检查客户端是否中止 + if self.conn.client_abort: + logger.bind(tag=TAG).info("收到打断信息,终止监听TTS响应") + break + if event_type == "tts_session.updated": + logger.bind(tag=TAG).info(f"Session: {self.conn.sentence_id}, 完成会话初始化") + opus_datas_cache = [] + is_first_sentence = True + first_sentence_segment_count = 0 + elif event_type == "response.audio.delta": + pcm_data = base64.b64decode(data.get("delta")) + opus_data_list = self.opus_encoder.encode_pcm_to_opus(pcm_data, False) + if len(opus_data_list) == 0: + continue + if is_first_sentence: + logger.bind(tag=TAG).debug(f"Session: {self.conn.sentence_id}, Received first audio data") + is_first_sentence = False + self.tts_audio_queue.put( + (SentenceType.FIRST, opus_data_list, None) + ) + else: + logger.bind(tag=TAG).debug("Received delta audio data") + self.tts_audio_queue.put( + (SentenceType.MIDDLE, opus_data_list, None) + ) + first_sentence_segment_count += 1 + elif event_type == "response.audio_subtitle.delta": + logger.bind(tag=TAG).debug(f"Session: {self.conn.sentence_id},Received subtitles delata data: {data}") + subtitles = data.get("subtitles").get("text") + if subtitles: + self.tts_audio_queue.put( + (SentenceType.MIDDLE, [], subtitles) + ) + elif event_type == "response.audio.done": + logger.bind(tag=TAG).debug(f"Session: {self.conn.sentence_id},完成语音生成.") + if self.tts_audio_queue: + self.tts_audio_queue.put( + (SentenceType.LAST, opus_datas_cache, None) + ) + self._process_before_stop_play_files() + break # End of stream for this request + elif event_type == "error": + logger.bind(tag=TAG).error(f"Received error from server: {data}") + break + except websockets.exceptions.ConnectionClosed as e: + await self.stop_ws_connection() + logger.bind(tag=TAG).warning(f"WebSocket connection closed: {e}") + except Exception as e: + logger.bind(tag=TAG).error(f"Error receiving audio: {e}") + finally: + logger.bind(tag=TAG).debug(f"Session: {self.conn.sentence_id}, 退出音频接收任务") + + def tts_text_priority_thread(self): + """ + Runs in a separate thread to process the text queue for synthesis. + + It starts, sends data to, and ends the TTS session based on the message type + (FIRST, MIDDLE, LAST). + """ + while not self.conn.stop_event.is_set(): + try: + message = self.tts_text_queue.get(timeout=1) + logger.bind(tag=TAG).debug( + f"收到TTS任务|{message.sentence_type.name} | {message.content_type.name} | 会话ID: {self.conn.sentence_id}" + ) + if message.sentence_type == SentenceType.FIRST: + self.conn.client_abort = False + + if self.conn.client_abort: + logger.bind(tag=TAG).info("收到打断信息,终止TTS文本处理线程") + continue + if message.sentence_type == SentenceType.FIRST: + if not getattr(self.conn, "sentence_id", None): + self.conn.sentence_id = uuid.uuid4().hex + logger.bind(tag=TAG).info(f"自动生成新的会话ID: {self.conn.sentence_id}") + logger.bind(tag=TAG).debug("开始启动TTS会话...") + future = asyncio.run_coroutine_threadsafe( + self.start_session(self.conn.sentence_id), + loop=self.conn.loop, + ) + future.result() + self.tts_stop_request = False + self.processed_chars = 0 + self.tts_text_buff = [] + self.is_first_sentence = True + self.tts_audio_first_sentence = True + self.before_stop_play_files.clear() + logger.bind(tag=TAG).debug("TTS会话启动成功") + elif ContentType.TEXT == message.content_type: + self.tts_text_buff.append(message.content_detail) + segment_text = self._get_segment_text() + if segment_text: + logger.bind(tag=TAG).info( + f"session: {self.conn.sentence_id} 发送TTS文本: {segment_text}" + ) + future = asyncio.run_coroutine_threadsafe( + self._send_text(segment_text), + loop=self.conn.loop, + ) + future.result() + elif ContentType.FILE == message.content_type: + logger.bind(tag=TAG).info( + f"添加音频文件到待播放列表: {message.content_file}" + ) + self.before_stop_play_files.append( + (message.content_file, message.content_detail) + ) + + if message.sentence_type == SentenceType.LAST: + logger.bind(tag=TAG).info(f"session: {self.conn.sentence_id} 结束TTS会话") + future = asyncio.run_coroutine_threadsafe( + self.finish_session(self.conn.sentence_id), + loop=self.conn.loop, + ) + future.result() + except queue.Empty: + continue + except Exception as e: + logger.bind(tag=TAG).error( + f"处理TTS文本失败: {str(e)}, 类型: {type(e).__name__}, 堆栈: {e.__traceback__}" + + ) + continue + + async def start_session(self, session_id): + """ + Starts a TTS session. + + Args: + session_id (str): A unique session ID. + """ + logger.bind(tag=TAG).info(f"开始会话 {session_id}") + try: + # 建立新连接 + await self._ensure_connection() + # 发送会话启动请求 + session_update_payload = { + "event_id": str(uuid.uuid4()), + "type": "tts_session.update", + "session": { + "voice": self.voice, + "output_audio_format": self.audio_file_type, + "output_audio_sample_rate": self.sample_rate, + "text_to_speech": { + "model": self.model_name + } + } + } + await self.ws.send(json.dumps(session_update_payload)) + logger.bind(tag=TAG).debug(f"会话启动请求已发送, Send Event: {session_update_payload}") + # 启动监听任务 + if self._monitor_task is None: + self._monitor_task = asyncio.create_task(self._receive_audio()) + except Exception as e: + logger.bind(tag=TAG).error(f"启动会话失败: {str(e)}") + # 确保清理资源 + if hasattr(self, "_monitor_task"): + try: + self._monitor_task.cancel() + await self._monitor_task + except Exception: + pass + self._monitor_task = None + await self.stop_ws_connection() + raise + + async def finish_session(self, session_id): + """ + Finishes a TTS session. + + Args: + session_id (str): The unique session ID. + """ + try: + done_payload = { + "type": "input_text.done" + } + await self.ws.send(json.dumps(done_payload)) + logger.bind(tag=TAG).debug(f"会话结束请求已发送,Send Event: {done_payload}") + + except Exception as e: + logger.bind(tag=TAG).error(f"关闭会话失败: {str(e)}") + await self.stop_ws_connection() + + # 等待监听任务完成 + if hasattr(self, "_monitor_task"): + try: + await self._monitor_task + logger.bind(tag=TAG).debug("退出monitor task") + except Exception as e: + logger.bind(tag=TAG).error( + f"等待监听任务完成时发生错误: {str(e)}" + ) + finally: + self._monitor_task = None + + async def _ensure_connection(self): + """ + Ensures that the WebSocket connection to the ASR service is active. + + If the connection is down, it attempts to reconnect. + + Raises: + Exception: If the connection cannot be established. + """ + # 检查连接是否存在且处于 open 状态 + # websockets 库的自动 ping/pong 机制会处理连接健康检查 + if self.ws: + logger.bind(tag=TAG).info("WebSocket connection is active.") + return + + # 如果连接不存在或已关闭,则重新连接 + try: + logger.bind(tag=TAG).info(f"Connecting to {self.ws_url}") + headers = {"Authorization": f"Bearer {self.api_key}"} + # 使用内置的 ping/pong 机制来维持连接和检查健康状况 + # 每 60 秒发送一次 ping,等待 30 秒超时 + self.ws = await websockets.connect( + self.ws_url, + additional_headers=headers, + ping_interval=60, + ping_timeout=30, + close_timeout=10 + ) + logger.bind(tag=TAG).info("WebSocket connection established.") + except Exception as e: + logger.bind(tag=TAG).error(f"Failed to connect to WebSocket: {e}") + self.ws = None + raise + + + + + async def stop_ws_connection(self): + """Safely closes the WebSocket connection.""" + # 关闭WebSocket连接 + if self.ws: + try: + await self.ws.close() + logger.bind(tag=TAG).info("WebSocket connection closed.") + except Exception as e: + logger.bind(tag=TAG).error(f"Error closing WebSocket: {e}") + finally: + self.ws = None + + async def _send_text(self, text): + """Sends a chunk of text to the TTS service for synthesis.""" + try: + # 建立新连接 + if self.ws is None: + logger.bind(tag=TAG).error("WebSocket连接不存在,终止发送文本") + await handleAbortMessage(self.conn) + return + + # 过滤Markdown + filtered_text = MarkdownCleaner.clean_markdown(text) + + # 发送文本 + if len(filtered_text) > 0: + text_append_payload = { + "event_id": str(uuid.uuid4()), + "type": "input_text.append", + "delta": filtered_text + } + await self.ws.send(json.dumps(text_append_payload)) + logger.bind(tag=TAG).debug(f"发送文本, Send Event: {text_append_payload}") + return + except Exception as e: + logger.bind(tag=TAG).error(f"发送TTS文本失败: {str(e)}") + await self.stop_ws_connection() + raise + + async def text_to_speak(self, text, output_file): + """ + Converts text to speech via an HTTP POST request. + + Args: + text (str): The text to be converted. + output_file (str): The path to save the audio file. + + Returns: + bytes: The byte stream of the generated WAV audio data. + """ + logger.bind(tag=TAG).info(f"采用http方式发送文本: {text}") + response = self.client.audio.speech.create( + model = self.model_name, + voice = self.voice, + input = text + ) + # 其他格式用pydub + audio = pydub.AudioSegment.from_file( + io.BytesIO(response.content), format="mp3", parameters=["-nostdin"] + ) + wav_buffer = io.BytesIO() + audio.export(wav_buffer, format="wav") + wav_bytes = wav_buffer.getvalue() + output_file = "/tmp/a.wav" + if output_file: + with open(output_file, "wb") as audio_file: + audio_file.write(wav_bytes) + return wav_bytes + \ No newline at end of file diff --git a/main/xiaozhi-server/core/providers/vad/base.py b/main/xiaozhi-server/core/providers/vad/base.py index 1d8d4c8dde..27e7d7245f 100644 --- a/main/xiaozhi-server/core/providers/vad/base.py +++ b/main/xiaozhi-server/core/providers/vad/base.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from typing import Optional class VADProviderBase(ABC): @@ -7,3 +6,13 @@ class VADProviderBase(ABC): def is_vad(self, conn, data) -> bool: """检测音频数据中的语音活动""" pass + + @abstractmethod + def is_eou(self, conn, text) -> bool: + """End of Utterance(话语结束检测),是基于语义理解的自动判断用户发言是否结束的技术, True 表示结束,False 表示未结束""" + pass + + @abstractmethod + def get_silence_duration(self, conn) -> int: + """返回语音静音时长,单位ms""" + pass diff --git a/main/xiaozhi-server/core/providers/vad/silero.py b/main/xiaozhi-server/core/providers/vad/silero.py index b516d8fbd2..41c86785ee 100644 --- a/main/xiaozhi-server/core/providers/vad/silero.py +++ b/main/xiaozhi-server/core/providers/vad/silero.py @@ -18,7 +18,7 @@ def __init__(self, config): model="silero_vad", force_reload=False, ) - + self.stop_duration = 0 self.decoder = opuslib_next.Decoder(16000, 1) # 处理空字符串的情况 @@ -36,9 +36,9 @@ def __init__(self, config): # 至少要多少帧才算有语音 self.frame_window_threshold = 3 - def is_vad(self, conn, opus_packet): + def is_vad(self, conn, data): try: - pcm_frame = self.decoder.decode(opus_packet, 960) + pcm_frame = self.decoder.decode(data, 960) conn.client_audio_buffer.extend(pcm_frame) # 将新数据加入缓冲区 # 处理缓冲区中的完整帧(每次处理512采样点) @@ -74,10 +74,11 @@ def is_vad(self, conn, opus_packet): # 如果之前有声音,但本次没有声音,且与上次有声音的时间差已经超过了静默阈值,则认为已经说完一句话 if conn.client_have_voice and not client_have_voice: - stop_duration = time.time() * 1000 - conn.last_activity_time - if stop_duration >= self.silence_threshold_ms: + self.stop_duration = time.time() * 1000 - conn.last_activity_time + if self.stop_duration >= self.silence_threshold_ms: conn.client_voice_stop = True if client_have_voice: + self.stop_duration = 0 conn.client_have_voice = True conn.last_activity_time = time.time() * 1000 @@ -86,3 +87,13 @@ def is_vad(self, conn, opus_packet): logger.bind(tag=TAG).info(f"解码错误: {e}") except Exception as e: logger.bind(tag=TAG).error(f"Error processing audio packet: {e}") + + def is_eou(self, conn, text) : + """End of Utterance(话语结束检测),是基于语义理解的自动判断用户发言是否结束的技术""" + return conn.client_voice_stop + + def get_silence_duration(self, conn) : + """返回语音静音时长,单位ms""" + if conn.client_voice_stop: + return self.stop_duration + return 0 diff --git a/main/xiaozhi-server/core/providers/vad/volcengine.py b/main/xiaozhi-server/core/providers/vad/volcengine.py new file mode 100644 index 0000000000..181cf182d8 --- /dev/null +++ b/main/xiaozhi-server/core/providers/vad/volcengine.py @@ -0,0 +1,158 @@ +""" +This module provides a Voice Activity Detection (VAD) provider using the Volcengine service. + +It combines basic VAD (like Silero) with semantic End-of-Utterance (EOU) detection +using embeddings to determine if a user has finished speaking. +""" +import openai + +from config.logger import setup_logging +from core.providers.vad.base import VADProviderBase +from core.utils.vad import create_instance + +TAG = __name__ +logger = setup_logging() + + +class VADProvider(VADProviderBase): + """ + Implements a VAD provider based on Volcengine for semantic End-of-Utterance (EOU) detection. + + This class uses a base VAD model (e.g., Silero) for initial voice activity detection + and leverages a Volcengine embedding model to determine if the user's speech + constitutes a complete thought, allowing for more natural conversation flow. + """ + + def __init__(self, config: dict): + """ + Initializes the Volcengine VAD provider. + + Args: + config (dict): Configuration dictionary containing settings for the base VAD, + semantic detection, and Volcengine API credentials. + """ + logger.bind(tag=TAG).info(f"init VAD_volcengine: config:{config}") + config['model_dir'] = "models/snakers4_silero-vad" + self.base_vad_model = create_instance("silero", config) + min_silence_duration_ms = config.get("min_silence_duration_ms", "1000") + max_silence_duration_ms = config.get("max_silence_duration_ms", "3000") + + self.semantic_only = config.get("semantic_only", False) + self.min_silence_threshold_ms = ( + int(min_silence_duration_ms) if min_silence_duration_ms else 1000 + ) + self.max_silence_threshold_ms = ( + int(max_silence_duration_ms) if max_silence_duration_ms else 3000 + ) + self.api_key = config.get("api_key") + self.model_name = config.get("model_name") + self.host = config.get("host","ai-gateway.vei.volces.com") + + self.base_url = f"https://{self.host}/v1" + self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url) + self.cached_text = "" + self.cached_embedding = None + + def is_vad(self, conn, data: bytes): + """ + Performs basic Voice Activity Detection. + + If semantic_only is True, this check is bypassed. Otherwise, it delegates + to the base VAD model. + + Args: + conn: The connection object. + data: The audio data chunk. + + Returns: + bool: True if voice activity is detected, False otherwise. + """ + if self.semantic_only: + return True + return self.base_vad_model.is_vad(conn, data) + + def is_eou(self, conn, text: str): + """ + Determines if the end of an utterance (EOU) has been reached. + + This method combines semantic analysis from an embedding model with silence duration. + The logic adjusts the EOU detection threshold based on how long the user has been silent. + + Args: + conn: The connection object. + text (str): The transcribed text so far. + + Returns: + bool: True if the utterance is considered complete, False otherwise. + """ + silence_duration = self.get_silence_duration(conn) + logger.bind(tag=TAG).debug(f"silence_duration : {silence_duration}") + # If text is empty, EOU is determined solely by silence duration. + if not text or not text.strip(): + return silence_duration >= self.max_silence_threshold_ms + + # For semantic checks, we need the embedding. + embedding, is_cached = self._get_embedding(text) + is_stop = embedding[1] > 0.5 + if not is_cached or silence_duration >= self.max_silence_threshold_ms: + logger.bind(tag=TAG).info(f"EOU Result: text:{text} embedding:{embedding} semantic_stop:{is_stop} silence_duration:{silence_duration} cache:{is_cached}") + if self.semantic_only: + return is_stop + + if silence_duration <= self.min_silence_threshold_ms / 2: + # If silence is short, be less likely to interrupt. + return False + elif silence_duration <= self.min_silence_threshold_ms: + # Short silence, requires high confidence to stop. + return embedding[1] > 0.9 + elif silence_duration <= self.max_silence_threshold_ms: + # Medium silence, requires medium confidence to stop. + return embedding[1] > 0.8 + else: + # Force stop if the user has been silent for a while. + return True + + + def get_silence_duration(self, conn): + """ + Gets the current silence duration from the base VAD model. + + Args: + conn: The connection object. + + Returns: + int: The duration of silence in milliseconds. + """ + return self.base_vad_model.get_silence_duration(conn) + def _get_embedding(self, text: str): + """ + Retrieves the text embedding from the Volcengine model. + + Args: + text (str): The input text to get the embedding for. + + Returns: + list: The embedding vector. + bool: True if the embedding is from cache, False otherwise. + + Raises: + Exception: If the API call to the embedding model fails. + """ + if not text or not text.strip(): + return [1.0, 0.0] + if self.cached_text == text: + return self.cached_embedding, True + try: + logger.bind(tag=TAG).debug(f"调用嵌入模型: model: {self.model_name}, input:{text}") + response = self.client.embeddings.create( + model=self.model_name, + encoding_format="float", + input=text + ) + embedding = response.data[0].embedding + self.cached_text = text + self.cached_embedding = embedding + return embedding, False + except Exception as e: + logger.bind(tag=TAG).error(f"调用嵌入模型失败: {str(e)}") + raise diff --git a/main/xiaozhi-server/core/providers/vllm/volcengine.py b/main/xiaozhi-server/core/providers/vllm/volcengine.py new file mode 100644 index 0000000000..3528ab2e46 --- /dev/null +++ b/main/xiaozhi-server/core/providers/vllm/volcengine.py @@ -0,0 +1,103 @@ +""" +This module provides an implementation of a VLLM (Vision Language and Large Model) provider +for the Volcengine service. It allows interaction with multimodal models that can process +both text and images. +""" +import json + +import openai + +from config.logger import setup_logging +from core.providers.vllm.base import VLLMProviderBase +from core.utils.util import check_model_key + +TAG = __name__ +logger = setup_logging() + + +class VLLMProvider(VLLMProviderBase): + """ + Implements the VLLM provider for Volcengine, supporting multimodal interactions. + + This class handles the configuration, client initialization, and communication + with the Volcengine API to get responses from vision-language models. + """ + + def __init__(self, config): + """ + Initializes the VLLMProvider for Volcengine. + + Args: + config (dict): A dictionary containing the configuration for the VLLM provider, + including API key, model name, host, and other parameters. + """ + self.api_key = config.get("api_key") + self.model_name = config.get("model_name") + self.host = config.get("host") + if self.host is None: + self.host = "ai-gateway.vei.volces.com" + + self.base_url = f"https://{self.host}/v1" + self.model_name = config.get("model_name") + + + # Default parameters for the model, with type converters. + param_defaults = { + "max_tokens": (500, int), + "temperature": (0.7, lambda x: round(float(x), 1)), + "top_p": (1.0, lambda x: round(float(x), 1)), + } + + # Set model parameters from config, falling back to defaults. + for param, (default, converter) in param_defaults.items(): + value = config.get(param) + try: + setattr( + self, + param, + converter(value) if value not in (None, "") else default, + ) + except (ValueError, TypeError): + setattr(self, param, default) + + + check_model_key("VLLM", self.api_key) + self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url) + + def response(self, question, base64_image): + """ + Sends a request to the Volcengine VLLM service with a question and an image. + + Args: + question (str): The text question to ask the model. + base64_image (str): The base64-encoded image data to be sent with the question. + + Returns: + str: The text response from the model. + + Raises: + Exception: If there is an error during the API call. + """ + try: + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": question}, + { + "type": "image_url", + "image_url": {"url": f"{base64_image}"}, + }, + ], + } + ] + + response = self.client.chat.completions.create( + model=self.model_name, messages=messages, stream=False + ) + + return response.choices[0].message.content + + except Exception as e: + logger.bind(tag=TAG).error(f"Error in response generation: {e}") + raise diff --git a/main/xiaozhi-server/core/utils/util.py b/main/xiaozhi-server/core/utils/util.py index bc778558f0..5797b8cbf5 100644 --- a/main/xiaozhi-server/core/utils/util.py +++ b/main/xiaozhi-server/core/utils/util.py @@ -862,52 +862,6 @@ def opus_datas_to_wav_bytes(opus_datas, sample_rate=16000, channels=1): return wav_buffer.getvalue() -def check_vad_update(before_config, new_config): - if ( - new_config.get("selected_module") is None - or new_config["selected_module"].get("VAD") is None - ): - return False - update_vad = False - current_vad_module = before_config["selected_module"]["VAD"] - new_vad_module = new_config["selected_module"]["VAD"] - current_vad_type = ( - current_vad_module - if "type" not in before_config["VAD"][current_vad_module] - else before_config["VAD"][current_vad_module]["type"] - ) - new_vad_type = ( - new_vad_module - if "type" not in new_config["VAD"][new_vad_module] - else new_config["VAD"][new_vad_module]["type"] - ) - update_vad = current_vad_type != new_vad_type - return update_vad - - -def check_asr_update(before_config, new_config): - if ( - new_config.get("selected_module") is None - or new_config["selected_module"].get("ASR") is None - ): - return False - update_asr = False - current_asr_module = before_config["selected_module"]["ASR"] - new_asr_module = new_config["selected_module"]["ASR"] - current_asr_type = ( - current_asr_module - if "type" not in before_config["ASR"][current_asr_module] - else before_config["ASR"][current_asr_module]["type"] - ) - new_asr_type = ( - new_asr_module - if "type" not in new_config["ASR"][new_asr_module] - else new_config["ASR"][new_asr_module]["type"] - ) - update_asr = current_asr_type != new_asr_type - return update_asr - - def filter_sensitive_info(config: dict) -> dict: """ 过滤配置中的敏感信息 diff --git a/main/xiaozhi-server/core/websocket_server.py b/main/xiaozhi-server/core/websocket_server.py index 09ba796156..884cdb1e91 100644 --- a/main/xiaozhi-server/core/websocket_server.py +++ b/main/xiaozhi-server/core/websocket_server.py @@ -4,7 +4,6 @@ from core.connection import ConnectionHandler from config.config_loader import get_config_from_api from core.utils.modules_initialize import initialize_modules -from core.utils.util import check_vad_update, check_asr_update TAG = __name__ @@ -100,20 +99,15 @@ async def update_config(self) -> bool: self.logger.bind(tag=TAG).error("获取新配置失败") return False self.logger.bind(tag=TAG).info(f"获取新配置成功") - # 检查 VAD 和 ASR 类型是否需要更新 - update_vad = check_vad_update(self.config, new_config) - update_asr = check_asr_update(self.config, new_config) - self.logger.bind(tag=TAG).info( - f"检查VAD和ASR类型是否需要更新: {update_vad} {update_asr}" - ) + # 更新配置 self.config = new_config # 重新初始化组件 modules = initialize_modules( self.logger, new_config, - update_vad, - update_asr, + "VAD" in new_config["selected_module"], + "ASR" in new_config["selected_module"], "LLM" in new_config["selected_module"], False, "Memory" in new_config["selected_module"],