From 9ff49e49250b315c95d07032632d76ecabd71128 Mon Sep 17 00:00:00 2001 From: Michal Kulakowski Date: Mon, 16 Mar 2026 16:09:16 +0100 Subject: [PATCH] Support kokoro model --- Dockerfile.redhat | 16 +- Dockerfile.ubuntu | 15 +- Makefile | 2 + src/audio/audio_utils.cpp | 33 +- src/audio/audio_utils.hpp | 1 + src/audio/speech_to_text/s2t_servable.cpp | 1 - src/audio/text_to_speech/BUILD | 1 + src/audio/text_to_speech/t2s_calculator.cc | 58 +- src/audio/text_to_speech/t2s_calculator.proto | 2 + src/audio/text_to_speech/t2s_servable.cpp | 27 +- src/audio/text_to_speech/t2s_servable.hpp | 10 +- tts_asr_roundtrip.py | 495 ++++++++++++++++++ 12 files changed, 634 insertions(+), 27 deletions(-) create mode 100644 tts_asr_roundtrip.py diff --git a/Dockerfile.redhat b/Dockerfile.redhat index 3a59009f21..e905def8b5 100644 --- a/Dockerfile.redhat +++ b/Dockerfile.redhat @@ -109,6 +109,7 @@ SHELL ["/bin/bash", "-xo", "pipefail", "-c"] ARG JOBS=40 ARG VERBOSE_LOGS=OFF ARG LTO_ENABLE=OFF +ARG ESPEAK=1 # hadolint ignore=DL3041 RUN dnf install -y -d6 \ @@ -129,6 +130,10 @@ RUN dnf install -y -d6 \ python3.12-pip \ libicu-devel && \ dnf clean all +RUN if [ "$ESPEAK" == "1" ] ; then \ + dnf install -y espeak-ng espeak-ng-libs || dnf install -y espeak-ng-libs ; \ + dnf clean all ; \ + fi WORKDIR / @@ -234,11 +239,11 @@ RUN git clone https://github.com/$ov_tokenizers_org/openvino_tokenizers.git /ope fi WORKDIR /openvino_genai/ -ARG ov_genai_branch=master -ARG ov_genai_org=openvinotoolkit +ARG ov_genai_branch=kokoro_tts +ARG ov_genai_repo=https://github.com/RyanMetcalfeInt8/openvino.genai.git # hadolint ignore=DL3003 RUN if [ "$ov_use_binary" == "0" ]; then true ; else exit 0 ; fi ; \ - git clone https://github.com/$ov_genai_org/openvino.genai /openvino_genai && cd /openvino_genai && git checkout $ov_genai_branch && git submodule update --init --recursive && \ + git clone $ov_genai_repo /openvino_genai && cd /openvino_genai && git checkout $ov_genai_branch && git submodule update --init --recursive && \ cmake -DCMAKE_BUILD_TYPE=$CMAKE_BUILD_TYPE -DCMAKE_CXX_FLAGS=" ${SDL_OPS} ${LTO_CXX_FLAGS} " -DCMAKE_SHARED_LINKER_FLAGS="${LTO_LD_FLAGS}" -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DENABLE_SYSTEM_ICU="True" -DBUILD_TOKENIZERS=OFF -DENABLE_SAMPLES=OFF -DENABLE_TOOLS=OFF -DENABLE_TESTS=OFF -DENABLE_XGRAMMAR=ON -S ./ -B ./build/ && \ cmake --build ./build/ --parallel $JOBS && cp /openvino_genai/build/openvino_genai/lib*.so* /opt/intel/openvino/runtime/lib/intel64/ && \ cp -r /openvino_genai/src/cpp/include/* /opt/intel/openvino/runtime/include/ && \ @@ -393,6 +398,7 @@ LABEL "maintainer"="dariusz.trawinski@intel.com" ARG INSTALL_RPMS_FROM_URL= ARG INSTALL_DRIVER_VERSION="24.52.32224" ARG GPU=0 +ARG ESPEAK=1 ARG debug_bazel_flags= LABEL bazel-build-flags=${debug_bazel_flags} LABEL supported-devices="CPU=1 GPU=${GPU}" @@ -407,6 +413,10 @@ COPY ./install_redhat_gpu_drivers.sh /install_gpu_drivers.sh # hadolint ignore=DL3003,DL3041,SC2164,SC1091 RUN if [ -f /usr/bin/dnf ] ; then export DNF_TOOL=dnf ; echo -e "max_parallel_downloads=8\nretries=50" >> /etc/dnf/dnf.conf ; else export DNF_TOOL=microdnf ; fi ; \ $DNF_TOOL upgrade --setopt=install_weak_deps=0 --nodocs -y ; \ + if [ "$ESPEAK" == "1" ] ; then \ + $DNF_TOOL install -y espeak-ng espeak-ng-libs --setopt=install_weak_deps=0 --nodocs || \ + $DNF_TOOL install -y espeak-ng-libs --setopt=install_weak_deps=0 --nodocs ; \ + fi ; \ if [ "$GPU" == "1" ] ; then \ source /install_gpu_drivers.sh && rm -rf /install_gpu_drivers.sh; \ fi ; \ diff --git a/Dockerfile.ubuntu b/Dockerfile.ubuntu index f7e57e380c..6a3761a60d 100644 --- a/Dockerfile.ubuntu +++ b/Dockerfile.ubuntu @@ -95,6 +95,7 @@ ENV DEBIAN_FRONTEND=noninteractive SHELL ["/bin/bash", "-xo", "pipefail", "-c"] ARG debug_bazel_flags="--strip=always --config=mp_on_py_on --//:distro=ubuntu" +ARG ESPEAK=1 RUN if [ "$BASE_OS" == "ubuntu24" ] ; then apt-get update && \ apt-get install -y software-properties-common --no-install-recommends; add-apt-repository 'ppa:deadsnakes/ppa' -y && \ apt-get clean && rm -rf /var/lib/apt/lists/* ; fi @@ -124,6 +125,10 @@ RUN apt-get update && apt-get install --no-install-recommends -y \ vim && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* +RUN if [ "$ESPEAK" == "1" ] ; then \ + apt-get update && apt-get install -y --no-install-recommends espeak-ng && \ + apt-get clean && rm -rf /var/lib/apt/lists/* ; \ + fi # on ubuntu 24.04 python3.12 is used as default python for ovms build and release # TF build needs python3.10 with numpy as it does not support python3.12 RUN python3.10 -m pip install "numpy<2.0.0" --no-cache-dir @@ -220,12 +225,12 @@ RUN if [ "$ov_use_binary" == "0" ]; then true ; else exit 0 ; fi ; \ if ! [[ $debug_bazel_flags == *"_py_off"* ]]; then \ cp build/python/* /opt/intel/openvino/python/openvino_tokenizers/ ; \ fi -ARG ov_genai_branch=master -ARG ov_genai_org=openvinotoolkit +ARG ov_genai_branch=kokoro_tts +ARG ov_genai_repo=https://github.com/RyanMetcalfeInt8/openvino.genai.git WORKDIR /openvino_genai/ # hadolint ignore=DL3003 RUN if [ "$ov_use_binary" == "0" ]; then \ - git clone https://github.com/$ov_genai_org/openvino.genai /openvino_genai && cd /openvino_genai && git checkout $ov_genai_branch && git submodule update --init --recursive && \ + git clone $ov_genai_repo /openvino_genai && cd /openvino_genai && git checkout $ov_genai_branch && git submodule update --init --recursive && \ cmake -DCMAKE_BUILD_TYPE=$CMAKE_BUILD_TYPE -DCMAKE_CXX_FLAGS=" ${SDL_OPS} " -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DENABLE_SYSTEM_ICU="True" -DBUILD_TOKENIZERS=OFF -DENABLE_SAMPLES=OFF -DENABLE_TOOLS=OFF -DENABLE_TESTS=OFF -DENABLE_XGRAMMAR=ON -S ./ -B ./build/ && \ cmake --build ./build/ --parallel $JOBS && cp /openvino_genai/build/openvino_genai/lib*.so* /opt/intel/openvino/runtime/lib/intel64/ && \ cp -r /openvino_genai/src/cpp/include/* /opt/intel/openvino/runtime/include/ && \ @@ -395,6 +400,7 @@ ARG INSTALL_RPMS_FROM_URL= ARG INSTALL_DRIVER_VERSION="24.26.30049" ARG GPU=0 ARG NPU=0 +ARG ESPEAK=1 ENV DEBIAN_FRONTEND=noninteractive ARG debug_bazel_flags= LABEL bazel-build-flags=${debug_bazel_flags} @@ -413,6 +419,9 @@ COPY ./install_ubuntu_gpu_drivers.sh /tmp/install_gpu_drivers.sh # hadolint ignore=DL3003,SC2164 RUN apt-get update ; \ apt-get install -y --no-install-recommends curl ca-certificates libxml2 || exit 1; \ + if [ "$ESPEAK" == "1" ] ; then \ + apt-get install -y --no-install-recommends espeak-ng || exit 1; \ + fi ; \ if [ "$GPU" == "1" ] ; then \ /tmp/install_gpu_drivers.sh ; \ fi ; \ diff --git a/Makefile b/Makefile index 6d7c5d2918..f0638f3dc0 100644 --- a/Makefile +++ b/Makefile @@ -61,6 +61,7 @@ BUILD_TESTS ?= 0 RUN_GPU_TESTS ?= GPU ?= 0 NPU ?= 0 +ESPEAK ?= 1 BUILD_NGINX ?= 0 MEDIAPIPE_DISABLE ?= 0 PYTHON_DISABLE ?= 0 @@ -237,6 +238,7 @@ BUILD_ARGS = --build-arg http_proxy=$(HTTP_PROXY)\ --build-arg BASE_OS=$(BASE_OS)\ --build-arg INSTALL_RPMS_FROM_URL=$(INSTALL_RPMS_FROM_URL)\ --build-arg INSTALL_DRIVER_VERSION=$(INSTALL_DRIVER_VERSION)\ + --build-arg ESPEAK=$(ESPEAK)\ --build-arg RELEASE_BASE_IMAGE=$(BASE_IMAGE_RELEASE)\ --build-arg JOBS=$(JOBS)\ --build-arg CAPI_FLAGS=$(CAPI_FLAGS)\ diff --git a/src/audio/audio_utils.cpp b/src/audio/audio_utils.cpp index 77b38e70df..f2281b85ce 100644 --- a/src/audio/audio_utils.cpp +++ b/src/audio/audio_utils.cpp @@ -1,5 +1,5 @@ //***************************************************************************** -// Copyright 2025 Intel Corporation +// Copyright 2026 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ #include "src/logging.hpp" #include #include +#include #include #include #pragma warning(push) @@ -188,3 +189,33 @@ void prepareAudioOutput(void** ppData, size_t& pDataSize, uint16_t bitsPerSample auto outputPreparationTime = (timer.elapsed(OUTPUT_PREPARATION)) / 1000; SPDLOG_LOGGER_DEBUG(t2s_calculator_logger, "Output preparation time: {} ms", outputPreparationTime); } + +void prepareAudioOutputKokoro(void** ppData, size_t& pDataSize, size_t speechSize, const float* waveformPtr) { + enum : unsigned int { + OUTPUT_PREPARATION, + TIMER_END + }; + Timer timer; + timer.start(OUTPUT_PREPARATION); + + drwav_data_format format; + format.container = drwav_container_riff; + format.format = DR_WAVE_FORMAT_IEEE_FLOAT; + format.channels = 1; + format.sampleRate = 24000; // Kokoro native sample rate + format.bitsPerSample = 32; + drwav wav; + + auto status = drwav_init_memory_write(&wav, ppData, &pDataSize, &format, nullptr); + if (status == DRWAV_FALSE) { + throw std::runtime_error("Failed to initialize WAV writer"); + } + drwav_uint64 framesWritten = drwav_write_pcm_frames(&wav, speechSize, waveformPtr); + if (framesWritten != speechSize) { + throw std::runtime_error("Failed to write all frames"); + } + drwav_uninit(&wav); + timer.stop(OUTPUT_PREPARATION); + auto outputPreparationTime = (timer.elapsed(OUTPUT_PREPARATION)) / 1000; + SPDLOG_LOGGER_DEBUG(t2s_calculator_logger, "Output preparation time: {} ms", outputPreparationTime); +} \ No newline at end of file diff --git a/src/audio/audio_utils.hpp b/src/audio/audio_utils.hpp index cbeea8b457..d3ef577c38 100644 --- a/src/audio/audio_utils.hpp +++ b/src/audio/audio_utils.hpp @@ -25,3 +25,4 @@ bool isWavBuffer(const std::string buf); std::vector readWav(const std::string_view& wavData); std::vector readMp3(const std::string_view& mp3Data); void prepareAudioOutput(void** ppData, size_t& pDataSize, uint16_t bitsPerSample, size_t speechSize, const float* waveformPtr); +void prepareAudioOutputKokoro(void** ppData, size_t& pDataSize, size_t speechSize, const float* waveformPtr); \ No newline at end of file diff --git a/src/audio/speech_to_text/s2t_servable.cpp b/src/audio/speech_to_text/s2t_servable.cpp index 451a8bb316..332c33dd34 100644 --- a/src/audio/speech_to_text/s2t_servable.cpp +++ b/src/audio/speech_to_text/s2t_servable.cpp @@ -35,7 +35,6 @@ namespace ovms { namespace { constexpr size_t ISO_LANG_CODE_MAX = 3; } - SttServable::SttServable(const ::mediapipe::S2tCalculatorOptions& nodeOptions, const std::string& graphPath) { auto fsModelsPath = std::filesystem::path(nodeOptions.models_path()); if (fsModelsPath.is_relative()) { diff --git a/src/audio/text_to_speech/BUILD b/src/audio/text_to_speech/BUILD index 2a494f6e16..a3d10f9893 100644 --- a/src/audio/text_to_speech/BUILD +++ b/src/audio/text_to_speech/BUILD @@ -37,6 +37,7 @@ ovms_cc_library( srcs = ["t2s_calculator.cc", "tts_node_initializer.cpp"], deps = [ + "//third_party:genai", "@mediapipe//mediapipe/framework:calculator_framework", "//src:httppayload", "//src:libovmslogging", diff --git a/src/audio/text_to_speech/t2s_calculator.cc b/src/audio/text_to_speech/t2s_calculator.cc index f8f4912f0d..634ee80502 100644 --- a/src/audio/text_to_speech/t2s_calculator.cc +++ b/src/audio/text_to_speech/t2s_calculator.cc @@ -28,6 +28,8 @@ #include "src/client_connection.hpp" #include "src/http_payload.hpp" #include "src/logging.hpp" +#include "openvino/genai/speech_generation/text2speech_pipeline.hpp" +#include "openvino/openvino.hpp" #include #include @@ -63,6 +65,8 @@ static absl::Status checkClientDisconnected(const ovms::HttpPayload& payload, co class T2sCalculator : public CalculatorBase { static const std::string INPUT_TAG_NAME; static const std::string OUTPUT_TAG_NAME; + std::string defaultLanguage = "en-us"; + float defaultSpeed = 1.0f; public: static absl::Status GetContract(CalculatorContract* cc) { @@ -81,6 +85,13 @@ class T2sCalculator : public CalculatorBase { absl::Status Open(CalculatorContext* cc) final { SPDLOG_LOGGER_DEBUG(t2s_calculator_logger, "T2sCalculator [Node: {}] Open start", cc->NodeName()); + const auto& options = cc->Options(); + if (options.has_language() && !options.language().empty()) { + defaultLanguage = options.language(); + } + if (options.has_speed()) { + defaultSpeed = options.speed(); + } return absl::OkStatus(); } @@ -113,26 +124,49 @@ class T2sCalculator : public CalculatorBase { if (streamIt != payload.parsedJson->MemberEnd()) { return absl::InvalidArgumentError("streaming is not supported"); } + SPDLOG_LOGGER_DEBUG(t2s_calculator_logger, "1"); std::optional voiceName; auto voiceIt = payload.parsedJson->FindMember("voice"); - if (voiceIt != payload.parsedJson->MemberEnd() && voiceIt->value.IsString()) { + if (voiceIt != payload.parsedJson->MemberEnd()) { + if (!voiceIt->value.IsString()) { + return absl::InvalidArgumentError("voice field is not a string"); + } voiceName = voiceIt->value.GetString(); - if (pipe->voices.find(voiceName.value()) == pipe->voices.end()) - return absl::InvalidArgumentError(absl::StrCat("Requested voice not available: ", voiceName.value())); } - + std::string language = defaultLanguage; + auto languageIt = payload.parsedJson->FindMember("language"); + if (languageIt != payload.parsedJson->MemberEnd()) { + if (!languageIt->value.IsString()) { + return absl::InvalidArgumentError("language field is not a string"); + } + language = languageIt->value.GetString(); + } + float speed = defaultSpeed; + auto speedIt = payload.parsedJson->FindMember("speed"); + if (speedIt != payload.parsedJson->MemberEnd()) { + if (!speedIt->value.IsNumber()) { + return absl::InvalidArgumentError("speed field is not a number"); + } + speed = speedIt->value.GetFloat(); + } ov::genai::Text2SpeechDecodedResults generatedSpeech; std::unique_lock lock(pipe->ttsPipelineMutex); auto disconnectStatus = checkClientDisconnected(payload, cc->NodeName(), "before generation"); if (!disconnectStatus.ok()) return disconnectStatus; - + ov::Tensor speakerEmbedding; + std::string selectedVoice = "af_alloy"; if (voiceName.has_value()) { - generatedSpeech = pipe->ttsPipeline->generate(inputIt->value.GetString(), pipe->voices[voiceName.value()]); - } else { - generatedSpeech = pipe->ttsPipeline->generate(inputIt->value.GetString()); + selectedVoice = voiceName.value(); + auto speakerIt = pipe->voices.find(selectedVoice); + if (speakerIt != pipe->voices.end()) { + speakerEmbedding = speakerIt->second; + } } - auto bitsPerSample = generatedSpeech.speeches[0].get_element_type().bitwidth(); + ov::AnyMap properties{{"voice", selectedVoice}, {"language", language}, {"speed", speed}}; + generatedSpeech = pipe->ttsPipeline->generate(inputIt->value.GetString(), speakerEmbedding, properties); + SPDLOG_LOGGER_DEBUG(t2s_calculator_logger, "3"); + //auto bitsPerSample = generatedSpeech.speeches[0].get_element_type().bitwidth(); auto speechSize = generatedSpeech.speeches[0].get_size(); ov::Tensor cpuTensor(generatedSpeech.speeches[0].get_element_type(), generatedSpeech.speeches[0].get_shape()); // copy results to release inference request @@ -143,7 +177,9 @@ class T2sCalculator : public CalculatorBase { return disconnectStatus; void* ppData; size_t pDataSize; - prepareAudioOutput(&ppData, pDataSize, bitsPerSample, speechSize, cpuTensor.data()); + SPDLOG_LOGGER_DEBUG(t2s_calculator_logger, "4"); + prepareAudioOutputKokoro(&ppData, pDataSize, speechSize, cpuTensor.data()); + SPDLOG_LOGGER_DEBUG(t2s_calculator_logger, "5"); output = std::make_unique(reinterpret_cast(ppData), pDataSize); drwav_free(ppData, NULL); } else { @@ -151,6 +187,8 @@ class T2sCalculator : public CalculatorBase { } } catch (ov::AssertFailure& e) { return absl::InvalidArgumentError(e.what()); + }catch (std::runtime_error& e) { + return absl::InvalidArgumentError(e.what()); } catch (...) { return absl::InvalidArgumentError("Response generation failed"); } diff --git a/src/audio/text_to_speech/t2s_calculator.proto b/src/audio/text_to_speech/t2s_calculator.proto index efea722c3d..5bfce7811f 100644 --- a/src/audio/text_to_speech/t2s_calculator.proto +++ b/src/audio/text_to_speech/t2s_calculator.proto @@ -40,4 +40,6 @@ message T2sCalculatorOptions { required string path = 2; } repeated SpeakerEmbeddings voices = 4; + optional string language = 5 [default = "en-us"]; + optional float speed = 6 [default = 1.0]; } diff --git a/src/audio/text_to_speech/t2s_servable.cpp b/src/audio/text_to_speech/t2s_servable.cpp index c782c9346d..a179dd6288 100644 --- a/src/audio/text_to_speech/t2s_servable.cpp +++ b/src/audio/text_to_speech/t2s_servable.cpp @@ -19,8 +19,8 @@ #include #include #include +#include -#include "openvino/genai/whisper_pipeline.hpp" #include "openvino/genai/speech_generation/text2speech_pipeline.hpp" #include "src/audio/text_to_speech/t2s_calculator.pb.h" #include "src/status.hpp" @@ -31,7 +31,15 @@ namespace ovms { -static ov::Tensor read_speaker_embedding(const std::filesystem::path& file_path) { +static size_t getShapeElementsCount(const ov::Shape& shape) { + size_t elementsCount = 1; + for (const auto dim : shape) { + elementsCount *= dim; + } + return elementsCount; +} + +static ov::Tensor read_speaker_embedding(const std::filesystem::path& file_path, const ov::Shape& expectedShape) { std::ifstream input(file_path, std::ios::binary); if (input.fail()) { std::stringstream ss; @@ -48,12 +56,16 @@ static ov::Tensor read_speaker_embedding(const std::filesystem::path& file_path) if (buffer_size % sizeof(float) != 0) { throw std::runtime_error("File size is not a multiple of float size."); } - size_t num_floats = buffer_size / sizeof(float); - if (num_floats != 512) { - throw std::runtime_error("File must contain speaker embedding including 512 32-bit floats."); + const size_t numFloats = buffer_size / sizeof(float); + const size_t expectedElements = getShapeElementsCount(expectedShape); + if (numFloats != expectedElements) { + std::stringstream ss; + ss << "File must contain speaker embedding with " << expectedElements + << " 32-bit floats. Got: " << numFloats; + throw std::runtime_error(ss.str()); } - ov::Tensor floats_tensor(ov::element::f32, ov::Shape{1, num_floats}); + ov::Tensor floats_tensor(ov::element::f32, expectedShape); input.read(reinterpret_cast(floats_tensor.data()), buffer_size); if (input.fail()) { throw std::runtime_error("Failed to read all data from file."); @@ -76,10 +88,11 @@ TtsServable::TtsServable(const std::string& modelDir, const std::string& targetD throw std::runtime_error("Error during plugin_config option parsing"); } ttsPipeline = std::make_shared(parsedModelsPath.string(), targetDevice, config); + const ov::Shape speakerEmbeddingShape = ttsPipeline->get_speaker_embedding_shape(); for (auto voice : graphVoices) { if (!std::filesystem::exists(voice.path())) throw std::runtime_error{"Requested voice speaker embeddings file does not exist: " + voice.path()}; - voices[voice.name()] = read_speaker_embedding(voice.path()); + voices[voice.name()] = read_speaker_embedding(voice.path(), speakerEmbeddingShape); } } } // namespace ovms diff --git a/src/audio/text_to_speech/t2s_servable.hpp b/src/audio/text_to_speech/t2s_servable.hpp index 6d192edcfb..6f5249baa4 100644 --- a/src/audio/text_to_speech/t2s_servable.hpp +++ b/src/audio/text_to_speech/t2s_servable.hpp @@ -16,15 +16,21 @@ #pragma once -#include "openvino/genai/speech_generation/text2speech_pipeline.hpp" #include "src/audio/text_to_speech/t2s_calculator.pb.h" +#include #include +#include #include #include -namespace ovms { +#include "openvino/runtime/tensor.hpp" + +namespace ov::genai { +class Text2SpeechPipeline; +} +namespace ovms { class TtsServable { public: std::shared_ptr ttsPipeline; diff --git a/tts_asr_roundtrip.py b/tts_asr_roundtrip.py new file mode 100644 index 0000000000..c726e84d48 --- /dev/null +++ b/tts_asr_roundtrip.py @@ -0,0 +1,495 @@ +#!/usr/bin/env python3 +import argparse +import json +import os +import random +import sys +import time +from difflib import SequenceMatcher +from pathlib import Path +from urllib.request import Request, urlopen + +from openai import OpenAI + +CHINESE_PROMPTS = [ + "Kokoro 是一系列体积虽小但功能强大的 TTS 模型。", + "今天天气很好,我们去公园散步吧。", + "请把这份报告在下午三点前发送给我。", + "人工智能正在改变我们的工作和生活方式。", + "这个模型支持多种语言,包括中文、英文和日语。", + "如果出现错误,请重试或联系管理员。", + "会议将在下周二上午九点准时开始。", + "请确认你的收货地址和联系电话是否正确。", + "项目进度需要每周更新一次并提交给经理。", + "为了提高性能,我们需要优化推理流程。", + "系统已成功部署到生产环境,监控正常。", + "这句话用于测试中文语音合成的清晰度。", + "请阅读并同意服务条款和隐私政策。", + "模型输出的音频应当自然流畅且可理解。", + "数据备份已完成,请检查日志确认结果。", + "请在两分钟内完成系统重启,并确认服务恢复。", + "客户反馈延迟较高,我们需要检查网络链路。", + "今天是星期五,记得提交本周的工作总结。", + "该功能已进入灰度发布阶段,请关注指标变化。", + "日志中出现多次超时错误,请检查依赖服务。", + "请将版本号更新为 1.2.3,并生成发布说明。", + "系统负载过高,建议临时扩容两台实例。", + "请确认验证码已发送到用户手机号。", + "数据库备份完成后请验证备份完整性。", + "优化缓存命中率可以显著提升响应速度。", + "模型推理耗时过长,需要排查瓶颈。", + "我们计划在下月上线新的支付流程。", + "请核对发票信息,确保金额与订单一致。", + "设备离线超过 30 分钟,请检查供电。", + "用户输入包含特殊字符,请做好过滤处理。", + "请确认接口文档已同步更新。", + "今天的会议取消,改为周三上午十点。", + "该功能支持多语言切换,请验证中文显示。", + "请检查邮件是否被误判为垃圾邮件。", + "请在测试环境验证修复结果,再合入主干。", +] +ENGLISH_PROMPTS = [ + "The HTTP 404 error indicates the requested resource wasn't found on the server.", + "My phone number is +1-555-123-4567, and my email is john.doe@example.com.", + "The CPU utilization reached 98.7% at 3:42 AM, triggering an automated alert.", + "Professor Smith's lecture on quantum mechanics is scheduled for December 15th, 2024.", + "The API endpoint responds in approximately 127 milliseconds with a 200 OK status.", + "NVIDIA's GeForce RTX 4090 GPU features 24GB of GDDR6X memory.", + "The PostgreSQL database crashed at 10:15 PM UTC due to out-of-memory errors.", + "Dr. Williams recommended acetaminophen 500mg three times daily for pain management.", + "The SHA-256 hash of the file is 3a4b5c6d7e8f9g0h1i2j3k4l5m6n7o8p9q0r1s2t3u4v.", + "Mount Everest's peak stands at 8,848.86 meters or 29,031.7 feet above sea level.", + "The XML configuration file references namespace xmlns:xsi='http://www.w3.org'.", + "Flight BA2490 departed London Heathrow at 14:25 GMT, arriving in New York JFK at 17:15 EST.", + "The Schrödinger equation describes quantum-mechanical wave functions: iℏ ∂Ψ/∂t = ĤΨ.", + "NASA's Artemis III mission aims to land astronauts near the lunar south pole.", + "The DNS server at IP address 192.168.1.1 failed to resolve www.example.com.", + "JavaScript's async/await syntax simplifies promise-based asynchronous code handling.", + "The Wi-Fi password is 'MyS3cur3P@ssw0rd!' with uppercase, numbers, and special characters.", + "Mrs. O'Brien's restaurant serves crème brûlée and jalapeño poppers as appetizers.", + "The cryptocurrency wallet address is 0x742d35Cc6634C0532925a3b844Bc9e7595f0bEb5.", + "Tokyo's coordinates are 35.6762° N latitude and 139.6503° E longitude.", + "The RESTful API uses OAuth 2.0 authentication with JWT bearer tokens.", + "Linux kernel version 6.5.13 includes patches for CVE-2024-12345 vulnerability.", + "The chemical formula for sulfuric acid is H₂SO₄, commonly used in batteries.", + "Amazon Web Services S3 bucket 'prod-data-backup-2024' exceeded 15TB storage.", + "The SQL query 'SELECT * FROM users WHERE id IN (1,2,3)' returned zero rows.", + "Dr. García's Ph.D. thesis explored non-Euclidean geometry in n-dimensional spaces.", + "The HTTPS certificate for *.mydomain.com expires on March 31st at 11:59 PM.", + "UNESCO's World Heritage site #1347 was designated in Kyoto, Japan.", + "The docker-compose.yml file defines three microservices: frontend, backend, and Redis cache.", + "Mr. Müller's BMW X5 accelerates from 0 to 60 mph in just 4.3 seconds.", + "The grep command 'grep -r TODO *.py | wc -l' found 237 occurrences.", + "Wellington, New Zealand's capital, experiences winds exceeding 50 km/h regularly.", + "The TCP/IP packet loss rate increased to 12.4% during the 10 PM network spike.", + "Einstein's E=mc² equation relates energy, mass, and the speed of light squared.", + "The JSON payload contains nested arrays: {'users': [{'id': 1, 'name': 'Alice'}]}.", +] + +ITALIAN_PROMPTS = [ + "Kokoro è una serie di modelli TTS leggeri ma potenti.", + "Oggi il tempo è bello, andiamo a fare una passeggiata nel parco.", + "Per favore, invia questo rapporto entro le tre del pomeriggio.", + "L'intelligenza artificiale sta cambiando il modo in cui lavoriamo e viviamo.", + "Questo modello supporta più lingue, inclusi inglese, cinese e giapponese.", + "Se si verifica un errore, prova di nuovo o contatta l'amministratore.", + "La riunione inizierà alle nove in punto martedì prossimo.", + "Per favore, conferma che l'indirizzo di spedizione e il numero di telefono siano corretti.", + "Lo stato del progetto deve essere aggiornato settimanalmente e inviato al responsabile.", + "Per migliorare le prestazioni, è necessario ottimizzare la pipeline di inferenza.", + "Il sistema è stato distribuito con successo nell'ambiente di produzione.", + "Questa frase viene utilizzata per testare la chiarezza della sintesi vocale italiana.", + "Si prega di leggere e accettare i termini di servizio e l'informativa sulla privacy.", + "L'audio di output del modello dovrebbe essere naturale e facile da comprendere.", + "Il backup dei dati è stato completato, verificare i log per confermare i risultati.", + "Si prega di completare il riavvio del sistema entro due minuti.", + "Il feedback dei clienti mostra una latenza elevata, è necessario controllare il collegamento di rete.", + "Oggi è venerdì, ricorda di inviare il riepilogo del lavoro di questa settimana.", + "Questa funzione è entrata nella fase di rilascio graduato.", + "Nel log compaiono più errori di timeout, si prega di controllare i servizi dipendenti.", + "Si prega di aggiornare il numero di versione a 1.2.3 e generare le note di rilascio.", + "Il carico del sistema è troppo elevato, si consiglia di espandere temporaneamente due istanze.", + "Si prega di confermare che il codice di verifica sia stato inviato al telefono dell'utente.", + "Dopo il backup del database, verificare l'integrità del backup.", + "L'ottimizzazione del tasso di hit della cache può migliorare significativamente la velocità di risposta.", + "L'inferenza del modello impiega troppo tempo, è necessario indagare il collo di bottiglia.", + "Stiamo pianificando il lancio di un nuovo processo di pagamento il mese prossimo.", + "Si prega di verificare le informazioni della fattura per assicurarsi che l'importo corrisponda all'ordine.", + "Il dispositivo è offline da più di 30 minuti, si prega di controllare l'alimentazione.", + "L'input dell'utente contiene caratteri speciali, si prega di gestire il filtraggio correttamente.", + "Si prega di confermare che la documentazione dell'interfaccia sia stata sincronizzata.", + "La riunione di oggi è annullata, riprogrammata per mercoledì alle 10.", + "Questa funzione supporta il cambio di più lingue, si prega di verificare la visualizzazione italiana.", + "Si prega di controllare se l'email è stata identificata erroneamente come spam.", + "Si prega di verificare la correzione nell'ambiente di test prima di unire al ramo principale.", +] + +SPANISH_PROMPTS = [ + "Kokoro es una serie de modelos TTS ligeros pero potentes.", + "Hoy el clima es agradable, vamos a pasear por el parque.", + "Por favor, envía este informe antes de las tres de la tarde.", + "La inteligencia artificial está cambiando la forma en que trabajamos y vivimos.", + "Este modelo admite varios idiomas, incluidos inglés, chino y japonés.", + "Si ocurre un error, intenta de nuevo o comunícate con el administrador.", + "La reunión comenzará puntualmente a las nueve el martes próximo.", + "Por favor, confirma que tu dirección de envío y número de teléfono sean correctos.", + "El progreso del proyecto debe actualizarse semanalmente y enviarse al gerente.", + "Para mejorar el rendimiento, necesitamos optimizar la tubería de inferencia.", + "El sistema se ha implementado correctamente en el entorno de producción.", + "Esta oración se utiliza para probar la claridad de la síntesis de voz en español.", + "Por favor, lee y acepta los términos de servicio y la política de privacidad.", + "La salida de audio del modelo debe ser natural y fácil de entender.", + "La copia de seguridad de datos se ha completado, verifica los registros para confirmar los resultados.", + "Por favor, completa el reinicio del sistema en dos minutos.", + "Los comentarios de los clientes muestran alta latencia, necesitamos verificar el enlace de red.", + "Hoy es viernes, recuerda enviar el resumen del trabajo de esta semana.", + "Esta función ha entrado en fase de lanzamiento gradual.", + "Varios errores de tiempo de espera aparecen en los registros, verifica los servicios dependientes.", + "Por favor, actualiza el número de versión a 1.2.3 y genera las notas de la versión.", + "La carga del sistema es demasiado alta, se sugiere expandir temporalmente dos instancias.", + "Por favor, confirma que el código de verificación se haya enviado al teléfono del usuario.", + "Después de completar la copia de seguridad de la base de datos, verifica la integridad.", + "Optimizar la tasa de aciertos de caché puede mejorar significativamente la velocidad de respuesta.", + "La inferencia del modelo toma demasiado tiempo, necesitamos investigar el cuello de botella.", + "Planeamos lanzar un nuevo proceso de pago el próximo mes.", + "Por favor, verifica la información de la factura para asegurar que el monto coincida con el pedido.", + "El dispositivo ha estado desconectado durante más de 30 minutos, verifica el suministro de energía.", + "La entrada del usuario contiene caracteres especiales, maneja el filtrado correctamente.", + "Por favor, confirma que la documentación de la interfaz se haya sincronizado.", + "La reunión de hoy se ha cancelado, reprogramada para el miércoles a las 10 de la mañana.", + "Esta función admite el cambio de varios idiomas, verifica la pantalla en español.", + "Por favor, verifica si el correo ha sido identificado erróneamente como spam.", + "Por favor, verifica la corrección en el entorno de prueba antes de fusionar con la rama principal.", +] + +GERMAN_PROMPTS = [ + "Kokoro ist eine Reihe leichter, aber leistungsstarker TTS-Modelle.", + "Heute ist das Wetter schön, lass uns im Park spazieren gehen.", + "Bitte sende diesen Bericht bis drei Uhr nachmittags.", + "Künstliche Intelligenz verändert unsere Arbeits- und Lebensweise.", + "Dieses Modell unterstützt mehrere Sprachen, darunter Englisch, Chinesisch und Japanisch.", + "Wenn ein Fehler auftritt, versuche es erneut oder kontaktiere den Administrator.", + "Die Besprechung beginnt nächsten Dienstag pünktlich um neun Uhr.", + "Bitte bestätige, dass deine Lieferadresse und Telefonnummer korrekt sind.", + "Der Projektfortschritt muss wöchentlich aktualisiert und an den Manager gesendet werden.", + "Um die Leistung zu verbessern, müssen wir die Inferenz-Pipeline optimieren.", + "Das System wurde erfolgreich in der Produktionsumgebung bereitgestellt.", + "Dieser Satz wird verwendet, um die Klarheit der deutschen Sprachsynthese zu testen.", + "Bitte lies und akzeptiere die Nutzungsbedingungen und die Datenschutzrichtlinie.", + "Die Audioausgabe des Modells sollte natürlich und leicht verständlich sein.", + "Die Datensicherung wurde abgeschlossen, bitte prüfe die Protokolle zur Bestätigung.", + "Bitte führe den Neustart des Systems innerhalb von zwei Minuten durch.", + "Das Kundenfeedback zeigt eine hohe Latenz, wir müssen die Netzwerkverbindung überprüfen.", + "Heute ist Freitag, denk daran, die Zusammenfassung dieser Woche einzureichen.", + "Diese Funktion befindet sich jetzt in der schrittweisen Freigabephase.", + "In den Protokollen erscheinen mehrere Zeitüberschreitungsfehler, bitte prüfe abhängige Dienste.", + "Bitte aktualisiere die Versionsnummer auf 1.2.3 und erstelle die Versionshinweise.", + "Die Systemlast ist zu hoch, es wird empfohlen, vorübergehend zwei Instanzen zu erweitern.", + "Bitte bestätige, dass der Bestätigungscode an die Telefonnummer des Benutzers gesendet wurde.", + "Nach Abschluss der Datenbanksicherung bitte die Integrität des Backups prüfen.", + "Die Optimierung der Cache-Trefferquote kann die Reaktionsgeschwindigkeit deutlich verbessern.", + "Die Modellinferenz dauert zu lange, wir müssen den Engpass untersuchen.", + "Wir planen, im nächsten Monat einen neuen Zahlungsprozess einzuführen.", + "Bitte überprüfe die Rechnungsinformationen, damit der Betrag mit der Bestellung übereinstimmt.", + "Das Gerät ist seit mehr als 30 Minuten offline, bitte prüfe die Stromversorgung.", + "Die Benutzereingabe enthält Sonderzeichen, bitte behandle die Filterung korrekt.", + "Bitte bestätige, dass die Schnittstellendokumentation synchronisiert wurde.", + "Das heutige Meeting wurde abgesagt und auf Mittwoch um zehn Uhr verschoben.", + "Diese Funktion unterstützt den Wechsel zwischen mehreren Sprachen, bitte prüfe die deutsche Anzeige.", + "Bitte prüfe, ob die E-Mail fälschlicherweise als Spam markiert wurde.", + "Bitte verifiziere die Korrektur in der Testumgebung, bevor du in den Hauptzweig zusammenführst.", +] + +CHINESE_PUNCT = ",。!?;:、""''()《》【】—…·、" +LATIN_PUNCT = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" + + +def get_prompts_for_language(language: str) -> list[str]: + """Get prompt list based on language code.""" + language = language.lower() + if language in ("zh", "zh-cn"): + return CHINESE_PROMPTS + elif language in ("en", "en-us", "en-gb"): + return ENGLISH_PROMPTS + elif language in ("it", "it-it"): + return ITALIAN_PROMPTS + elif language in ("es", "es-es"): + return SPANISH_PROMPTS + elif language in ("de", "de-de"): + return GERMAN_PROMPTS + else: + # Default to English for unknown languages + return ENGLISH_PROMPTS + + +def normalize_text(text: str, language: str = "en") -> str: + if not text: + return "" + text = text.strip().lower() + + language = language.lower() + if language in ("zh", "zh-cn"): + # Chinese: remove Chinese punctuation and whitespace + remove_chars = set(CHINESE_PUNCT) + remove_chars.update({" ", "\t", "\n", "\r"}) + for ch in "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~": + remove_chars.add(ch) + else: + # English, Italian, Spanish, German: remove Latin punctuation and whitespace + remove_chars = set(LATIN_PUNCT) + remove_chars.update({" ", "\t", "\n", "\r"}) + # Also remove any Chinese punctuation that might appear + remove_chars.update(set(CHINESE_PUNCT)) + + return "".join(ch for ch in text if ch not in remove_chars) + + +def similarity(a: str, b: str) -> float: + return SequenceMatcher(a=a, b=b).ratio() + + +def cer(reference: str, hypothesis: str) -> float: + if not reference and not hypothesis: + return 0.0 + if not reference: + return 1.0 + if not hypothesis: + return 1.0 + + ref_len = len(reference) + hyp_len = len(hypothesis) + prev = list(range(hyp_len + 1)) + curr = [0] * (hyp_len + 1) + + for i in range(1, ref_len + 1): + curr[0] = i + r_char = reference[i - 1] + for j in range(1, hyp_len + 1): + h_char = hypothesis[j - 1] + cost = 0 if r_char == h_char else 1 + curr[j] = min( + prev[j] + 1, + curr[j - 1] + 1, + prev[j - 1] + cost, + ) + prev, curr = curr, prev + + distance = prev[hyp_len] + return distance / ref_len + + +def tts_request(endpoint: str, model: str, voice: str, prompt: str, language: str) -> bytes: + url = endpoint.rstrip("/") + "/audio/speech" + payload = { + "model": model, + "voice": voice, + "input": prompt, + } + if language: + payload["language"] = language + data = json.dumps(payload).encode("utf-8") + req = Request(url, data=data, headers={"Content-Type": "application/json"}) + with urlopen(req, timeout=120) as resp: + return resp.read() + + +def split_text_into_chunks(text: str, max_chars: int) -> list[str]: + if max_chars <= 0: + return [text] + text = text.strip() + if len(text) <= max_chars: + return [text] + + sentences = [] + buf = [] + for ch in text: + buf.append(ch) + if ch in "。!?;\n": + sentence = "".join(buf).strip() + if sentence: + sentences.append(sentence) + buf = [] + if buf: + sentence = "".join(buf).strip() + if sentence: + sentences.append(sentence) + + chunks = [] + current = "" + for s in sentences: + if not current: + current = s + continue + if len(current) + len(s) <= max_chars: + current += s + else: + chunks.append(current) + current = s + if current: + chunks.append(current) + + if not chunks: + chunks = [text[i : i + max_chars] for i in range(0, len(text), max_chars)] + return chunks + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Send prompts to TTS, then transcribe and compare results." + ) + parser.add_argument("--endpoint", required=True, help="Base URL, e.g. http://localhost:8122/v3") + parser.add_argument("--tts-model", default="kokoro", help="TTS model name") + parser.add_argument("--asr-model", default="whisper", help="ASR model name") + parser.add_argument("--voice", default=None, help="Voice name") + parser.add_argument("--language", default="en", help="Language code (default: en, options: en, zh, it, es, de)") + parser.add_argument("--limit", type=int, default=5, help="Number of prompts to test") + parser.add_argument("--seed", type=int, default=42, help="Random seed for sampling") + parser.add_argument("--output-dir", default="tts_asr_output", help="Output directory") + parser.add_argument("--save-audio", action="store_true", help="Save WAV files to output directory") + parser.add_argument("--text", default=None, help="Single text to send (overrides prompt list)") + parser.add_argument("--text-file", default=None, help="Path to a text file to send (overrides prompt list)") + parser.add_argument("--max-chars", type=int, default=300, help="Max chars per TTS request for single text") + args = parser.parse_args() + + if args.limit <= 0: + print("--limit must be > 0", file=sys.stderr) + return 2 + + if args.text and args.text_file: + print("Use only one of --text or --text-file", file=sys.stderr) + return 2 + + if args.text_file: + try: + with open(args.text_file, "r", encoding="utf-8") as f: + single_text = f.read().strip() + except OSError as exc: + print(f"Failed to read --text-file: {exc}", file=sys.stderr) + return 2 + if not single_text: + print("--text-file is empty", file=sys.stderr) + return 2 + prompts = split_text_into_chunks(single_text, args.max_chars) + elif args.text: + prompts = split_text_into_chunks(args.text, args.max_chars) + else: + # Get prompts based on language + prompts = get_prompts_for_language(args.language) + if args.limit < len(prompts): + random.seed(args.seed) + prompts = random.sample(prompts, args.limit) + + out_dir = Path(args.output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + client = OpenAI(base_url=args.endpoint, api_key="unused") + + results = [] + total_tts_time = 0.0 + total_asr_time = 0.0 + + for idx, prompt in enumerate(prompts, start=1): + print(f"[{idx}/{len(prompts)}] {prompt}") + + wav_path = out_dir / f"{idx:02d}.wav" + try: + tts_start = time.time() + audio_bytes = tts_request( + endpoint=args.endpoint, + model=args.tts_model, + voice=args.voice, + prompt=prompt, + language=args.language, + ) + tts_time = time.time() - tts_start + total_tts_time += tts_time + + with open(wav_path, "wb") as f: + f.write(audio_bytes) + except Exception as exc: + print(f" TTS failed: {exc}", file=sys.stderr) + results.append({ + "prompt": prompt, + "transcript": "", + "similarity": 0.0, + "tts_time": 0.0, + "asr_time": 0.0, + "error": f"tts: {exc}", + }) + continue + + try: + asr_start = time.time() + with open(wav_path, "rb") as audio_file: + transcript = client.audio.transcriptions.create( + model=args.asr_model, + file=audio_file, + ) + asr_time = time.time() - asr_start + total_asr_time += asr_time + transcript_text = transcript.text or "" + except Exception as exc: + print(f" ASR failed: {exc}", file=sys.stderr) + results.append({ + "prompt": prompt, + "transcript": "", + "similarity": 0.0, + "tts_time": tts_time, + "asr_time": 0.0, + "error": f"asr: {exc}", + }) + if not args.save_audio and wav_path.exists(): + wav_path.unlink(missing_ok=True) + continue + + n_prompt = normalize_text(prompt, args.language) + n_trans = normalize_text(transcript_text, args.language) + score = similarity(n_prompt, n_trans) + cer_score = cer(n_prompt, n_trans) + + results.append({ + "prompt": prompt, + "transcript": transcript_text, + "similarity": score, + "cer": cer_score, + "tts_time": tts_time, + "asr_time": asr_time, + "error": "", + }) + + print(f" Transcript: {transcript_text}") + print(f" Similarity: {score:.3f}") + print(f" CER: {cer_score:.3f}") + print(f" TTS time: {tts_time:.3f}s, ASR time: {asr_time:.3f}s\n") + + if not args.save_audio and wav_path.exists(): + wav_path.unlink(missing_ok=True) + + if results: + avg = sum(r["similarity"] for r in results) / len(results) + avg_cer = sum(r.get("cer", 1.0) for r in results) / len(results) + else: + avg = 0.0 + avg_cer = 1.0 + + exact = sum(1 for r in results if normalize_text(r["prompt"], args.language) == normalize_text(r["transcript"], args.language)) + + avg_tts_time = total_tts_time / len(results) if results else 0.0 + avg_asr_time = total_asr_time / len(results) if results else 0.0 + + print("=" * 60) + print(f"Completed {len(results)} items") + print(f"Exact matches: {exact}") + print(f"Average similarity: {avg:.3f}") + print(f"Average CER: {avg_cer:.3f}") + print(f"Average TTS time: {avg_tts_time:.3f}s") + print(f"Average ASR time: {avg_asr_time:.3f}s") + print(f"Total TTS time: {total_tts_time:.3f}s") + print(f"Total ASR time: {total_asr_time:.3f}s") + print(f"Total processing time: {total_tts_time + total_asr_time:.3f}s") + print("Output directory:", out_dir) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main())