diff --git a/Makefile b/Makefile index 8ad322def80..3c0eac14bce 100644 --- a/Makefile +++ b/Makefile @@ -91,7 +91,7 @@ # # ============================================================================== -.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu qwen3_5_moe-cuda qwen3_5_moe-metal clean help +.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu qwen3_5_moe-cuda qwen3_5_moe-metal clean help help: @echo "This Makefile adds targets to build runners for various models on various backends. Run using \`make \`. Available targets:" @@ -103,6 +103,8 @@ help: @echo " voxtral_realtime-cpu - Build Voxtral Realtime runner with CPU backend" @echo " voxtral_realtime-metal - Build Voxtral Realtime runner with Metal backend (macOS only)" @echo " voxtral_realtime-mlx - Build Voxtral Realtime runner with MLX backend" + @echo " voxtral_tts-cpu - Build Voxtral TTS runner (CPU)" + @echo " voxtral_tts-cuda - Build Voxtral TTS runner with CUDA backend" @echo " whisper-cuda - Build Whisper runner with CUDA backend" @echo " whisper-cuda-debug - Build Whisper runner with CUDA backend (debug mode)" @echo " whisper-cpu - Build Whisper runner with CPU backend" @@ -396,6 +398,24 @@ gemma3-cpu: @echo "✓ Build complete!" @echo " Binary: cmake-out/examples/models/gemma3/gemma3_e2e_runner" +voxtral_tts-cpu: + @echo "==> Building and installing ExecuTorch..." + cmake --workflow --preset llm-release + @echo "==> Building Voxtral TTS runner (CPU)..." + cd examples/models/voxtral_tts && cmake --workflow --preset voxtral-tts-cpu + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/voxtral_tts/voxtral_tts_runner" + +voxtral_tts-cuda: + @echo "==> Building and installing ExecuTorch with CUDA..." + cmake --workflow --preset llm-release-cuda + @echo "==> Building Voxtral TTS runner with CUDA..." + cd examples/models/voxtral_tts && cmake --workflow --preset voxtral-tts-cuda + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/voxtral_tts/voxtral_tts_runner" + qwen3_5_moe-cuda: @echo "==> Building and installing ExecuTorch with CUDA..." cmake --workflow --preset llm-release-cuda diff --git a/examples/models/voxtral_tts/CMakeLists.txt b/examples/models/voxtral_tts/CMakeLists.txt new file mode 100644 index 00000000000..bfd17cd9810 --- /dev/null +++ b/examples/models/voxtral_tts/CMakeLists.txt @@ -0,0 +1,106 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.24) +project(voxtral_tts) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..) + +include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) + +set(_common_include_directories ${EXECUTORCH_ROOT}/..) + +# gflags +set(gflags_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/gflags) +find_package(gflags REQUIRED) + +# ExecuTorch +list(APPEND CMAKE_FIND_ROOT_PATH ${CMAKE_CURRENT_BINARY_DIR}/../../..) +find_package(executorch CONFIG REQUIRED FIND_ROOT_PATH_BOTH) +executorch_target_link_options_shared_lib(executorch) + +set(link_libraries executorch gflags) + +# Common ops +list(APPEND link_libraries optimized_native_cpu_ops_lib cpublas eigen_blas) +executorch_target_link_options_shared_lib(optimized_native_cpu_ops_lib) + +# CPU-only builds need quantized and custom ops +if(NOT EXECUTORCH_BUILD_CUDA) + list(APPEND link_libraries quantized_ops_lib custom_ops) + executorch_target_link_options_shared_lib(quantized_ops_lib) + executorch_target_link_options_shared_lib(custom_ops) +endif() + +# XNNPACK +if(TARGET xnnpack_backend) + set(xnnpack_backend_libs xnnpack_backend XNNPACK xnnpack-microkernels-prod) + if(TARGET kleidiai) + list(APPEND xnnpack_backend_libs kleidiai) + endif() + list(APPEND link_libraries ${xnnpack_backend_libs}) + executorch_target_link_options_shared_lib(xnnpack_backend) +endif() + +# LLM runner extension +if(NOT TARGET extension_llm_runner) + message( + FATAL_ERROR + "ExecuTorch must be installed with EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER enabled." + ) +endif() + +if(ANDROID) + list(APPEND link_libraries log) +endif() + +list( + APPEND + link_libraries + extension_llm_runner + extension_module + extension_data_loader + extension_tensor + extension_flat_tensor +) + +# CUDA backend +if(EXECUTORCH_BUILD_CUDA) + find_package(CUDAToolkit REQUIRED) + list(APPEND link_libraries aoti_cuda_backend) + if(NOT MSVC) + executorch_target_link_options_shared_lib(aoti_cuda_backend) + endif() +endif() + +# Metal backend +if(EXECUTORCH_BUILD_METAL) + list(APPEND link_libraries metal_backend) + executorch_target_link_options_shared_lib(metal_backend) +endif() + +# Tokenizer +list(APPEND link_libraries tokenizers::tokenizers) + +add_executable( + voxtral_tts_runner main.cpp voxtral_tts_runner.cpp wav_writer.cpp +) +if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") + target_link_options_gc_sections(voxtral_tts_runner) + if(NOT APPLE AND NOT MSVC) + target_link_options(voxtral_tts_runner PRIVATE "LINKER:-s") + endif() +endif() + +target_include_directories( + voxtral_tts_runner PUBLIC ${_common_include_directories} + ${EXECUTORCH_ROOT}/third-party/json/include +) +target_link_libraries(voxtral_tts_runner PUBLIC ${link_libraries}) +target_compile_options(voxtral_tts_runner PUBLIC ${_common_compile_options}) diff --git a/examples/models/voxtral_tts/CMakePresets.json b/examples/models/voxtral_tts/CMakePresets.json new file mode 100644 index 00000000000..424c84c5730 --- /dev/null +++ b/examples/models/voxtral_tts/CMakePresets.json @@ -0,0 +1,90 @@ +{ + "version": 6, + "configurePresets": [ + { + "name": "voxtral-tts-base", + "hidden": true, + "binaryDir": "${sourceDir}/../../../cmake-out/examples/models/voxtral_tts", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_FIND_ROOT_PATH": "${sourceDir}/../../../cmake-out", + "CMAKE_PREFIX_PATH": "${sourceDir}/../../../cmake-out" + } + }, + { + "name": "voxtral-tts-cpu", + "displayName": "Voxtral TTS runner (CPU)", + "inherits": [ + "voxtral-tts-base" + ] + }, + { + "name": "voxtral-tts-cuda", + "displayName": "Voxtral TTS runner (CUDA)", + "inherits": [ + "voxtral-tts-base" + ], + "cacheVariables": { + "EXECUTORCH_BUILD_CUDA": "ON" + }, + "condition": { + "type": "inList", + "string": "${hostSystemName}", + "list": [ + "Linux", + "Windows" + ] + } + } + ], + "buildPresets": [ + { + "name": "voxtral-tts-cpu", + "displayName": "Build Voxtral TTS runner (CPU)", + "configurePreset": "voxtral-tts-cpu", + "configuration": "Release", + "targets": [ + "voxtral_tts_runner" + ] + }, + { + "name": "voxtral-tts-cuda", + "displayName": "Build Voxtral TTS runner (CUDA)", + "configurePreset": "voxtral-tts-cuda", + "configuration": "Release", + "targets": [ + "voxtral_tts_runner" + ] + } + ], + "workflowPresets": [ + { + "name": "voxtral-tts-cpu", + "displayName": "Voxtral TTS (CPU)", + "steps": [ + { + "type": "configure", + "name": "voxtral-tts-cpu" + }, + { + "type": "build", + "name": "voxtral-tts-cpu" + } + ] + }, + { + "name": "voxtral-tts-cuda", + "displayName": "Voxtral TTS (CUDA)", + "steps": [ + { + "type": "configure", + "name": "voxtral-tts-cuda" + }, + { + "type": "build", + "name": "voxtral-tts-cuda" + } + ] + } + ] +} diff --git a/examples/models/voxtral_tts/README.md b/examples/models/voxtral_tts/README.md new file mode 100644 index 00000000000..ea7706cb367 --- /dev/null +++ b/examples/models/voxtral_tts/README.md @@ -0,0 +1,275 @@ +# Voxtral TTS + +Self-contained ExecuTorch implementation of +[Voxtral-4B-TTS-2603](https://huggingface.co/mistralai/Voxtral-4B-TTS-2603), +a ~4B parameter text-to-speech model that produces 24 kHz mono audio from +text. Weights are loaded directly from the HuggingFace safetensors +checkpoint. Supports CPU (portable + XNNPACK) and CUDA backends. With `--streaming`, the CUDA 4w export runs at **RTF 0.31x on RTX 5080 — 3× faster than real-time** with 2.6 s time-to-first-audio. + +## Overview + +The pipeline has two stages: **export** (Python, once) and **inference** +(C++ runner, repeated). Export converts the HuggingFace checkpoint into a +`model.pte` (LM + flow head, 5 methods) plus a `codec_decoder.pte`. At +inference time, the C++ runner loads both `.pte` files, the tokenizer, and a +voice embedding, then synthesizes a `.wav` file. + +The model has three components: +1. **Mistral 4B LLM decoder** — autoregressive text-to-hidden-states +2. **Flow Matching Head** (3-layer transformer) — hidden states to 37 + audio codebook tokens per frame via a 7-step Euler ODE +3. **Codec Decoder** (Conv1d / ConvTranspose1d stack + 8 transformer + layers) — codebook tokens to 24 kHz waveform + +## Prerequisites + +- ExecuTorch installed from source (see [building from source](../../../docs/source/using-executorch-building-from-source.md)) +- Model weights downloaded from HuggingFace. The directory should contain + `params.json`, `consolidated.safetensors`, `tekken.json`, and + `voice_embedding/` with one or more `.pt` voice files. + ```bash + huggingface-cli download mistralai/Voxtral-4B-TTS-2603 \ + --local-dir ~/models/Voxtral-4B-TTS-2603 + ``` +- For CUDA: NVIDIA GPU with CUDA 12.8 or 12.9 toolkit (tested on A100 80GB + / sm_80 and RTX 5080 / sm_120). + Note: CUDA 13 is not supported (CUB 3.0 incompatibility in + `backends/cuda/runtime/shims/sort.cu`). + +## Export + +Export produces `model.pte` and `codec_decoder.pte`. For CUDA, it also +produces `aoti_cuda_blob.ptd` and `codec_aoti_cuda_blob.ptd` containing +the compiled CUDA kernels and weights. + +```bash +# CPU (XNNPACK, FP32) +python export_voxtral_tts.py \ + --model-path ~/models/Voxtral-4B-TTS-2603 \ + --backend xnnpack \ + --output-dir ./voxtral_tts_exports + +# CUDA, 4-bit weight-only quant (recommended — sub-real-time on A100) +python export_voxtral_tts.py \ + --model-path ~/models/Voxtral-4B-TTS-2603 \ + --backend cuda \ + --qlinear 4w \ + --output-dir ./voxtral_tts_exports_cuda_4w +``` + +`--dtype` is auto-promoted to `bf16` and `--qlinear-packing-format` is +auto-set to `tile_packed_to_4d` when `--backend cuda --qlinear 4w` is +selected (required by AOTI's `_weight_int4pack_mm` kernel). + +### Options + +| Flag | Default | Description | +|------|---------|-------------| +| `--model-path` | (required) | Local directory with `params.json` + `consolidated.safetensors` | +| `--backend` | `xnnpack` | `portable`, `xnnpack`, `cuda`, `cuda-windows` | +| `--dtype` | `fp32` | Model dtype: `fp32` or `bf16` (auto-promoted to bf16 when CUDA + `--qlinear`) | +| `--output-dir` | `./voxtral_tts_exports` | Output directory | +| `--max-seq-len` | `4096` | KV cache length | +| `--qlinear` | (none) | Linear layer quantization: `4w`, `8w`, `8da4w`, `8da8w` | +| `--qlinear-group-size` | `32` | Group size for linear quantization | +| `--qlinear-packing-format` | (auto) | `tile_packed_to_4d` (auto-set for CUDA + 4w) | +| `--decoder-qlinear-scope` | `all` | Scope decoder quant to `all`, `attention`, `feed_forward`, or `none` | +| `--qlinear-codec` | (none) | Quantize codec decoder linears: `4w`, `8w` | +| `--qembedding` | (none) | Embedding quantization: `4w`, `8w` (XNNPACK: not yet supported) | +| `--streaming` | off | Enable streaming codec chunking metadata | + +### CUDA quantization configs + +Validated on A100, `seed=42`, `"Hello, how are you today?"`: + +| Config | model.ptd | LM time | Total wall | E2E RTF | Notes | +|---|---|---|---|---|---| +| `--backend cuda` | 15.8 GB | 11.5 s | 178 s | 51x | FP32 weights, codec on portable CPU | +| **`--backend cuda --qlinear 4w`** | **3.4 GB** | **2.1 s** | **3.7 s** | **0.88x** ⚡ | int4 weights, codec on CUDA | + + +### Streaming + +`--streaming` emits codec chunks as they are decoded rather than batching the +full audio at the end. The first chunk arrives in ~0.4 s of audio (short +prefill delay), then 2 s chunks follow continuously. This decouples +time-to-first-audio from total synthesis length and enables live piped playback. + +Measured on RTX 5080 (sm_120, warm Triton autotune cache): + +| Prompt | Audio | Wall clock | **RTF** | Time-to-first | +|---|---|---|---|---| +| 24 tokens | 10.3 s | 3.85 s | **0.31x** ⚡⚡ (~3.2× real-time) | ~2.6 s | + +Live playback (pipe raw f32le PCM to `ffplay` or `aplay`): + +```bash +unset CPATH +export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH + +cmake-out/examples/models/voxtral_tts/voxtral_tts_runner \ + --model voxtral_tts_exports_cuda_4w/model.pte \ + --data_path voxtral_tts_exports_cuda_4w/aoti_cuda_blob.ptd \ + --codec voxtral_tts_exports_cuda_4w/codec_decoder.pte \ + --codec_data_path voxtral_tts_exports_cuda_4w/codec_aoti_cuda_blob.ptd \ + --tokenizer ~/models/Voxtral-4B-TTS-2603/tekken.json \ + --voice ~/models/Voxtral-4B-TTS-2603/voice_embedding/neutral_female.pt \ + --text "Hello, how are you today?" \ + --streaming --speaker \ + | ffplay -f f32le -ar 24000 -ac 1 -nodisp -autoexit - +``` + +Or `aplay`: replace `| ffplay ...` with `| aplay -f FLOAT_LE -r 24000 -c 1`. + +### XNNPACK quantization configs + +| Config | Scope | model.pte | RTF (long prompt) | +|---|---|---|---| +| `--qlinear 8da4w --decoder-qlinear-scope feed_forward` | FFN only | 7.0 GB | 2.6x | +| `--qlinear 8da8w` | all decoder | 5.7 GB | 1.9x | +| `--qlinear 8da4w` | all decoder | 4.3 GB | 2.0x | + +## Build + +ExecuTorch must be installed from source first (see +[Prerequisites](#prerequisites)). The `make` target handles building the +core libraries and the runner binary. + +```bash +# CUDA (recommended) +make voxtral_tts-cuda + +# CPU +make voxtral_tts-cpu +``` + +This builds ExecuTorch with the requested backend, then the runner binary +at `cmake-out/examples/models/voxtral_tts/voxtral_tts_runner`. + +## Run + +The runner requires: +- `model.pte` — exported LM + flow head (see [Export](#export)) +- `codec_decoder.pte` — exported codec +- `tekken.json` — tokenizer from the model weights directory +- A `.pt` voice embedding from the model's `voice_embedding/` directory + +For CUDA also pass `--data_path` and `--codec_data_path` for the AOTI +delegate `.ptd` files. + +```bash +# CUDA, full pipeline +unset CPATH # see Troubleshooting +export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH + +cmake-out/examples/models/voxtral_tts/voxtral_tts_runner \ + --model voxtral_tts_exports_cuda_4w/model.pte \ + --data_path voxtral_tts_exports_cuda_4w/aoti_cuda_blob.ptd \ + --codec voxtral_tts_exports_cuda_4w/codec_decoder.pte \ + --codec_data_path voxtral_tts_exports_cuda_4w/codec_aoti_cuda_blob.ptd \ + --tokenizer ~/models/Voxtral-4B-TTS-2603/tekken.json \ + --voice ~/models/Voxtral-4B-TTS-2603/voice_embedding/neutral_female.pt \ + --text "Hello, how are you today?" \ + --output output.wav \ + --seed 42 \ + --max_new_tokens 200 +``` + +Output is **24 kHz mono 16-bit PCM**. Listen with `ffplay output.wav` or +`aplay output.wav`. + +Or use the one-shot script that does export + build + run end to end: + +```bash +bash examples/models/voxtral_tts/run_cuda_e2e.sh ~/models/Voxtral-4B-TTS-2603 +``` + +### Options + +| Flag | Default | Description | +|------|---------|-------------| +| `--model` | `model.pte` | Path to exported `model.pte` (LM + flow head) | +| `--data_path` | (none) | Path to LM `.ptd` (required for CUDA) | +| `--codec` | `codec_decoder.pte` | Path to exported codec `.pte` | +| `--codec_data_path` | (none) | Path to codec `.ptd` (required for CUDA codec export) | +| `--tokenizer` | `tekken.json` | Path to tokenizer JSON from the base model | +| `--voice` | (required) | Path to voice embedding `.pt` | +| `--text` | (required) | Prompt text to synthesize | +| `--output` | `output.wav` | Output WAV file path | +| `--seed` | `42` | RNG seed (semantic sampling + flow noise) | +| `--temperature` | `0.0` | Sampling temperature (0 = greedy) | +| `--max_new_tokens` | `2048` | Max audio frames to generate (~12.5 frames/sec) | +| `--streaming` | off | Chunked codec emission for lower per-chunk latency | +| `--speaker` | off | Pipe raw f32le PCM to stdout (e.g. `... --speaker \| aplay -f FLOAT_LE -r 24000 -c 1`) | + +### Available voices + +`neutral_female`, `neutral_male`, `casual_female`, `casual_male`, +`cheerful_female`, `ar_male`, `de_female`, `de_male`, `es_female`, +`es_male`, `fr_female`, `fr_male` (under `voice_embedding/` in the model +directory). + +## License + +The ExecuTorch example code in this directory is licensed under the +BSD-style license found in the repository root. + +The [Voxtral-4B-TTS-2603](https://huggingface.co/mistralai/Voxtral-4B-TTS-2603) +model weights and bundled voice embeddings are licensed by Mistral AI under +[CC BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/) +(Creative Commons Attribution-NonCommercial 4.0 International). This means: + +- **Attribution required**: credit Mistral AI and the voice data sources + (EARS, CML-TTS, IndicVoices-R, Arabic Natural Audio datasets). +- **NonCommercial only**: the model weights may not be used for commercial + purposes. + +The NC restriction originates from the voice reference datasets used to train +the model. See the +[model card](https://huggingface.co/mistralai/Voxtral-4B-TTS-2603) for +full details. + +## Troubleshooting + +- **`__cudaLaunch was not declared` during build**: `CPATH` is polluted with + CUDA 13's include path. `unset CPATH` and rebuild. CUDA 13's + `crt/host_runtime.h` has a 2-arg `__cudaLaunch` macro incompatible with + nvcc 12.8's stub generation. +- **`GLIBCXX_3.4.30 not found` at runner load time**: AOTI `.so` files + require a newer libstdc++ than `/lib64/libstdc++.so.6`. Set + `LD_LIBRARY_PATH=$CONDA_PREFIX/lib` before launching the runner. +- **`aoti_cuda_backend` target not found at link time**: the parent + ExecuTorch was built without CUDA. Use `make voxtral_tts-cuda` (which + builds with `EXECUTORCH_BUILD_CUDA=ON`) instead of running cmake by hand. +- **`cannot find -lcuda` during `pip install -e .` or export (WSL2)**: the + CUDA toolkit doesn't ship `libcuda.so` — on WSL2 the driver lib lives at + `/usr/lib/wsl/lib/`. Prepend it (or `/usr/local/cuda/lib64/stubs`) to + `LIBRARY_PATH` before invoking pip / the export script. +- **First call takes ~30–50 s**: Triton autotunes the LM matmul kernels on + first run, then caches per-process. The runner's `warmup()` amortizes + this so the first user-visible synth pays the cost once. +- **`pip install -e .` after pulling source changes**: the default + `install_executorch.sh` does `pip install .`. Repo edits to + `examples/models/voxtral_tts/` won't take effect until you reinstall as + editable. + +## Pre-exported artifacts + +For users who want to skip the export step, ready-to-run CUDA artifacts +are available on the HuggingFace hub at +[`younghan-meta/Voxtral-4B-TTS-2603-ExecuTorch-CUDA`](https://huggingface.co/younghan-meta/Voxtral-4B-TTS-2603-ExecuTorch-CUDA). + +ExecuTorch's CUDA backend uses AOTInductor, which bakes pre-compiled +cubins for the export-time GPU's compute capability into `*.ptd`. Cubins +are not compatible across architectures, so the repo ships per-arch +subfolders: + +| Folder | Compute capability | Example GPUs | +|---|---|---| +| `sm80/` | `sm_80` (Ampere) | A100, A30 | +| `sm120/` | `sm_120` (Blackwell) | RTX 5080, RTX 5090 | + +Find your GPU's arch with `nvidia-smi --query-gpu=compute_cap --format=csv`, +then `hf download ... --include 'sm80/*'` (or `sm120`). If your arch isn't +shipped, re-export on the target GPU with the command above — the AOTI +compile step writes cubins for the local arch. diff --git a/examples/models/voxtral_tts/__init__.py b/examples/models/voxtral_tts/__init__.py new file mode 100644 index 00000000000..2e41cd717f6 --- /dev/null +++ b/examples/models/voxtral_tts/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/examples/models/voxtral_tts/export_voxtral_tts.py b/examples/models/voxtral_tts/export_voxtral_tts.py new file mode 100644 index 00000000000..7ebc89bace1 --- /dev/null +++ b/examples/models/voxtral_tts/export_voxtral_tts.py @@ -0,0 +1,720 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Export Voxtral-4B-TTS-2603 to ExecuTorch. + +Produces two .pte files: + model.pte (multi-method, like voxtral_realtime): + - token_embedding: token_ids (1, S) -> embeds (1, S, 3072) + - audio_token_embedding: codes (1, 37, S) -> embeds (1, S, 3072) + - text_decoder: embeds (1, S, 3072) + cache_pos (S,) -> hidden (1, S, 3072) + - semantic_head: hidden (1, 3072) -> code (1,) + - predict_velocity: x_t (1, 36) + t_idx (1,) + hidden (1, 3072) -> v_t (1, 36) + + codec_decoder.pte (single method): + - forward: codes (1, 37, T) -> waveform (1, 1, T*1920) + +Usage: + python export_voxtral_tts.py --model-path ~/models/Voxtral-4B-TTS-2603 --qlinear 4w + python export_voxtral_tts.py --model-path ~/models/Voxtral-4B-TTS-2603 --qlinear 4w --qembedding 4w +""" + +import argparse +import os +from pathlib import Path + +import torch +import torch.nn as nn +from executorch.examples.models.voxtral_tts.model import load_model +from executorch.examples.models.voxtral_tts.voice import load_voice_from_model_dir +from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + to_edge_transform_and_lower, +) +from executorch.exir.passes import MemoryPlanningPass +from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass +from executorch.extension.llm.export.quantize import quantize_model_ +from torch.export import Dim, export + +# --------------------------------------------------------------------------- +# Export wrappers +# --------------------------------------------------------------------------- + + +class TokenEmbeddingExport(nn.Module): + def __init__(self, model): + super().__init__() + self.tok_embeddings = model.decoder.tok_embeddings + + def forward(self, token_ids: torch.Tensor) -> torch.Tensor: + return self.tok_embeddings(token_ids) + + +class AudioTokenEmbeddingExport(nn.Module): + def __init__(self, model): + super().__init__() + self.audio_token_embedding = model.audio_token_embedding + + def forward(self, codes: torch.Tensor) -> torch.Tensor: + return self.audio_token_embedding(codes) + + +class TextDecoderExport(nn.Module): + def __init__(self, model): + super().__init__() + self.decoder = model.decoder + + def forward( + self, input_embeds: torch.Tensor, cache_position: torch.Tensor + ) -> torch.Tensor: + return self.decoder(input_embeds, cache_position) + + +class SemanticHeadExport(nn.Module): + def __init__(self, model): + super().__init__() + self.flow_head = model.flow_head + + def forward(self, hidden: torch.Tensor) -> torch.Tensor: + return self.flow_head.semantic_logits(hidden) + + +class PredictVelocityExport(nn.Module): + def __init__(self, model): + super().__init__() + self.flow_head = model.flow_head + + def forward( + self, + x_t: torch.Tensor, + t_idx: torch.Tensor, + hidden: torch.Tensor, + ) -> torch.Tensor: + return self.flow_head.predict_velocity(x_t, t_idx, hidden) + + +class CodecDecoderExport(nn.Module): + def __init__(self, model): + super().__init__() + self.codec_decoder = model.codec_decoder + + def forward(self, codes: torch.Tensor) -> torch.Tensor: + return self.codec_decoder(codes) + + +# --------------------------------------------------------------------------- +# Quantization policy +# --------------------------------------------------------------------------- + + +def _export_lm_pte(model, args, device: str) -> None: + """Export model.pte (LM + flow head, 5 methods).""" + print("\n" + "=" * 60) + print("Exporting model.pte (5 methods)") + print("=" * 60) + programs, metadata = export_model( + model, + args.max_seq_len, + streaming=args.streaming, + device=device, + ) + et_model = lower_to_executorch(programs, metadata, backend=args.backend) + + model_pte = os.path.join(args.output_dir, "model.pte") + print(f"\nSaving to {model_pte}...") + with open(model_pte, "wb") as f: + et_model.write_to_file(f) + size_mb = os.path.getsize(model_pte) / (1024 * 1024) + print(f"Saved model.pte ({size_mb:.1f} MB)") + + # CUDA backend emits a .ptd containing the AOTI .so + weights. + if et_model._tensor_data: + et_model.write_tensor_data_to_file(args.output_dir) + print(f"Saved model tensor data to {args.output_dir}/") + + +def _export_codec_pte(model, args, device: str) -> None: + """Export codec_decoder.pte (single forward method). + + Codec convs are expressed as unfold + matmul / matmul + Fold + (model.py:_conv1d_as_matmul / _conv_transpose1d_as_matmul) so AOTI's CUDA + backend can lower them via Triton mm kernels. CodecAttention uses an + additive ALiBi mask which is fine for ATen SDPA when triton_kernel_mode=OFF. + """ + print("\n" + "=" * 60) + print("Exporting codec_decoder.pte") + print("=" * 60) + codec_programs, codec_metadata = export_codec_decoder( + model, + max_codec_frames=args.max_codec_frames, + qlinear_codec=args.qlinear_codec, + qlinear_codec_group_size=args.qlinear_codec_group_size, + device=device, + ) + codec_triton_mode = "OFF" if args.backend in ("cuda", "cuda-windows") else "ON" + et_codec = lower_to_executorch( + codec_programs, + codec_metadata, + backend=args.backend, + triton_kernel_mode=codec_triton_mode, + ) + + codec_pte = os.path.join(args.output_dir, "codec_decoder.pte") + print(f"\nSaving to {codec_pte}...") + with open(codec_pte, "wb") as f: + et_codec.write_to_file(f) + size_mb = os.path.getsize(codec_pte) / (1024 * 1024) + print(f"Saved codec_decoder.pte ({size_mb:.1f} MB)") + + if et_codec._tensor_data: + # Rename the codec's data blob so it doesn't collide with the LM's + # `aoti_cuda_blob.ptd` (both default to the same filename). + renamed = {} + for k, v in et_codec._tensor_data.items(): + new_key = "codec_" + k if ("aoti_cuda" in k or k.startswith("model")) else k + renamed[new_key] = v + et_codec._tensor_data = renamed + et_codec.write_tensor_data_to_file(args.output_dir) + print(f"Saved codec tensor data to {args.output_dir}/") + + +def _apply_cuda_arg_defaults(parser, args, backend_for_export: str) -> None: + """Auto-set CUDA-specific defaults: tile_packed_to_4d packing + bf16 dtype. + + Both are required for the AOTI _weight_int4pack_mm kernel path. Promoted + automatically (with a print) so users don't have to remember the rule; + explicit incompatible values are rejected via parser.error(). + """ + if backend_for_export == "cuda" and args.qlinear == "4w": + if args.qlinear_packing_format is None: + args.qlinear_packing_format = "tile_packed_to_4d" + print( + "Auto-selected --qlinear-packing-format=tile_packed_to_4d " + "(required by _weight_int4pack_mm on CUDA)." + ) + elif args.qlinear_packing_format != "tile_packed_to_4d": + parser.error( + "--qlinear=4w on CUDA requires " + "--qlinear-packing-format=tile_packed_to_4d" + ) + + if backend_for_export == "cuda" and args.qlinear and args.dtype == "fp32": + print( + f"Auto-promoting --dtype to bf16 (CUDA --qlinear={args.qlinear} " + "needs bf16 weights for the int-pack kernels)." + ) + args.dtype = "bf16" + + +def resolve_effective_quantization( + *, + backend: str, + qlinear: str | None, + qembedding: str | None, +) -> dict[str, str | None]: + warning = None + effective_qembedding = qembedding + if backend == "xnnpack" and qembedding is not None: + warning = ( + "XNNPACK runtime does not register quantized embedding kernels yet; " + "disabling embedding quantization for this export." + ) + effective_qembedding = None + return { + "qlinear": qlinear, + "qembedding": effective_qembedding, + "warning": warning, + } + + +# --------------------------------------------------------------------------- +# Export functions +# --------------------------------------------------------------------------- + + +def export_model( + model, + max_seq_len, + streaming=False, + device="cpu", +): + """Export LLM + acoustic head as a single multi-method model.pte. + + Quantization must be applied to the model BEFORE calling this function. + """ + programs = {} + param_dtype = next(model.parameters()).dtype + config = model.config + + # 1. Text decoder + print("\nExporting text_decoder...") + text_decoder = TextDecoderExport(model) + text_decoder.eval() + seq_dim = Dim("seq_len", min=1, max=max_seq_len) + sample_embeds = torch.randn(1, 4, config.dim, dtype=param_dtype, device=device) + sample_pos = torch.arange(4, dtype=torch.long, device=device) + programs["text_decoder"] = export( + text_decoder, + (sample_embeds, sample_pos), + dynamic_shapes={ + "input_embeds": {1: seq_dim}, + "cache_position": {0: seq_dim}, + }, + strict=True, + ) + print(f" text_decoder exported (sample: {sample_embeds.shape})") + + # 2. Token embedding + print("\nExporting token_embedding...") + tok_emb = TokenEmbeddingExport(model) + tok_emb.eval() + tok_seq_dim = Dim("tok_seq_len", min=1, max=max_seq_len) + sample_ids = torch.tensor([[0, 1, 2, 3]], dtype=torch.long, device=device) + programs["token_embedding"] = export( + tok_emb, + (sample_ids,), + dynamic_shapes={"token_ids": {1: tok_seq_dim}}, + strict=True, + ) + print(f" token_embedding exported (sample: {sample_ids.shape})") + + # 3. Audio token embedding + print("\nExporting audio_token_embedding...") + audio_tok_emb = AudioTokenEmbeddingExport(model) + audio_tok_emb.eval() + sample_audio_codes = torch.zeros( + 1, config.n_codebooks, 1, dtype=torch.long, device=device + ) + programs["audio_token_embedding"] = export( + audio_tok_emb, + (sample_audio_codes,), + strict=True, + ) + print(" audio_token_embedding exported " f"(sample: {sample_audio_codes.shape})") + + # 4. Semantic head + print("\nExporting semantic_head...") + sem_head = SemanticHeadExport(model) + sem_head.eval() + sample_hidden = torch.randn(1, config.dim, dtype=param_dtype, device=device) + programs["semantic_head"] = export( + sem_head, + (sample_hidden,), + strict=True, + ) + print(f" semantic_head exported (sample: {sample_hidden.shape})") + + # 5. Predict velocity + print("\nExporting predict_velocity...") + vel_pred = PredictVelocityExport(model) + vel_pred.eval() + sample_xt = torch.randn(1, config.acoustic_dim, dtype=param_dtype, device=device) + sample_tidx = torch.tensor([0], dtype=torch.long, device=device) + sample_hv = torch.randn(1, config.dim, dtype=param_dtype, device=device) + programs["predict_velocity"] = export( + vel_pred, + (sample_xt, sample_tidx, sample_hv), + strict=True, + ) + print(" predict_velocity exported") + + # Tells the runner whether to stage fp32 buffers as bf16 before each + # AOTI execute (1 = bf16 model, 0 = fp32 model). bf16 happens for + # quantized exports (--qlinear); fp32 is the default mixed-precision path. + lm_input_is_bf16 = 1 if param_dtype == torch.bfloat16 else 0 + + # Determine the default voice embedding length from the real voice asset + # instead of baking in casual_male-specific metadata. + voice_embed_len = 0 + model_dir = Path(model.config_path) if hasattr(model, "config_path") else None + if model_dir: + try: + v, _ = load_voice_from_model_dir( + model_dir, + None, + dim=config.dim, + ) + voice_embed_len = v.shape[0] + except Exception: + voice_embed_len = 0 + + metadata = { + "sample_rate": config.sampling_rate, + "n_decoding_steps": config.n_decoding_steps, + "cfg_alpha_x100": int(config.cfg_alpha * 100), + "n_acoustic_codebook": config.acoustic_dim, + "semantic_codebook_size": config.semantic_codebook_size, + "acoustic_levels": config.acoustic_levels, + "vocab_size": config.vocab_size, + "max_seq_len": max_seq_len, + "dim": config.dim, + "downsample_factor": config.downsample_factor, + "n_codebooks": config.n_codebooks, + "end_audio_code": 1, + "empty_audio_code": 0, + "n_special_tokens": 2, + "streaming": 1 if streaming else 0, + "streaming_chunk_frames": 25, + "streaming_initial_chunk": 5, + "streaming_left_context": 25, + "audio_token_id": config.audio_token_id, + "begin_audio_token_id": config.begin_audio_token_id, + "text_to_audio_token_id": config.text_to_audio_token_id, + "repeat_audio_text_token_id": config.repeat_audio_text_token_id, + "voice_embed_len": voice_embed_len, + "lm_input_is_bf16": lm_input_is_bf16, + } + + return programs, metadata + + +def export_codec_decoder( + model, + max_codec_frames=256, + qlinear_codec=None, + qlinear_codec_group_size=None, + device="cpu", +): + """Export codec decoder as a separate .pte.""" + from executorch.extension.llm.export.quantize import quantize_model_ + + config = model.config + + print("\nExporting codec_decoder...") + codec_dec = CodecDecoderExport(model) + codec_dec.eval() + + if qlinear_codec: + print(f" Quantizing codec ({qlinear_codec})...") + quantize_model_( + codec_dec, + qlinear_config=qlinear_codec, + qlinear_group_size=qlinear_codec_group_size, + ) + + sample_codes = torch.zeros( + 1, config.n_codebooks, max_codec_frames, dtype=torch.long, device=device + ) + # Static export: the codec's transformer/conv stages introduce tight + # divisibility constraints under dynamic_shapes (upsample stride/kernel + # math). Keeping the input static at max_codec_frames avoids those + # constraint violations. The runner pads to max_codec_frames, but the + # codec's transformer is only locally bidirectional (window<=16) so the + # ALiBi-windowed attention contaminates a small boundary region; choose + # max_codec_frames close to the expected per-utterance frame count to + # minimize how many trailing zero codes the model attends to. + programs = {"forward": export(codec_dec, (sample_codes,), strict=True)} + print( + f" codec_decoder exported (codes: {sample_codes.shape}, " + f"waveform: {max_codec_frames * config.downsample_factor} samples)" + ) + + metadata = { + "max_codec_frames": max_codec_frames, + "downsample_factor": config.downsample_factor, + "n_codebooks": config.n_codebooks, + "sample_rate": config.sampling_rate, + "codec_supports_exact_frames": 0, + } + + return programs, metadata + + +def apply_model_quantization( + model, + *, + qlinear: str | None, + qlinear_group_size: int | None, + qlinear_packing_format: str | None, + qembedding: str | None, + qembedding_group_size: int | None, + decoder_qlinear_scope: str = "all", +) -> None: + if qlinear: + qlinear_kwargs = { + "qlinear_config": qlinear, + "qlinear_group_size": qlinear_group_size, + "qlinear_packing_format": qlinear_packing_format, + } + if decoder_qlinear_scope == "all": + quantize_model_(model.decoder, **qlinear_kwargs) + elif decoder_qlinear_scope == "attention": + for layer in model.decoder.layers: + quantize_model_(layer.attention, **qlinear_kwargs) + elif decoder_qlinear_scope == "feed_forward": + for layer in model.decoder.layers: + quantize_model_(layer.feed_forward, **qlinear_kwargs) + elif decoder_qlinear_scope != "none": + raise ValueError( + f"Unsupported decoder_qlinear_scope: {decoder_qlinear_scope}" + ) + quantize_model_( + model.flow_head, + qlinear_config=qlinear, + qlinear_group_size=qlinear_group_size, + qlinear_packing_format=qlinear_packing_format, + skip_incompatible_shapes=True, + ) + + if qembedding: + tok_emb_wrapper = TokenEmbeddingExport(model) + quantize_model_( + tok_emb_wrapper, + qembedding_config=qembedding, + qembedding_group_size=qembedding_group_size, + ) + audio_tok_emb_wrapper = AudioTokenEmbeddingExport(model) + quantize_model_( + audio_tok_emb_wrapper, + qembedding_config=qembedding, + qembedding_group_size=qembedding_group_size, + ) + + +def lower_to_executorch(programs, metadata, backend="xnnpack", triton_kernel_mode="ON"): + """Lower exported programs to ExecuTorch. + + Args: + triton_kernel_mode: For CUDA backend only. "ON" replaces ATen SDPA with + Triton sdpa kernel (required for the LM decoder). "OFF" disables + replacement so the codec's additive ALiBi mask SDPA can lower + (Triton sdpa kernel only accepts bool masks). + """ + mutable_buffer_passes = [InitializedMutableBufferPass(["k_cache", "v_cache"])] + if backend == "xnnpack": + from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( + XnnpackDynamicallyQuantizedPartitioner, + XnnpackPartitioner, + ) + + print(f"\nLowering to ExecuTorch with XNNPACK ({len(programs)} methods)...") + partitioner = { + key: [XnnpackDynamicallyQuantizedPartitioner(), XnnpackPartitioner()] + for key in programs + } + elif backend in ("cuda", "cuda-windows"): + from executorch.backends.cuda.cuda_backend import CudaBackend + from executorch.backends.cuda.cuda_partitioner import CudaPartitioner + from executorch.exir.backend.compile_spec_schema import CompileSpec + + print( + f"\nLowering to ExecuTorch with CUDA " + f"{'(Windows) ' if backend == 'cuda-windows' else ''}" + f"({len(programs)} methods, triton_kernel_mode={triton_kernel_mode})..." + ) + # NB: conv1d_to_conv2d is applied inside CudaBackend.preprocess via its + # decomposition_table. Doing it here too triggers an extra run_decompositions + # pass that leaks unbacked symbols on the 26-layer Mistral text_decoder. + + partitioner = {} + for key in programs: + compile_specs = [CudaBackend.generate_method_name_compile_spec(key)] + compile_specs.append( + CompileSpec("triton_kernel_mode", triton_kernel_mode.encode("utf-8")) + ) + if backend == "cuda-windows": + compile_specs.append(CompileSpec("platform", b"windows")) + partitioner[key] = [CudaPartitioner(compile_specs)] + else: + print(f"\nLowering to ExecuTorch (portable, {len(programs)} methods)...") + partitioner = [] + + et_prog = to_edge_transform_and_lower( + programs, + partitioner=partitioner, + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ), + constant_methods=metadata, + ) + + return et_prog.to_executorch( + config=ExecutorchBackendConfig( + extract_delegate_segments=True, + do_quant_fusion_and_const_prop=True, + passes=mutable_buffer_passes, + memory_planning_pass=MemoryPlanningPass( + alloc_graph_input=False, + share_mutable_buffers=True, + ), + emit_mutable_buffer_names=True, + ), + ) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(): + import sys + + parser = argparse.ArgumentParser(description="Export Voxtral TTS to ExecuTorch") + parser.add_argument( + "--model-path", + required=True, + help="Directory with params.json + consolidated.safetensors", + ) + parser.add_argument( + "--backend", + default="xnnpack", + choices=["portable", "xnnpack", "cuda", "cuda-windows"], + help="Backend (default: xnnpack). cuda/cuda-windows compile via " + "AOTInductor and emit model.pte + model.ptd.", + ) + parser.add_argument( + "--output-dir", + default="./voxtral_tts_exports", + help="Output directory (default: ./voxtral_tts_exports)", + ) + parser.add_argument( + "--export-target", + default="all", + choices=["all", "model", "codec"], + help="Which artifacts to export (default: all).", + ) + parser.add_argument( + "--max-seq-len", + type=int, + default=4096, + help="KV cache length (default: 4096)", + ) + parser.add_argument( + "--max-codec-frames", + type=int, + default=256, + help="Max codec frames for decoder (default: 256 = ~20s audio)", + ) + parser.add_argument( + "--qlinear", + default=None, + choices=["4w", "8w", "8da4w", "8da8w"], + help="Quantize ALL linear layers (LLM + acoustic head).", + ) + parser.add_argument( + "--qlinear-group-size", + type=int, + default=None, + help="Group size for linear quantization.", + ) + parser.add_argument( + "--qlinear-packing-format", + default=None, + help="Packing format for 4w quantization.", + ) + parser.add_argument( + "--decoder-qlinear-scope", + default="all", + choices=["all", "attention", "feed_forward", "none"], + help="Limit decoder linear quantization to a specific decoder sub-scope.", + ) + parser.add_argument( + "--qlinear-codec", + default=None, + choices=["4w", "8w"], + help="Quantize codec decoder linear layers.", + ) + parser.add_argument( + "--qlinear-codec-group-size", + type=int, + default=None, + help="Group size for codec linear quantization.", + ) + parser.add_argument( + "--qembedding", + default=None, + choices=["4w", "8w"], + help="Quantize embedding layers.", + ) + parser.add_argument( + "--qembedding-group-size", + type=int, + default=None, + help="Group size for embedding quantization.", + ) + parser.add_argument( + "--streaming", + action="store_true", + help="Enable streaming codec chunking metadata.", + ) + parser.add_argument( + "--dtype", + default="fp32", + choices=["fp32", "bf16"], + help="Model dtype (default: fp32).", + ) + args = parser.parse_args() + + backend_for_export = "cuda" if args.backend == "cuda-windows" else args.backend + _apply_cuda_arg_defaults(parser, args, backend_for_export) + + os.makedirs(args.output_dir, exist_ok=True) + model_dtype = {"fp32": torch.float32, "bf16": torch.bfloat16}[args.dtype] + device = "cuda" if backend_for_export == "cuda" else "cpu" + + sys.stdout.reconfigure(line_buffering=True) + + print("Loading model...") + model = load_model( + args.model_path, + max_seq_len=args.max_seq_len, + dtype=model_dtype, + backend=backend_for_export, + ) + model.config_path = Path(args.model_path) + + if device == "cuda": + print("Moving model to CUDA...") + model.cuda() + + quant_plan = resolve_effective_quantization( + backend=backend_for_export, + qlinear=args.qlinear, + qembedding=args.qembedding, + ) + effective_qlinear = quant_plan["qlinear"] + effective_qembedding = quant_plan["qembedding"] + if quant_plan["warning"]: + print(f"\nWarning: {quant_plan['warning']}") + + if effective_qlinear or effective_qembedding: + if effective_qlinear: + print( + f"\nQuantizing linear layers ({effective_qlinear}, " + f"decoder scope={args.decoder_qlinear_scope})..." + ) + if effective_qembedding: + print(f"Quantizing embedding ({effective_qembedding})...") + apply_model_quantization( + model, + qlinear=effective_qlinear, + qlinear_group_size=args.qlinear_group_size, + qlinear_packing_format=args.qlinear_packing_format, + qembedding=effective_qembedding, + qembedding_group_size=args.qembedding_group_size, + decoder_qlinear_scope=args.decoder_qlinear_scope, + ) + + if args.export_target in ("all", "model"): + _export_lm_pte(model, args, device) + + if args.export_target in ("all", "codec"): + _export_codec_pte(model, args, device) + + print("\n" + "=" * 60) + print("DONE") + print("=" * 60) + for f in sorted(os.listdir(args.output_dir)): + if f.endswith(".pte"): + s = os.path.getsize(os.path.join(args.output_dir, f)) / (1024 * 1024) + print(f" {f}: {s:.1f} MB") + + +if __name__ == "__main__": + main() diff --git a/examples/models/voxtral_tts/main.cpp b/examples/models/voxtral_tts/main.cpp new file mode 100644 index 00000000000..336f60718cc --- /dev/null +++ b/examples/models/voxtral_tts/main.cpp @@ -0,0 +1,133 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * Voxtral TTS runner CLI. + * + * Usage: + * voxtral_tts_runner --model model.pte --codec codec_decoder.pte \ + * --tokenizer tekken.json --text "Hello world" --output output.wav + */ + +#include "voxtral_tts_runner.h" + +#include +#include +#include + +#include + +DEFINE_string(model, "model.pte", "Path to model.pte (LLM + acoustic head)"); +DEFINE_string(codec, "codec_decoder.pte", "Path to codec_decoder.pte"); +DEFINE_string(tokenizer, "tekken.json", "Path to tokenizer JSON"); +DEFINE_string( + data_path, + "", + "Optional path to model.ptd (CUDA backend; AOTI .so + weights)."); +DEFINE_string( + codec_data_path, + "", + "Optional path to codec_decoder.ptd (CUDA backend)."); +DEFINE_string(text, "", "Text to synthesize"); +DEFINE_string( + voice, + "", + "Voice preset name or path to .pt/.bin voice embedding " + "(default: neutral_female)."); +DEFINE_string(output, "output.wav", "Output WAV file path"); +DEFINE_string( + trace_json, + "", + "Optional path to write a structured parity trace JSON."); +DEFINE_int32(seed, 42, "Random seed for semantic sampling and flow noise"); +DEFINE_double(temperature, 0.0, "Sampling temperature (0 = greedy)"); +DEFINE_int32(max_new_tokens, 2048, "Max audio frames to generate"); +DEFINE_bool(streaming, false, "Use streaming mode with chunked codec decoding"); +DEFINE_bool( + speaker, + false, + "Write raw f32le PCM to stdout for live playback. " + "Pipe to: aplay -f FLOAT_LE -r 24000 -c 1, or ffplay -f f32le -ar 24000"); + +static volatile bool g_interrupted = false; +static void signal_handler(int) { + g_interrupted = true; +} + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + + if (FLAGS_text.empty()) { + std::cerr << "Error: --text is required" << std::endl; + return 1; + } + + std::signal(SIGINT, signal_handler); + + // When --speaker is active, keep stdout clean for PCM — log to stderr. + auto& log = FLAGS_speaker ? std::cerr : std::cout; + + log << "Voxtral TTS" << std::endl; + log << " Model: " << FLAGS_model << std::endl; + log << " Codec: " << FLAGS_codec << std::endl; + log << " Tokenizer: " << FLAGS_tokenizer << std::endl; + log << " Text: \"" << FLAGS_text << "\"" << std::endl; + log << " Output: " << FLAGS_output << std::endl; + log << " Seed: " << FLAGS_seed << std::endl; + log << " Mode: " + << (FLAGS_speaker ? "streaming+speaker" + : FLAGS_streaming ? "streaming" + : "offline") + << std::endl; + + auto load_start = std::chrono::high_resolution_clock::now(); + + voxtral_tts::VoxtralTTSRunner runner( + FLAGS_model, + FLAGS_codec, + FLAGS_tokenizer, + FLAGS_data_path, + FLAGS_codec_data_path); + runner.set_trace_output_path(FLAGS_trace_json); + runner.set_seed(static_cast(FLAGS_seed)); + + auto load_end = std::chrono::high_resolution_clock::now(); + auto load_ms = std::chrono::duration_cast( + load_end - load_start) + .count(); + log << "Model loaded in " << load_ms << "ms" << std::endl; + + if (FLAGS_streaming || FLAGS_speaker) { + auto callback = [&](const float* samples, std::size_t count) { + if (FLAGS_speaker) { + // Write raw f32le PCM to stdout for live playback. + std::cout.write( + reinterpret_cast(samples), count * sizeof(float)); + std::cout.flush(); + } + log << " Chunk: " << count << " samples (" + << static_cast(count) / 24000.0f << "s)" << std::endl; + }; + runner.synthesize_streaming( + FLAGS_text, + FLAGS_voice, + FLAGS_output, + callback, + static_cast(FLAGS_temperature), + FLAGS_max_new_tokens); + } else { + runner.synthesize_offline( + FLAGS_text, + FLAGS_voice, + FLAGS_output, + static_cast(FLAGS_temperature), + FLAGS_max_new_tokens); + } + + return 0; +} diff --git a/examples/models/voxtral_tts/model.py b/examples/models/voxtral_tts/model.py new file mode 100644 index 00000000000..0fb7641a062 --- /dev/null +++ b/examples/models/voxtral_tts/model.py @@ -0,0 +1,1609 @@ +# Voxtral-4B-TTS-2603 reference implementation for ExecuTorch. +# Based on the Mistral model released under the CC-BY-NC-4.0 license. +# See https://huggingface.co/mistralai/Voxtral-4B-TTS-2603 + +"""Voxtral-4B-TTS-2603 eager model for ExecuTorch. + +Three-component architecture: + 1. Mistral LLM backbone (~4B params) — autoregressive text-to-hidden-states + 2. FlowMatchingHead — hidden states to 37 audio codebook tokens per frame + 3. CodecDecoder — codebook tokens to 24kHz waveform + +See the plan document for architecture details. +""" + +import json +import math +from dataclasses import dataclass +from pathlib import Path + +import torch +import torch.nn as nn +import torch.nn.functional as F +from executorch.extension.llm.custom_ops import custom_ops as _custom_ops # noqa: F401 + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + + +@dataclass +class VoxtralTTSConfig: + # LLM (Mistral backbone) + dim: int = 3072 + n_layers: int = 26 + n_heads: int = 32 + n_kv_heads: int = 8 + head_dim: int = 128 + hidden_dim: int = 9216 + vocab_size: int = 131072 + rope_theta: float = 1_000_000.0 + norm_eps: float = 1e-5 + # Acoustic transformer (flow matching head) — defaults match 4B checkpoint + at_dim: int = 3072 + at_n_layers: int = 3 + at_n_heads: int = 32 + at_n_kv_heads: int = 8 + at_head_dim: int = 128 + at_hidden_dim: int = 9216 + at_norm_eps: float = 1e-5 + at_use_biases: bool = False + n_decoding_steps: int = 7 + cfg_alpha: float = 1.2 + noise_scale: float = 1.0 + audio_token_id: int = 24 + begin_audio_token_id: int = 25 + text_to_audio_token_id: int = 36 + repeat_audio_text_token_id: int = 35 + # Codebooks + semantic_codebook_size: int = 8192 + semantic_dim: int = 256 + acoustic_levels: int = 21 + acoustic_dim: int = 36 + # Codec decoder + codec_dim: int = 1024 + codec_hidden_dim: int = 4096 + codec_n_heads: int = 8 + codec_n_kv_heads: int = 8 + codec_head_dim: int = 128 + codec_norm_eps: float = 1e-2 + codec_qk_norm_eps: float = 1e-6 + codec_sliding_window: int = 16 + codec_patch_size: int = 240 + codec_use_biases: bool = False + codec_layer_scale: bool = True + codec_conv_weight_norm: bool = True + codec_causal: bool = True + codec_half_attn_window_upon_downsampling: bool = True + codec_decoder_transformer_lengths: tuple[int, ...] = (2, 2, 2, 2) + codec_decoder_convs_kernels: tuple[int, ...] = (3, 4, 4, 4) + codec_decoder_convs_strides: tuple[int, ...] = (1, 2, 2, 2) + sampling_rate: int = 24000 + # Runtime + max_seq_len: int = 4096 + backend: str = "xnnpack" + + @staticmethod + def from_params_json(path: str) -> "VoxtralTTSConfig": + with open(path) as f: + p = json.load(f) + + mm = p.get("multimodal", {}) + audio_model = mm.get("audio_model_args", {}) + at_args = audio_model.get("acoustic_transformer_args", {}) + tokenizer_args = mm.get("audio_tokenizer_args", {}) + audio_enc = audio_model.get("audio_encoding_args", {}) + + # Parse codebook sizes from comma-separated string or individual fields + if "codebook_sizes" in audio_model: + cb_sizes = [int(c) for c in audio_model["codebook_sizes"].split(",")] + semantic_cb_size = cb_sizes[0] + acoustic_cb_size = cb_sizes[1] if len(cb_sizes) > 1 else 21 + n_acoustic = len(cb_sizes) - 1 + else: + semantic_cb_size = audio_model.get("semantic_codebook_size", 8192) + acoustic_cb_size = audio_model.get("acoustic_codebook_size", 21) + n_acoustic = audio_model.get("n_acoustic_codebook", 36) + + def _str2tuple(s: str) -> tuple[int, ...]: + return tuple(int(x) for x in s.split(",")) + + return VoxtralTTSConfig( + dim=p["dim"], + n_layers=p["n_layers"], + n_heads=p["n_heads"], + n_kv_heads=p["n_kv_heads"], + head_dim=p["head_dim"], + hidden_dim=p["hidden_dim"], + vocab_size=p["vocab_size"], + rope_theta=p["rope_theta"], + norm_eps=p["norm_eps"], + at_dim=at_args.get("dim", 3072), + at_n_layers=at_args.get("n_layers", 3), + at_n_heads=at_args.get("n_heads", 32), + at_n_kv_heads=at_args.get("n_kv_heads", 8), + at_head_dim=at_args.get("head_dim", 128), + at_hidden_dim=at_args.get("hidden_dim", 9216), + at_norm_eps=at_args.get("norm_eps", 1e-5), + at_use_biases=at_args.get("use_biases", False), + n_decoding_steps=at_args.get("n_decoding_steps", 7), + audio_token_id=audio_model.get("audio_token_id", 24), + begin_audio_token_id=audio_model.get("begin_audio_token_id", 25), + text_to_audio_token_id=audio_model.get("text_to_audio_token_id", 36), + semantic_codebook_size=semantic_cb_size, + acoustic_levels=acoustic_cb_size, + acoustic_dim=n_acoustic, + codec_dim=tokenizer_args.get("dim", 1024), + codec_hidden_dim=tokenizer_args.get("hidden_dim", 4096), + codec_n_heads=tokenizer_args.get("n_heads", 8), + codec_n_kv_heads=tokenizer_args.get("n_kv_heads", 8), + codec_head_dim=tokenizer_args.get("head_dim", 128), + codec_norm_eps=tokenizer_args.get("norm_eps", 1e-2), + codec_qk_norm_eps=tokenizer_args.get("qk_norm_eps", 1e-6), + codec_sliding_window=tokenizer_args.get("attn_sliding_window_size", 16), + codec_patch_size=tokenizer_args.get("pretransform_patch_size", 240), + codec_use_biases=tokenizer_args.get("use_biases", False), + codec_layer_scale=tokenizer_args.get("layer_scale", True), + codec_conv_weight_norm=tokenizer_args.get("conv_weight_norm", True), + codec_causal=tokenizer_args.get("causal", True), + codec_half_attn_window_upon_downsampling=tokenizer_args.get( + "half_attn_window_upon_downsampling", True + ), + codec_decoder_transformer_lengths=_str2tuple( + tokenizer_args.get("decoder_transformer_lengths_str", "2,2,2,2") + ), + codec_decoder_convs_kernels=_str2tuple( + tokenizer_args.get("decoder_convs_kernels_str", "3,4,4,4") + ), + codec_decoder_convs_strides=_str2tuple( + tokenizer_args.get("decoder_convs_strides_str", "1,2,2,2") + ), + sampling_rate=audio_enc.get("sampling_rate", 24000), + semantic_dim=tokenizer_args.get("semantic_dim", 256), + ) + + @property + def n_codebooks(self) -> int: + return 1 + self.acoustic_dim # 1 semantic + N acoustic + + @property + def downsample_factor(self) -> int: + return self.codec_patch_size * math.prod(self.codec_decoder_convs_strides) + + @property + def frame_rate(self) -> float: + return self.sampling_rate / self.downsample_factor + + +# --------------------------------------------------------------------------- +# Shared building blocks +# --------------------------------------------------------------------------- + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.dim = dim + self.weight = nn.Parameter(torch.ones(dim)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.rms_norm(x, (self.dim,), self.weight, self.eps) + + +def precompute_freqs_cis( + head_dim: int, max_len: int, theta: float +) -> tuple[torch.Tensor, torch.Tensor]: + """Pairwise interleaved RoPE matching mistral_inference's complex convention.""" + freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim)) + t = torch.arange(max_len, dtype=torch.float) + emb = torch.outer(t, freqs) # (max_len, head_dim/2) + cos = emb.cos().repeat_interleave(2, dim=-1) # (max_len, head_dim) + sin = emb.sin().repeat_interleave(2, dim=-1) + return cos, sin + + +def _rotate_interleave(x: torch.Tensor) -> torch.Tensor: + """Pairwise rotation on adjacent pairs: (-x1, x0, -x3, x2, ...).""" + x = x.unflatten(-1, (-1, 2)) + x = torch.stack((-x[..., 1], x[..., 0]), dim=-1) + return x.flatten(-2) + + +def apply_rotary_emb( + q: torch.Tensor, + k: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + fc = freqs_cos.unsqueeze(0).unsqueeze(2) + fs = freqs_sin.unsqueeze(0).unsqueeze(2) + q_float = q.float() + k_float = k.float() + q_out = q_float * fc + _rotate_interleave(q_float) * fs + k_out = k_float * fc + _rotate_interleave(k_float) * fs + return q_out.type_as(q), k_out.type_as(k) + + +# --------------------------------------------------------------------------- +# LLM decoder components +# --------------------------------------------------------------------------- + + +class KVCache(nn.Module): + """KV cache in [B, S, H, D] layout for torch.ops.llama.update_cache.""" + + def __init__(self, max_seq_len: int, n_kv_heads: int, head_dim: int): + super().__init__() + self.max_seq_len = max_seq_len + cache_shape = (1, max_seq_len, n_kv_heads, head_dim) + self.register_buffer("k_cache", torch.zeros(cache_shape)) + self.register_buffer("v_cache", torch.zeros(cache_shape)) + + def update( + self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + torch._check(start_pos < self.max_seq_len) + torch.ops.llama.update_cache(k_val, self.k_cache, start_pos) + torch.ops.llama.update_cache(v_val, self.v_cache, start_pos) + return self.k_cache, self.v_cache + + +def _build_causal_mask_bool( + input_pos: torch.Tensor, max_seq_len: int, device: torch.device +) -> torch.Tensor: + """Bool causal mask for the CUDA Triton SDPA kernel. + + Shape: [1, 1, T_q, max_seq_len]. Position i in the query attends to cache + positions [0, input_pos[i]] inclusive — every other slot is masked. This + is critical for the CUDA path because StaticKVCache is preallocated to + max_seq_len with zeros at unwritten positions; without the mask, queries + would attend to those zeros and corrupt the hidden state. + + The XNNPACK path uses torch.ops.llama.custom_sdpa, which infers the prefix + length from start_pos internally and doesn't need this mask. + """ + k_pos = torch.arange(max_seq_len, device=device) + mask = input_pos.unsqueeze(1) >= k_pos.unsqueeze(0) + return mask.unsqueeze(0).unsqueeze(0) + + +class StaticKVCache(nn.Module): + """KV cache in [B, H, S, D] layout, stored bf16 for the CUDA Triton SDPA. + + The buffer is bf16 because torch.ops.triton.sdpa only accepts bf16 K/V. + The model's weights and activations stay in their native dtype (fp32 by + default); inputs to update() get cast to bf16 at write time. This isolates + the bf16 requirement to the SDPA path so semantic_head, predict_velocity, + and the LM MLPs keep fp32 precision. + """ + + def __init__(self, max_seq_len: int, n_kv_heads: int, head_dim: int): + super().__init__() + self.max_seq_len = max_seq_len + self.n_kv_heads = n_kv_heads + self.head_dim = head_dim + cache_shape = (1, n_kv_heads, max_seq_len, head_dim) + # Cache buffers are always bf16 — required by the Triton SDPA kernel. + self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=torch.bfloat16)) + self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=torch.bfloat16)) + + def update( + self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + k_val = k_val.to(torch.bfloat16).transpose(1, 2) + v_val = v_val.to(torch.bfloat16).transpose(1, 2) + self.k_cache.index_copy_(2, input_pos, k_val) + self.v_cache.index_copy_(2, input_pos, v_val) + return self.k_cache, self.v_cache + + +class SDPA(nn.Module): + """Scaled dot-product attention using torch.ops.llama.custom_sdpa.""" + + def __init__(self, n_heads: int, head_dim: int): + super().__init__() + self.dim = n_heads * head_dim + + def forward( + self, + input_pos: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + bsz: int, + seqlen: int, + mask: torch.Tensor | None = None, + ) -> torch.Tensor: + input_dtype = q.dtype + q = q.to(dtype=torch.float32) + k = k.to(dtype=torch.float32) + v = v.to(dtype=torch.float32) + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + if mask is not None: + y = torch.ops.llama.custom_sdpa( + q, + k, + v, + start_pos, + mask.to(dtype=torch.float32), + 0, + False, + ) + else: + y = torch.ops.llama.custom_sdpa(q, k, v, start_pos, None, 0, True) + return y.view(bsz, seqlen, self.dim).to(dtype=input_dtype) + + +class StandardSDPA(nn.Module): + """Scaled dot-product attention using ExecuTorch's Triton SDPA kernel. + + Used for the CUDA (AOTI) backend. Q arrives in [B, S, H_q, D] in the + model's native dtype (typically fp32); K/V arrive in [B, H_kv, S, D] bf16 + (StaticKVCache stores bf16 because the Triton SDPA kernel demands it). + + Q is cast to bf16 just before the kernel and the output is cast back to + Q's original dtype, so the rest of the LM stays in fp32. The mask is + a [1, 1, T_q, max_seq_len] bool — without it, queries would attend to + the unwritten (zero) cache slots beyond input_pos and corrupt the hidden + state from frame 0. + + F.scaled_dot_product_attention is avoided because its decomposition leaks + unbacked symbols during AOTI's re-trace pass. + """ + + def __init__(self, n_heads: int, n_kv_heads: int, head_dim: int): + super().__init__() + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.head_dim = head_dim + self.dim = n_heads * head_dim + self.enable_gqa = n_heads != n_kv_heads + # Trigger triton::sdpa op registration. The kernel is invoked via the + # registered op below so we don't pay an import cost in non-CUDA builds. + import executorch.backends.cuda.triton.kernels # noqa: F401 + + def forward( + self, + input_pos: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + bsz: int, + seqlen: int, + mask: torch.Tensor | None = None, + ) -> torch.Tensor: + in_dtype = q.dtype + q = q.transpose(1, 2).to( + torch.bfloat16 + ) # [B, S, H_q, D] -> [B, H_q, S, D] bf16 + y = torch.ops.triton.sdpa( + q, + k, + v, + mask, + 0.0, + False, + 0.0, + self.enable_gqa, + ) + y = y.to(in_dtype).transpose(1, 2).contiguous() + return y.view(bsz, seqlen, self.dim) + + +class LMAttention(nn.Module): + """GQA with RoPE, KV cache, and SDPA. No biases. + + Backend selection: + - "xnnpack"/"portable" (default): KVCache (BSHD) + custom SDPA op. + - "cuda": StaticKVCache (BHSD) + StandardSDPA (F.scaled_dot_product_attention). + The custom_sdpa op cannot lower through AOTInductor. + """ + + def __init__(self, config: VoxtralTTSConfig): + super().__init__() + self.n_heads = config.n_heads + self.n_kv_heads = config.n_kv_heads + self.head_dim = config.head_dim + self.dim = config.dim + self.backend = config.backend + + self.wq = nn.Linear(config.dim, self.n_heads * self.head_dim, bias=False) + self.wk = nn.Linear(config.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(config.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear(self.n_heads * self.head_dim, config.dim, bias=False) + + if self.backend == "cuda": + self.kv_cache = StaticKVCache( + config.max_seq_len, self.n_kv_heads, self.head_dim + ) + self.sdpa = StandardSDPA(self.n_heads, self.n_kv_heads, self.head_dim) + else: + self.kv_cache = KVCache(config.max_seq_len, self.n_kv_heads, self.head_dim) + self.sdpa = SDPA(self.n_heads, self.head_dim) + + def forward( + self, + x: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + input_pos: torch.Tensor, + attn_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + B, T, _ = x.shape + q = self.wq(x).view(B, T, self.n_heads, self.head_dim) + k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim) + v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim) + q, k = apply_rotary_emb(q, k, freqs_cos, freqs_sin) + k, v = self.kv_cache.update(input_pos, k, v) + y = self.sdpa(input_pos, q, k, v, B, T, attn_mask) + return self.wo(y) + + +class LMMLP(nn.Module): + """SwiGLU FFN. No biases.""" + + def __init__(self, dim: int, hidden_dim: int): + super().__init__() + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class MistralDecoderLayer(nn.Module): + """Decoder layer with standard pre-norm (no adaptive RMSNorm for TTS).""" + + def __init__(self, config: VoxtralTTSConfig): + super().__init__() + self.attention_norm = RMSNorm(config.dim, config.norm_eps) + self.attention = LMAttention(config) + self.ffn_norm = RMSNorm(config.dim, config.norm_eps) + self.feed_forward = LMMLP(config.dim, config.hidden_dim) + + def forward( + self, + x: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + input_pos: torch.Tensor, + attn_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + x = x + self.attention( + self.attention_norm(x), freqs_cos, freqs_sin, input_pos, attn_mask + ) + x = x + self.feed_forward(self.ffn_norm(x)) + return x + + +class MistralDecoder(nn.Module): + """Mistral LM decoder. Returns hidden states (no lm_head projection).""" + + def __init__(self, config: VoxtralTTSConfig): + super().__init__() + self.config = config + self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) + self.layers = nn.ModuleList( + [MistralDecoderLayer(config) for _ in range(config.n_layers)] + ) + self.norm = RMSNorm(config.dim, config.norm_eps) + + freqs_cos, freqs_sin = precompute_freqs_cis( + config.head_dim, config.max_seq_len, config.rope_theta + ) + self.register_buffer("freqs_cos", freqs_cos) + self.register_buffer("freqs_sin", freqs_sin) + + def forward( + self, + input_embeds: torch.Tensor, + input_pos: torch.Tensor, + ) -> torch.Tensor: + freqs_cos = self.freqs_cos[input_pos] + freqs_sin = self.freqs_sin[input_pos] + + # CUDA: must explicitly mask unwritten cache slots — see _build_causal_mask_bool. + # XNNPACK / portable: custom_sdpa handles the prefix internally. + attn_mask = None + if self.config.backend == "cuda": + attn_mask = _build_causal_mask_bool( + input_pos, self.config.max_seq_len, input_embeds.device + ) + + x = input_embeds + for layer in self.layers: + x = layer(x, freqs_cos, freqs_sin, input_pos, attn_mask) + + return self.norm(x) + + +# --------------------------------------------------------------------------- +# Flow matching head (acoustic transformer) +# --------------------------------------------------------------------------- + +# Special token IDs for audio codebooks (0-indexed) +EMPTY_AUDIO_ID = 0 +END_AUDIO_ID = 1 +N_SPECIAL_TOKENS = 2 + + +class AudioTokenEmbedding(nn.Module): + """Embed one semantic+acoustic frame back into the LLM hidden space.""" + + def __init__(self, config: VoxtralTTSConfig): + super().__init__() + self.codebook_sizes = [ + config.semantic_codebook_size + N_SPECIAL_TOKENS, + *[ + config.acoustic_levels + N_SPECIAL_TOKENS + for _ in range(config.acoustic_dim) + ], + ] + total_vocab_size = sum(self.codebook_sizes) + padded_vocab_size = 128 * ((total_vocab_size + 127) // 128) + self.embeddings = nn.Embedding(padded_vocab_size, config.dim) + self.register_buffer("offsets", self.make_offsets(), persistent=False) + + def make_offsets(self) -> torch.Tensor: + offsets = [] + offset = 0 + for size in self.codebook_sizes: + offsets.append(offset) + offset += size + return torch.tensor(offsets, dtype=torch.long) + + def forward(self, codes: torch.Tensor) -> torch.Tensor: + offsets = self.offsets.view(1, -1, 1) + return self.embeddings(codes + offsets).sum(dim=1) + + +class TimeEmbedding(nn.Module): + """Sinusoidal embedding for flow matching timestep.""" + + def __init__(self, dim: int, theta: float = 10000.0): + super().__init__() + inv_freq = torch.exp( + -math.log(theta) * torch.arange(dim // 2).float() / (dim // 2) + ) + self.register_buffer("inv_freq", inv_freq, persistent=True) + + def forward(self, t: torch.Tensor) -> torch.Tensor: + emb = torch.einsum("bi, j -> bj", t, self.inv_freq) + return torch.cat((emb.cos(), emb.sin()), dim=-1) + + +class BidirectionalAttention(nn.Module): + """Full (non-causal) attention with GQA. No positional encoding.""" + + def __init__( + self, + dim: int, + n_heads: int, + n_kv_heads: int, + head_dim: int, + use_biases: bool = False, + ): + super().__init__() + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.head_dim = head_dim + self.repeats = n_heads // n_kv_heads + + self.wq = nn.Linear(dim, n_heads * head_dim, bias=use_biases) + self.wk = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=use_biases) + self.wo = nn.Linear(n_heads * head_dim, dim, bias=use_biases) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if x.dim() == 2: + bsz, seqlen = 1, x.shape[0] + else: + bsz, seqlen, _ = x.shape + + q = self.wq(x).view(bsz, seqlen, self.n_heads, self.head_dim) + k = self.wk(x).view(bsz, seqlen, self.n_kv_heads, self.head_dim) + v = self.wv(x).view(bsz, seqlen, self.n_kv_heads, self.head_dim) + + # GQA expansion + if self.repeats > 1: + k = k.unsqueeze(3).expand(-1, -1, -1, self.repeats, -1).flatten(2, 3) + v = v.unsqueeze(3).expand(-1, -1, -1, self.repeats, -1).flatten(2, 3) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # Scale and compute attention (bidirectional, no mask) + scale = self.head_dim**-0.5 + attn = (q * scale) @ k.transpose(-2, -1) + attn = attn.softmax(-1) + y = attn @ v + + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + return self.wo(y) + + +class AcousticFeedForward(nn.Module): + """SwiGLU FFN for the acoustic transformer.""" + + def __init__(self, dim: int, hidden_dim: int, use_biases: bool = False): + super().__init__() + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=use_biases) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class AcousticTransformerBlock(nn.Module): + def __init__(self, config: VoxtralTTSConfig, layer_id: int): + super().__init__() + self.attention = BidirectionalAttention( + config.at_dim, + config.at_n_heads, + config.at_n_kv_heads, + config.at_head_dim, + config.at_use_biases, + ) + self.feed_forward = AcousticFeedForward( + config.at_dim, config.at_hidden_dim, config.at_use_biases + ) + self.attention_norm = RMSNorm(config.at_dim, config.at_norm_eps) + self.ffn_norm = RMSNorm(config.at_dim, config.at_norm_eps) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.attention(self.attention_norm(x)) + x = x + self.feed_forward(self.ffn_norm(x)) + return x + + +class FlowMatchingHead(nn.Module): + """Generates audio codebook tokens from LLM hidden states via flow matching ODE. + + Per frame: produces 1 semantic code (argmax) + N acoustic codes (7-step Euler ODE). + The predict_velocity method is exported separately for the C++ runner to call + in a loop. + """ + + def __init__(self, config: VoxtralTTSConfig): + super().__init__() + self.config = config + + # Projections + self.input_projection = nn.Linear( + config.acoustic_dim, config.at_dim, bias=False + ) + self.time_projection = nn.Linear(config.at_dim, config.at_dim, bias=False) + self.llm_projection = nn.Linear(config.dim, config.at_dim, bias=False) + + # Semantic codebook head + padded_semantic_size = 128 * ( + (config.semantic_codebook_size + N_SPECIAL_TOKENS + 127) // 128 + ) + self.semantic_codebook_output = nn.Linear( + config.at_dim, padded_semantic_size, bias=config.at_use_biases + ) + self.padded_semantic_size = padded_semantic_size + + # Acoustic codebook head (predicts velocity vector) + self.acoustic_codebook_output = nn.Linear( + config.at_dim, config.acoustic_dim, bias=False + ) + + # Transformer layers + self.layers = nn.ModuleDict( + { + str(i): AcousticTransformerBlock(config, i) + for i in range(config.at_n_layers) + } + ) + self.norm = RMSNorm(config.at_dim, config.at_norm_eps) + + # Time embedding + self.time_embedding = TimeEmbedding(config.at_dim) + + # Pre-compute timestep table for export + self.register_buffer( + "_timesteps", + torch.linspace(0, 1, config.n_decoding_steps + 1), + persistent=False, + ) + + def forward_layers(self, h: torch.Tensor) -> torch.Tensor: + for i in range(self.config.at_n_layers): + h = self.layers[str(i)](h) + return h + + def predict_velocity( + self, + x_t: torch.Tensor, + t_idx: torch.Tensor, + llm_hidden: torch.Tensor, + ) -> torch.Tensor: + """Single velocity prediction step for the flow matching ODE. + + Args: + x_t: (B, acoustic_dim) current noisy state + t_idx: (B,) timestep index into self._timesteps + llm_hidden: (B, llm_dim) hidden state from LLM + Returns: + v_t: (B, acoustic_dim) predicted velocity + """ + t = self._timesteps[t_idx].unsqueeze(-1) # (B, 1) + t_emb = self.time_embedding(t).to(llm_hidden.dtype) + t_emb = self.time_projection(t_emb) + llm_proj = self.llm_projection(llm_hidden) + + inp = self.input_projection(x_t.to(llm_hidden.dtype)).unsqueeze(1) + t_tok = t_emb.unsqueeze(1) + ctx_tok = llm_proj.unsqueeze(1) + h = torch.cat([inp, t_tok, ctx_tok], dim=1) # (B, 3, at_dim) + + h = self.forward_layers(h) + h = self.norm(h) + return self.acoustic_codebook_output(h[:, 0, :]) + + def semantic_head(self, llm_hidden: torch.Tensor) -> torch.Tensor: + """Predict semantic codebook token (greedy argmax). + + Args: + llm_hidden: (B, llm_dim) hidden state from LLM + Returns: + code: (B,) semantic codebook index + """ + logit = self.semantic_logits(llm_hidden) + return logit.argmax(dim=-1) + + def semantic_logits(self, llm_hidden: torch.Tensor) -> torch.Tensor: + """Raw masked logits for semantic code prediction.""" + logit = self.semantic_codebook_output(llm_hidden).float() + logit[:, EMPTY_AUDIO_ID] = float("-inf") + logit[:, (N_SPECIAL_TOKENS + self.config.semantic_codebook_size) :] = float( + "-inf" + ) + return logit + + def forward(self, llm_hidden: torch.Tensor) -> torch.Tensor: + """Full forward: semantic code + flow matching ODE -> all codes. + + Used for eager validation. The C++ runner calls predict_velocity + and semantic_head separately. + + Args: + llm_hidden: (B, llm_dim) + Returns: + codes: (B, 1 + acoustic_dim) = (B, 37) per frame + """ + B = llm_hidden.shape[0] + semantic_code = self.semantic_head(llm_hidden).unsqueeze(1) # (B, 1) + + should_decode = semantic_code.squeeze(1) != END_AUDIO_ID + + # Flow matching ODE + x = torch.randn(B, self.config.acoustic_dim, device=llm_hidden.device) + x = x.to(llm_hidden.dtype) * self.config.noise_scale + + timesteps = self._timesteps.to(llm_hidden.dtype) + llm_zero = torch.zeros_like(llm_hidden) + + for i in range(len(timesteps) - 1): + t = timesteps[i] + dt = timesteps[i + 1] - timesteps[i] + + t_emb = self.time_embedding(t.view(-1, 1).repeat(B, 1)).to(llm_hidden.dtype) + t_emb = self.time_projection(t_emb) + + # CFG: batch cond + uncond + x_batched = torch.cat([x, x], dim=0) + llm_batched = torch.cat([llm_hidden, llm_zero], dim=0) + t_emb_batched = torch.cat([t_emb, t_emb], dim=0) + llm_proj = self.llm_projection(llm_batched) + + inp = self.input_projection(x_batched.to(llm_hidden.dtype)).unsqueeze(1) + t_tok = t_emb_batched.unsqueeze(1) + ctx_tok = llm_proj.unsqueeze(1) + h = torch.cat([inp, t_tok, ctx_tok], dim=1) + + h = self.forward_layers(h) + h = self.norm(h) + v_all = self.acoustic_codebook_output(h[:, 0, :]) + + v_cond, v_uncond = v_all[:B], v_all[B:] + v = self.config.cfg_alpha * v_cond + (1 - self.config.cfg_alpha) * v_uncond + x = x + v * dt + + # Quantize + x = torch.clamp(x, -1, 1) + scaled = ((x + 1) / 2) * (self.config.acoustic_levels - 1) + acoustic_codes = scaled.round().long() + acoustic_codes[~should_decode] = EMPTY_AUDIO_ID + acoustic_codes = acoustic_codes + N_SPECIAL_TOKENS + + return torch.cat([semantic_code, acoustic_codes], dim=1) + + +# --------------------------------------------------------------------------- +# Codec decoder components +# --------------------------------------------------------------------------- + + +def _conv1d_as_matmul( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + stride: int, + dilation: int, +) -> torch.Tensor: + """Math-identical replacement for F.conv1d, expressed as unfold + matmul. + + AOTI's CUDA backend has Triton matmul kernels and aoti_torch_cuda_mm + runtime shims, but no kernels for aten.convolution.default at the codec's + ConvTranspose shapes (and no aoti_torch_cuda_convolution shim). This + reformulation lets the codec lower onto the same fast Triton path the LM + already uses. + + x: (B, C_in, L_in) + weight: (C_out, C_in, K) + bias: (C_out,) or None + Returns (B, C_out, L_out) where L_out = (L_in - K_eff) // stride + 1 + """ + b, c_in, l_in = x.shape + c_out, _, k = weight.shape + x4 = x.unsqueeze(-1) # (B, C_in, L_in, 1) + unf = F.unfold( + x4, + kernel_size=(k, 1), + dilation=(dilation, 1), + stride=(stride, 1), + ) + # F.unfold returns (B, C_in * K, L_out); each column flattens a window in + # (channel-major, then kernel) order, matching weight.reshape(C_out, -1). + unf = unf.transpose(1, 2) # (B, L_out, C_in*K) + w_flat = weight.reshape(c_out, -1) + y = unf @ w_flat.t() # (B, L_out, C_out) + if bias is not None: + y = y + bias + return y.transpose(1, 2).contiguous() # (B, C_out, L_out) + + +def _conv_transpose1d_as_matmul( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + stride: int, +) -> torch.Tensor: + """Math-identical replacement for F.conv_transpose1d, via matmul + Fold. + + Same motivation as _conv1d_as_matmul — moves the codec off aten.convolution + onto the matmul path that AOTI's CUDA backend handles cleanly. + + x: (B, C_in, L_in) + weight: (C_in, C_out, K) (PyTorch ConvTranspose1d layout) + bias: (C_out,) or None + Returns (B, C_out, L_out) where L_out = (L_in - 1) * stride + K + """ + b, c_in, l_in = x.shape + _, c_out, k = weight.shape + # Reshape weight to (C_in, C_out * K). For each input position l_in we want + # to produce a (C_out * K) vector that fold then scatters into the right + # output positions with stride-based overlap-add. + w_flat = weight.reshape(c_in, c_out * k) + # (B, C_in, L_in) -> (B, L_in, C_in) for batched matmul. + y = x.transpose(1, 2) @ w_flat # (B, L_in, C_out * K) + y = y.transpose(1, 2) # (B, C_out * K, L_in) — F.fold's expected order + l_out = (l_in - 1) * stride + k + # F.fold scatters each column's K values into output positions + # [l_in * stride .. l_in * stride + K) and accumulates overlaps. + y4 = F.fold( + y, + output_size=(l_out, 1), + kernel_size=(k, 1), + stride=(stride, 1), + ) # (B, C_out, L_out, 1) + out = y4.squeeze(-1) # (B, C_out, L_out) + if bias is not None: + out = out + bias[None, :, None] + return out + + +def _pad1d( + x: torch.Tensor, + paddings: tuple[int, int], + mode: str = "constant", + value: float = 0.0, +) -> torch.Tensor: + length = x.shape[-1] + padding_left, padding_right = paddings + if mode == "reflect": + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + x = F.pad(x, (0, extra_pad)) + padded = F.pad(x, paddings, mode, value) + end = padded.shape[-1] - extra_pad + return padded[..., :end] + return F.pad(x, paddings, mode, value) + + +class CodecCausalConv1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + pad_mode: str = "reflect", + use_weight_norm: bool = True, + use_bias: bool = True, + ): + super().__init__() + self.conv = nn.Conv1d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=0, + dilation=dilation, + bias=use_bias, + ) + if use_weight_norm: + self.conv = torch.nn.utils.parametrizations.weight_norm(self.conv) + self.pad_mode = pad_mode + self._stride = stride + self._effective_kernel_size = (kernel_size - 1) * dilation + 1 + self._padding_total = self._effective_kernel_size - self._stride + + def forward(self, x: torch.Tensor) -> torch.Tensor: + n_frames = ( + x.shape[-1] - self._effective_kernel_size + self._padding_total + ) / self._stride + 1 + target_length = (math.ceil(n_frames) - 1) * self._stride + ( + self._effective_kernel_size - self._padding_total + ) + extra_padding = target_length - x.shape[-1] + x = _pad1d(x, (self._padding_total, extra_padding), mode=self.pad_mode) + # Use unfold + matmul instead of F.conv1d so AOTI CUDA can lower this. + return _conv1d_as_matmul( + x, + self.conv.weight, + self.conv.bias, + stride=self._stride, + dilation=self.conv.dilation[0], + ) + + +class CodecCausalConvTranspose1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + use_weight_norm: bool = True, + use_bias: bool = True, + trim_ratio: float = 1.0, + ): + super().__init__() + self.conv = nn.ConvTranspose1d( + in_channels, + out_channels, + kernel_size, + stride=stride, + bias=use_bias, + ) + if use_weight_norm: + self.conv = torch.nn.utils.parametrizations.weight_norm(self.conv) + self.trim_ratio = trim_ratio + + def forward(self, x: torch.Tensor) -> torch.Tensor: + kernel_size = self.conv.kernel_size[0] + stride = self.conv.stride[0] + total_padding = kernel_size - stride + # matmul + Fold replacement for F.conv_transpose1d (AOTI CUDA path). + out = _conv_transpose1d_as_matmul( + x, self.conv.weight, self.conv.bias, stride=stride + ) + right_padding = math.ceil(total_padding * self.trim_ratio) + left_padding = total_padding - right_padding + return out[..., left_padding : out.shape[-1] - right_padding] + + +def _get_alibi_slopes(n_heads: int) -> torch.Tensor: + def _slopes_power_of_2(n: int) -> torch.Tensor: + r = 2.0 ** (-8.0 / n) + return torch.tensor([r ** (i + 1) for i in range(n)], dtype=torch.float32) + + if math.log2(n_heads).is_integer(): + return _slopes_power_of_2(n_heads) + m = 2 ** math.floor(math.log2(n_heads)) + return torch.cat( + [_slopes_power_of_2(m), _slopes_power_of_2(2 * m)[::2][: n_heads - m]] + ) + + +class CodecAttention(nn.Module): + """Causal attention with ALiBi + sliding window for the codec decoder.""" + + def __init__( + self, + dim: int, + n_heads: int, + n_kv_heads: int, + head_dim: int, + sliding_window: int, + qk_norm: bool = True, + qk_norm_eps: float = 1e-6, + use_biases: bool = False, + causal: bool = True, + ): + super().__init__() + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.head_dim = head_dim + self.repeats = n_heads // n_kv_heads + self.sliding_window = sliding_window + self.causal = causal + + self.wq = nn.Linear(dim, n_heads * head_dim, bias=False) + self.wk = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.wo = nn.Linear(n_heads * head_dim, dim, bias=use_biases) + + if qk_norm: + self.q_norm = RMSNorm(n_heads * head_dim, qk_norm_eps) + self.k_norm = RMSNorm(n_kv_heads * head_dim, qk_norm_eps) + else: + self.q_norm = None + self.k_norm = None + + self.register_buffer( + "alibi_slopes", _get_alibi_slopes(n_heads), persistent=False + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if x.dim() == 2: + bsz, seqlen = 1, x.shape[0] + else: + bsz, seqlen, _ = x.shape + + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + if self.q_norm is not None: + xq = self.q_norm(xq) + xk = self.k_norm(xk) + + xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim) + + # Transpose to (B, H, S, D) + q = xq.transpose(1, 2) + k = xk.transpose(1, 2) + v = xv.transpose(1, 2) + + # GQA expansion + if self.repeats > 1: + k = k.repeat_interleave(self.repeats, dim=1) + v = v.repeat_interleave(self.repeats, dim=1) + + # Build ALiBi + causal + sliding window bias + positions = torch.arange(seqlen, device=x.device) + rel_pos = positions.unsqueeze(0) - positions.unsqueeze(1) + alibi_slopes = self.alibi_slopes.to(dtype=x.dtype, device=x.device) + attn_bias = alibi_slopes.view(self.n_heads, 1, 1) * rel_pos.unsqueeze(0).to( + x.dtype + ) + + if self.causal: + attn_bias = attn_bias.masked_fill(rel_pos.unsqueeze(0) > 0, float("-inf")) + + window_right = 0 if self.causal else self.sliding_window + outside_window = (rel_pos < -self.sliding_window) | (rel_pos > window_right) + attn_bias = attn_bias.masked_fill(outside_window.unsqueeze(0), float("-inf")) + + y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias.unsqueeze(0)) + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + return self.wo(y) + + +class CodecFeedForward(nn.Module): + """SwiGLU FFN for the codec.""" + + def __init__(self, dim: int, hidden_dim: int, use_biases: bool = False): + super().__init__() + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=use_biases) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class CodecTransformerBlock(nn.Module): + def __init__( + self, + layer_id: int, + dim: int, + hidden_dim: int, + n_heads: int, + n_kv_heads: int, + head_dim: int, + sliding_window: int, + norm_eps: float, + qk_norm: bool, + qk_norm_eps: float, + use_biases: bool, + layer_scale: bool, + causal: bool, + ): + super().__init__() + self.attention = CodecAttention( + dim, + n_heads, + n_kv_heads, + head_dim, + sliding_window, + qk_norm, + qk_norm_eps, + use_biases, + causal, + ) + self.feed_forward = CodecFeedForward(dim, hidden_dim, use_biases) + self.attention_norm = RMSNorm(dim, norm_eps) + self.ffn_norm = RMSNorm(dim, norm_eps) + + self.use_layer_scale = layer_scale + if layer_scale: + if layer_id < 18: + init_scale = 0.1 + elif layer_id <= 24: + init_scale = 1e-5 + else: + init_scale = 1e-6 + self.attention_scale = nn.Parameter( + torch.full((dim,), init_scale, requires_grad=True) + ) + self.ffn_scale = nn.Parameter( + torch.full((dim,), init_scale, requires_grad=True) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r = self.attention(self.attention_norm(x)) + if self.use_layer_scale: + r = self.attention_scale * r + h = x + r + r = self.feed_forward(self.ffn_norm(h)) + if self.use_layer_scale: + r = self.ffn_scale * r + return h + r + + +class CodecTransformer(nn.Module): + """Stack of codec transformer blocks with specified sliding window.""" + + def __init__(self, config: VoxtralTTSConfig, n_layers: int, sliding_window: int): + super().__init__() + self.layers = nn.ModuleDict() + for i in range(n_layers): + self.layers[str(i)] = CodecTransformerBlock( + layer_id=i, + dim=config.codec_dim, + hidden_dim=config.codec_hidden_dim, + n_heads=config.codec_n_heads, + n_kv_heads=config.codec_n_kv_heads, + head_dim=config.codec_head_dim, + sliding_window=sliding_window, + norm_eps=config.codec_norm_eps, + qk_norm=True, + qk_norm_eps=config.codec_qk_norm_eps, + use_biases=config.codec_use_biases, + layer_scale=config.codec_layer_scale, + causal=config.codec_causal, + ) + self.n_layers = n_layers + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for i in range(self.n_layers): + x = self.layers[str(i)](x) + return x + + +# --------------------------------------------------------------------------- +# Codebook quantizer (decode only) +# --------------------------------------------------------------------------- + + +class SemanticCodebook(nn.Module): + """Euclidean distance VQ codebook — decode is just an embedding lookup.""" + + def __init__(self, codebook_size: int, codebook_dim: int): + super().__init__() + self.register_buffer("cluster_usage", torch.ones(codebook_size)) + self.register_buffer("embedding_sum", torch.zeros(codebook_size, codebook_dim)) + + @property + def embedding(self) -> torch.Tensor: + return self.embedding_sum / self.cluster_usage.clamp(min=1e-5)[:, None] + + def decode(self, codes: torch.Tensor) -> torch.Tensor: + """codes: (B, 1, T) -> (B, semantic_dim, T)""" + codes = codes.squeeze(1) # (B, T) + quantized = F.embedding(codes, self.embedding) # (B, T, D) + return quantized.transpose(1, 2) # (B, D, T) + + +class AcousticCodebook(nn.Module): + """Finite Scalar Quantization — decode rescales integers to [-1, 1].""" + + def __init__(self, n_levels: int, dim: int): + super().__init__() + self.n_levels = n_levels + self.dim = dim + + def decode( + self, codes: torch.Tensor, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """codes: (B, dim, T) long -> (B, dim, T) float in [-1, 1]""" + return (codes.to(dtype) * 2) / (self.n_levels - 1) - 1 + + +class AudioCodebook(nn.Module): + """Combined semantic + acoustic codebook for decode.""" + + def __init__(self, config: VoxtralTTSConfig): + super().__init__() + self.semantic_codebook = SemanticCodebook( + config.semantic_codebook_size, config.semantic_dim + ) + self.acoustic_codebook = AcousticCodebook( + config.acoustic_levels, config.acoustic_dim + ) + self.semantic_dim = config.semantic_dim + self.acoustic_dim = config.acoustic_dim + + def decode( + self, codes: torch.Tensor, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """codes: (B, 1+acoustic_dim, T) -> (B, semantic_dim+acoustic_dim, T)""" + semantic_codes = codes[:, :1, :] + acoustic_codes = codes[:, 1:, :] + sem_emb = self.semantic_codebook.decode(semantic_codes).to(dtype) + aco_emb = self.acoustic_codebook.decode(acoustic_codes, dtype) + return torch.cat([sem_emb, aco_emb], dim=1) + + +# --------------------------------------------------------------------------- +# Full codec decoder +# --------------------------------------------------------------------------- + + +class CodecDecoder(nn.Module): + """Converts codebook tokens to waveform via VQ/FSQ decode + upsampling.""" + + def __init__(self, config: VoxtralTTSConfig): + super().__init__() + self.config = config + self.quantizer = AudioCodebook(config) + latent_dim = config.semantic_dim + config.acoustic_dim + + decoder_blocks: list[nn.Module] = [] + # The encoder starts at codec_sliding_window and halves at each + # downsample. The decoder mirrors this: start at the most-compressed + # window and double at each upsample. + n_upsample = sum(1 for s in config.codec_decoder_convs_strides if s > 1) + if config.codec_half_attn_window_upon_downsampling and n_upsample > 0: + cur_window_size = config.codec_sliding_window // (2**n_upsample) + else: + cur_window_size = config.codec_sliding_window + + # First projection: latent_dim -> codec_dim + decoder_blocks.append( + CodecCausalConv1d( + latent_dim, + config.codec_dim, + kernel_size=config.codec_decoder_convs_kernels[0], + stride=config.codec_decoder_convs_strides[0], + pad_mode="replicate", + use_weight_norm=config.codec_conv_weight_norm, + use_bias=False, + ) + ) + if ( + config.codec_half_attn_window_upon_downsampling + and config.codec_decoder_convs_strides[0] > 1 + ): + cur_window_size *= 2 + + for idx, n_layers in enumerate(config.codec_decoder_transformer_lengths): + decoder_blocks.append(CodecTransformer(config, n_layers, cur_window_size)) + if idx + 1 < len(config.codec_decoder_transformer_lengths): + next_k = config.codec_decoder_convs_kernels[idx + 1] + next_s = config.codec_decoder_convs_strides[idx + 1] + if next_k != 1 or next_s != 1: + decoder_blocks.append( + CodecCausalConvTranspose1d( + config.codec_dim, + config.codec_dim, + kernel_size=next_k, + stride=next_s, + use_weight_norm=config.codec_conv_weight_norm, + use_bias=False, + ) + ) + if config.codec_half_attn_window_upon_downsampling and next_s > 1: + cur_window_size *= 2 + + self.decoder_blocks = nn.ModuleList(decoder_blocks) + + self.output_proj = CodecCausalConv1d( + config.codec_dim, + config.codec_patch_size, + kernel_size=7, + use_weight_norm=config.codec_conv_weight_norm, + use_bias=False, + ) + + def forward(self, codes: torch.Tensor) -> torch.Tensor: + """Decode codebook tokens to waveform. + + Args: + codes: (B, n_codebooks, T) integer codes + Returns: + waveform: (B, 1, T * downsample_factor) + """ + # The generator emits all 37 codebooks in the shifted token space where + # 0/1 are EMPTY/END special tokens and normal codes start at +2. + # Match the reference tokenizer path by unshifting every codebook while + # mapping specials/padding back to 0 for decode. + codes_stripped = torch.where( + codes >= N_SPECIAL_TOKENS, + codes - N_SPECIAL_TOKENS, + torch.zeros_like(codes), + ) + + # Match the dtype of the codec's first conv weight so downstream + # bf16/fp32 paths see consistent dtypes. + latent_dtype = ( + codes.dtype + if codes.is_floating_point() + else self.output_proj.conv.weight.dtype + ) + latent = self.quantizer.decode(codes_stripped, dtype=latent_dtype) + + x = latent # (B, D, T) channels-first + for block in self.decoder_blocks: + if isinstance(block, CodecTransformer): + x = x.transpose(1, 2) # (B, D, T) -> (B, T, D) + x = block(x) + x = x.transpose(1, 2) # (B, T, D) -> (B, D, T) + else: + x = block(x) # Conv1d / ConvTranspose1d: stays (B, D, T) + + waveform = self.output_proj(x) # (B, patch_size=240, T') + B, P, T = waveform.shape + # Audio samples are produced frame-by-frame: for each frame t we emit + # P contiguous samples. Interleave time-outer / patch-inner to match + # the reference C codec (`samples[t*P + h] = out_proj[h*T + t]`). + return waveform.transpose(1, 2).reshape(B, 1, T * P) + + +# --------------------------------------------------------------------------- +# Top-level model +# --------------------------------------------------------------------------- + + +class VoxtralTTSModel(nn.Module): + def __init__(self, config: VoxtralTTSConfig): + super().__init__() + self.config = config + self.decoder = MistralDecoder(config) + self.audio_token_embedding = AudioTokenEmbedding(config) + self.flow_head = FlowMatchingHead(config) + self.codec_decoder = CodecDecoder(config) + + +# --------------------------------------------------------------------------- +# Weight loading +# --------------------------------------------------------------------------- + + +def _map_checkpoint_key(ckpt_key: str) -> str | None: + """Map Mistral consolidated checkpoint key to model state_dict key. + + Checkpoint structure: + - layers.N.* -> decoder.layers.N.* + - norm.weight -> decoder.norm.weight + - mm_audio_embeddings.tok_embeddings.weight -> decoder.tok_embeddings.weight + - mm_audio_embeddings.audio_codebook_embeddings.embeddings.weight + -> audio_token_embedding.embeddings.weight + - acoustic_transformer.* -> flow_head.* + - audio_tokenizer.* -> codec_decoder.* + """ + # LLM decoder layers + if ckpt_key.startswith("layers."): + return "decoder." + ckpt_key + + if ckpt_key == "norm.weight": + return "decoder.norm.weight" + + # Token embeddings + if ckpt_key == "mm_audio_embeddings.tok_embeddings.weight": + return "decoder.tok_embeddings.weight" + + if ckpt_key == "mm_audio_embeddings.audio_codebook_embeddings.embeddings.weight": + return "audio_token_embedding.embeddings.weight" + + # Flow matching head (acoustic transformer) + if ckpt_key.startswith("acoustic_transformer."): + suffix = ckpt_key[len("acoustic_transformer.") :] + return "flow_head." + suffix + + # Codec decoder + if ckpt_key.startswith("audio_tokenizer."): + suffix = ckpt_key[len("audio_tokenizer.") :] + return "codec_decoder." + suffix + + # Skip voice embeddings (loaded separately) + if ckpt_key.startswith("mm_audio_embeddings.audio_codebook"): + return None + + return None + + +def _fold_weight_norm(model: nn.Module) -> None: + """Remove weight_norm parametrizations, fusing weight_v + weight_g into weight.""" + for _name, module in model.named_modules(): + if isinstance(module, (nn.Conv1d, nn.ConvTranspose1d)): + if hasattr(module, "parametrizations"): + torch.nn.utils.parametrize.remove_parametrizations(module, "weight") + + +def _materialize_meta_buffers(model: nn.Module, dtype: torch.dtype) -> None: + """Replace meta-device buffers with zero tensors on CPU. + + Preserves each buffer's declared dtype — StaticKVCache asks for bf16 for + the CUDA Triton SDPA path even when the rest of the model is fp32. + """ + for fqn, buf in list(model.named_buffers()): + if buf.device.type != "meta": + continue + parts = fqn.rsplit(".", 1) + parent = model.get_submodule(parts[0]) if len(parts) > 1 else model + if buf.dtype == torch.bfloat16: + buf_dtype = torch.bfloat16 + elif buf.dtype.is_floating_point: + buf_dtype = dtype + else: + buf_dtype = buf.dtype + parent.register_buffer( + parts[-1], + torch.zeros(buf.shape, dtype=buf_dtype, device="cpu"), + ) + + +def load_model( + model_path: str, + max_seq_len: int = 4096, + dtype: torch.dtype = torch.float32, + backend: str = "xnnpack", +) -> VoxtralTTSModel: + """Load VoxtralTTSModel from a Mistral checkpoint. + + Uses meta-device construction + assign-based loading to minimize peak memory. + """ + from safetensors import safe_open + + model_dir = Path(model_path) + config = VoxtralTTSConfig.from_params_json(str(model_dir / "params.json")) + config.max_seq_len = max_seq_len + config.backend = backend + + print( + f"Building model on meta device (dim={config.dim}, layers={config.n_layers}, " + f"at_dim={config.at_dim}, at_layers={config.at_n_layers}, " + f"codec_dim={config.codec_dim}, backend={backend})..." + ) + with torch.device("meta"): + model = VoxtralTTSModel(config) + + # Load weights + ckpt_path = str(model_dir / "consolidated.safetensors") + print(f"Loading weights from {ckpt_path}...") + state_dict = {} + with safe_open(ckpt_path, framework="pt", device="cpu") as f: + for ckpt_key in f.keys(): + model_key = _map_checkpoint_key(ckpt_key) + if model_key is None: + continue + state_dict[model_key] = f.get_tensor(ckpt_key).to(dtype) + + missing, unexpected = model.load_state_dict(state_dict, strict=False, assign=True) + + _materialize_meta_buffers(model, dtype) + + # Recompute RoPE + dec_cos, dec_sin = precompute_freqs_cis( + config.head_dim, max_seq_len, config.rope_theta + ) + model.decoder.register_buffer("freqs_cos", dec_cos) + model.decoder.register_buffer("freqs_sin", dec_sin) + + # Recompute audio-token embedding offsets + model.audio_token_embedding.register_buffer( + "offsets", + model.audio_token_embedding.make_offsets(), + persistent=False, + ) + + # Recompute flow-matching timestep embedding buffers + model.flow_head.time_embedding.register_buffer( + "inv_freq", + torch.exp( + -math.log(10000.0) + * torch.arange(config.at_dim // 2, dtype=torch.float32) + / (config.at_dim // 2) + ), + persistent=True, + ) + + # Recompute timesteps + model.flow_head.register_buffer( + "_timesteps", + torch.linspace(0, 1, config.n_decoding_steps + 1), + ) + + # Recompute ALiBi slopes for codec attention + for module in model.codec_decoder.modules(): + if isinstance(module, CodecAttention): + slopes = _get_alibi_slopes(module.n_heads) + module.register_buffer("alibi_slopes", slopes, persistent=False) + + # Recompute semantic codebook embedding + sem = model.codec_decoder.quantizer.semantic_codebook + if sem.embedding_sum.device.type != "meta": + sem.register_buffer( + "_embedding", + sem.embedding_sum / sem.cluster_usage.clamp(min=1e-5)[:, None], + persistent=False, + ) + + # Fold weight_norm in codec decoder + _fold_weight_norm(model.codec_decoder) + + # Validate loading + runtime_prefixes = ( + "decoder.freqs_", + "audio_token_embedding.offsets", + ".kv_cache.", + "flow_head.time_embedding.inv_freq", + "flow_head._timesteps", + ".alibi_slopes", + "._embedding", + ) + actual_missing = set(missing) + expected_missing = { + k for k in actual_missing if any(p in k for p in runtime_prefixes) + } + extra_missing = actual_missing - expected_missing + if extra_missing: + print(f" WARNING: {len(extra_missing)} unexpected missing keys") + for k in sorted(extra_missing)[:20]: + print(f" {k}") + if unexpected: + print(f" WARNING: {len(unexpected)} unexpected keys") + + loaded = len(state_dict) - len(unexpected) + print( + f" Loaded {loaded} tensors ({len(expected_missing)} runtime buffers OK, " + f"{len(extra_missing)} unexpected missing)" + ) + + model.eval() + return model diff --git a/examples/models/voxtral_tts/run_cuda_e2e.sh b/examples/models/voxtral_tts/run_cuda_e2e.sh new file mode 100755 index 00000000000..36d6db6dd50 --- /dev/null +++ b/examples/models/voxtral_tts/run_cuda_e2e.sh @@ -0,0 +1,95 @@ +#!/usr/bin/env bash +# Voxtral TTS — CUDA end-to-end script. +# Exports the 4w-quantized full-CUDA pipeline (LM + codec both on GPU) and +# runs the runner. Total wall clock for "Hello, how are you today?" on A100: +# ~3.7 s (LM 2.1 s + codec 0.04 s + load/build). +# +# Usage: +# conda activate et-cuda +# unset CPATH # critical — see README.md "CUDA gotchas" +# export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH +# bash examples/models/voxtral_tts/run_cuda_e2e.sh \ +# [] +# +# Env overrides: +# SKIP_EXPORT=1 — skip export step (use existing artifacts in OUT_DIR) +# SKIP_BUILD=1 — skip cmake build step +# PROMPT="..." — override the synthesis text +# VOICE= — voice embedding name without .pt (default: neutral_female) +# SEED= — RNG seed (default: 42) + +set -euo pipefail + +VOXTRAL_DIR="${1:?usage: $0 []}" +OUT_DIR="${2:-$PWD/voxtral_tts_exports_cuda_4w}" +PROMPT="${PROMPT:-Hello, how are you today?}" +VOICE="${VOICE:-neutral_female}" +SEED="${SEED:-42}" + +if [[ -n "${CPATH:-}" ]]; then + echo "ERROR: CPATH is set ('$CPATH'). Run 'unset CPATH' first." >&2 + echo " It pollutes nvcc's include search and breaks the CUDA backend." >&2 + exit 1 +fi +if ! command -v nvcc >/dev/null; then + echo "ERROR: nvcc not on PATH. Source ~/.bashrc and retry." >&2 + exit 1 +fi +if [[ ! -d "$VOXTRAL_DIR" ]]; then + echo "ERROR: model dir '$VOXTRAL_DIR' does not exist." >&2 + exit 1 +fi + +REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/../../.." && pwd)" +RUNNER="$REPO_ROOT/cmake-out/examples/models/voxtral_tts/voxtral_tts_runner" + +echo "=== 1/4. env check ===" +which nvcc +nvcc --version | tail -3 || true +echo " CUDA_HOME=${CUDA_HOME:-unset}" +echo " CUDAToolkit_ROOT=${CUDAToolkit_ROOT:-unset}" +echo " LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-unset}" +nvidia-smi -L +nvidia-smi --query-gpu=name,driver_version,memory.total,memory.free --format=csv + +echo +echo "=== 2/4. export 4w-quant CUDA model + CUDA codec ===" +if [[ "${SKIP_EXPORT:-0}" == "1" && -f "$OUT_DIR/model.pte" ]]; then + echo " SKIP_EXPORT=1 and $OUT_DIR/model.pte exists — skipping export" +else + mkdir -p "$OUT_DIR" + python "$REPO_ROOT/examples/models/voxtral_tts/export_voxtral_tts.py" \ + --model-path "$VOXTRAL_DIR" \ + --backend cuda --qlinear 4w \ + --output-dir "$OUT_DIR" +fi +echo " output:" +ls -la "$OUT_DIR" + +echo +echo "=== 3/4. build voxtral_tts_runner with EXECUTORCH_BUILD_CUDA=ON ===" +if [[ "${SKIP_BUILD:-0}" == "1" && -x "$RUNNER" ]]; then + echo " SKIP_BUILD=1 and $RUNNER exists — skipping build" +else + ( cd "$REPO_ROOT" && cmake --workflow --preset llm-release-cuda ) + ( cd "$REPO_ROOT/examples/models/voxtral_tts" && cmake --workflow --preset voxtral-tts-cuda ) +fi + +echo +echo "=== 4/4. synth: '$PROMPT' (voice=$VOICE seed=$SEED) ===" +WAV_OUT="${WAV_OUT:-$OUT_DIR/sample.wav}" +"$RUNNER" \ + --model "$OUT_DIR/model.pte" \ + --data_path "$OUT_DIR/aoti_cuda_blob.ptd" \ + --codec "$OUT_DIR/codec_decoder.pte" \ + --codec_data_path "$OUT_DIR/codec_aoti_cuda_blob.ptd" \ + --tokenizer "$VOXTRAL_DIR/tekken.json" \ + --voice "$VOXTRAL_DIR/voice_embedding/${VOICE}.pt" \ + --text "$PROMPT" \ + --output "$WAV_OUT" \ + --seed "$SEED" \ + --max_new_tokens 200 + +echo +echo "DONE. Wav: $WAV_OUT" +echo " Listen: ffplay $WAV_OUT (or aplay $WAV_OUT)" diff --git a/examples/models/voxtral_tts/voice.py b/examples/models/voxtral_tts/voice.py new file mode 100644 index 00000000000..0adeb98ef13 --- /dev/null +++ b/examples/models/voxtral_tts/voice.py @@ -0,0 +1,91 @@ +from pathlib import Path + +import numpy as np +import torch + +DEFAULT_VOICE_NAME = "neutral_female" + + +def resolve_voice_asset_path(model_dir: str | Path, voice: str | None) -> Path: + model_dir = Path(model_dir) + voice_name = voice or DEFAULT_VOICE_NAME + candidate = Path(voice_name) + + if candidate.exists(): + return candidate + + voice_dir = model_dir / "voice_embedding" + if candidate.suffix: + local_candidate = voice_dir / candidate.name + if local_candidate.exists(): + return local_candidate + return candidate + + for ext in (".pt", ".bin"): + local_candidate = voice_dir / f"{voice_name}{ext}" + if local_candidate.exists(): + return local_candidate + + return voice_dir / f"{voice_name}.pt" + + +def load_voice_embedding_tensor( + path: str | Path, + dim: int = 3072, + expected_frames_hint: int | None = None, +) -> torch.Tensor: + path = Path(path) + if path.suffix == ".pt": + return torch.load(path, map_location="cpu", weights_only=True).float() + + raw = path.read_bytes() + bf16_row_bytes = dim * 2 + f32_row_bytes = dim * 4 + matches_hint_bf16 = ( + expected_frames_hint is not None + and len(raw) == expected_frames_hint * bf16_row_bytes + ) + matches_hint_f32 = ( + expected_frames_hint is not None + and len(raw) == expected_frames_hint * f32_row_bytes + ) + + if matches_hint_f32: + data = np.frombuffer(raw, dtype=np.float32).copy() + return torch.from_numpy(data).reshape(-1, dim).float() + + if matches_hint_bf16 or len(raw) % bf16_row_bytes == 0: + data = np.frombuffer(raw, dtype=np.uint16).copy() + tensor = torch.from_numpy(data).reshape(-1, dim) + return tensor.view(torch.bfloat16).float() + + if len(raw) % f32_row_bytes == 0: + data = np.frombuffer(raw, dtype=np.float32).copy() + return torch.from_numpy(data).reshape(-1, dim).float() + + raise ValueError( + f"Voice embedding {path} has unsupported size {len(raw)} for dim={dim}" + ) + + +def load_voice_from_model_dir( + model_dir: str | Path, + voice: str | None, + dim: int = 3072, +) -> tuple[torch.Tensor, Path]: + path = resolve_voice_asset_path(model_dir, voice) + expected_frames_hint = None + if path.suffix == ".bin": + pt_peer = path.with_suffix(".pt") + if pt_peer.exists(): + expected_frames_hint = int( + load_voice_embedding_tensor(pt_peer, dim=dim).shape[0] + ) + return ( + load_voice_embedding_tensor( + path, + dim=dim, + expected_frames_hint=expected_frames_hint, + ), + path, + ) diff --git a/examples/models/voxtral_tts/voxtral_tts_runner.cpp b/examples/models/voxtral_tts/voxtral_tts_runner.cpp new file mode 100644 index 00000000000..82c353650f7 --- /dev/null +++ b/examples/models/voxtral_tts/voxtral_tts_runner.cpp @@ -0,0 +1,1469 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "voxtral_tts_runner.h" +#include "wav_writer.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace voxtral_tts { + +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::extension::from_blob; +using ::executorch::extension::Module; +using ::executorch::extension::TensorPtr; +using ::executorch::runtime::Error; +using ::executorch::runtime::EValue; + +namespace { + +using json = nlohmann::json; + +// fp32 ↔ bf16 conversion helpers. Used when feeding the bf16 AOTI methods +// (CUDA backend) from the runner's fp32 buffers, and reading their outputs +// back into fp32. The CPU/portable path keeps everything fp32 and these are +// not invoked. +inline uint16_t fp32_to_bf16(float f) { + uint32_t bits; + std::memcpy(&bits, &f, sizeof(float)); + // Round-to-nearest-even (truncation with rounding bias). + uint32_t rounding_bias = 0x00007FFFu + ((bits >> 16) & 1u); + return static_cast((bits + rounding_bias) >> 16); +} + +inline float bf16_to_fp32(uint16_t b) { + uint32_t bits = static_cast(b) << 16; + float f; + std::memcpy(&f, &bits, sizeof(float)); + return f; +} + +void read_float_tensor(const Tensor& t, float* dst, size_t n) { + if (t.scalar_type() == ScalarType::BFloat16) { + const uint16_t* src = t.const_data_ptr(); + for (size_t i = 0; i < n; ++i) { + dst[i] = bf16_to_fp32(src[i]); + } + } else { + std::memcpy(dst, t.const_data_ptr(), n * sizeof(float)); + } +} + +void write_float_to_bf16(uint16_t* dst, const float* src, size_t n) { + for (size_t i = 0; i < n; ++i) { + dst[i] = fp32_to_bf16(src[i]); + } +} + +int64_t read_metadata_int(Module& m, const char* name, int64_t fallback) { + std::vector empty; + auto result = m.execute(name, empty); + if (result.ok() && !result.get().empty()) { + return result.get()[0].toInt(); + } + return fallback; +} + +bool has_method(Module& m, const char* name) { + auto methods = m.method_names(); + if (!methods.ok()) { + return false; + } + return methods.get().count(name) > 0; +} + +json topk_logits(const float* logits, int64_t vocab_size, int k = 5) { + std::vector indices(vocab_size); + std::iota(indices.begin(), indices.end(), 0); + auto cmp = [&](int64_t lhs, int64_t rhs) { + return logits[lhs] > logits[rhs]; + }; + const int64_t topk = std::min(k, vocab_size); + std::partial_sort( + indices.begin(), indices.begin() + topk, indices.end(), cmp); + + json result = json::array(); + for (int64_t i = 0; i < topk; ++i) { + result.push_back({ + {"id", indices[i]}, + {"logit", logits[indices[i]]}, + }); + } + return result; +} + +json waveform_stats(const std::vector& samples) { + json result = { + {"num_samples", samples.size()}, + {"min", 0.0f}, + {"max", 0.0f}, + {"mean_abs", 0.0f}, + {"peak_abs", 0.0f}, + }; + if (samples.empty()) { + return result; + } + + float min_val = std::numeric_limits::infinity(); + float max_val = -std::numeric_limits::infinity(); + double sum_abs = 0.0; + float peak_abs = 0.0f; + for (float sample : samples) { + min_val = std::min(min_val, sample); + max_val = std::max(max_val, sample); + sum_abs += std::abs(sample); + peak_abs = std::max(peak_abs, std::abs(sample)); + } + result["min"] = min_val; + result["max"] = max_val; + result["mean_abs"] = static_cast(sum_abs / samples.size()); + result["peak_abs"] = peak_abs; + return result; +} + +void write_trace_json(const std::string& path, const json& trace) { + std::ofstream file(path); + ET_CHECK_MSG(file.is_open(), "Failed to open trace output: %s", path.c_str()); + file << trace.dump(2) << std::endl; +} + +uint16_t read_u16(const unsigned char* ptr) { + uint16_t value = 0; + std::memcpy(&value, ptr, sizeof(value)); + return value; +} + +uint32_t read_u32(const unsigned char* ptr) { + uint32_t value = 0; + std::memcpy(&value, ptr, sizeof(value)); + return value; +} + +bool find_zip_entry( + const std::vector& file_data, + const std::string& target_name, + const unsigned char*& out_data, + size_t& out_size) { + if (file_data.size() < 22) { + return false; + } + + size_t eocd_pos = 0; + bool found_eocd = false; + const size_t lower_bound = + file_data.size() > 65536 ? file_data.size() - 65536 : 0; + for (size_t pos = file_data.size() - 22; pos > lower_bound; --pos) { + if (read_u32(file_data.data() + pos) == 0x06054b50) { + eocd_pos = pos; + found_eocd = true; + break; + } + } + if (!found_eocd) { + return false; + } + + const uint32_t cd_offset = read_u32(file_data.data() + eocd_pos + 16); + size_t pos = cd_offset; + while (pos + 46 < file_data.size()) { + if (read_u32(file_data.data() + pos) != 0x02014b50) { + break; + } + + const uint16_t compression = read_u16(file_data.data() + pos + 10); + const uint32_t comp_size = read_u32(file_data.data() + pos + 20); + const uint32_t uncomp_size = read_u32(file_data.data() + pos + 24); + const uint16_t fname_len = read_u16(file_data.data() + pos + 28); + const uint16_t extra_len = read_u16(file_data.data() + pos + 30); + const uint16_t comment_len = read_u16(file_data.data() + pos + 32); + const uint32_t local_offset = read_u32(file_data.data() + pos + 42); + + const char* fname = + reinterpret_cast(file_data.data() + pos + 46); + if (std::string(fname, fname_len) == target_name) { + if (compression != 0) { + return false; + } + const uint16_t local_fname_len = + read_u16(file_data.data() + local_offset + 26); + const uint16_t local_extra_len = + read_u16(file_data.data() + local_offset + 28); + const size_t data_start = + local_offset + 30 + local_fname_len + local_extra_len; + const size_t entry_size = uncomp_size > 0 ? uncomp_size : comp_size; + if (data_start + entry_size > file_data.size()) { + return false; + } + out_data = file_data.data() + data_start; + out_size = entry_size; + return true; + } + + pos += 46 + fname_len + extra_len + comment_len; + } + return false; +} + +float bf16_to_float(uint16_t value) { + uint32_t bits = static_cast(value) << 16; + float result = 0.0f; + std::memcpy(&result, &bits, sizeof(result)); + return result; +} + +void load_bf16_tensor_data( + const uint16_t* bf16_data, + size_t count, + std::vector& out_data) { + out_data.resize(count); + for (size_t i = 0; i < count; ++i) { + out_data[i] = bf16_to_float(bf16_data[i]); + } +} + +bool load_pt_voice_tensor( + const std::filesystem::path& path, + int64_t dim, + std::vector& out_data, + int64_t& out_frames) { + std::ifstream file(path, std::ios::binary | std::ios::ate); + if (!file.is_open()) { + return false; + } + const auto file_size = static_cast(file.tellg()); + file.seekg(0, std::ios::beg); + std::vector file_data(file_size); + file.read(reinterpret_cast(file_data.data()), file_data.size()); + + const char* candidate_paths[] = { + "voice_embed/data/0", "archive/data/0", "data/0"}; + const unsigned char* tensor_data = nullptr; + size_t tensor_size = 0; + bool found = false; + for (const char* candidate : candidate_paths) { + if (find_zip_entry(file_data, candidate, tensor_data, tensor_size)) { + found = true; + break; + } + } + if (!found || + tensor_size % (static_cast(dim) * sizeof(uint16_t)) != 0) { + return false; + } + + out_frames = static_cast( + tensor_size / (static_cast(dim) * sizeof(uint16_t))); + load_bf16_tensor_data( + reinterpret_cast(tensor_data), + static_cast(out_frames) * static_cast(dim), + out_data); + return true; +} + +bool load_bin_voice_tensor( + const std::filesystem::path& path, + int64_t dim, + int64_t expected_frames_hint, + std::vector& out_data, + int64_t& out_frames) { + std::ifstream file(path, std::ios::binary | std::ios::ate); + if (!file.is_open()) { + return false; + } + const auto file_size = static_cast(file.tellg()); + file.seekg(0, std::ios::beg); + std::vector raw(file_size); + file.read(reinterpret_cast(raw.data()), raw.size()); + + const size_t bf16_row_bytes = static_cast(dim) * sizeof(uint16_t); + const size_t f32_row_bytes = static_cast(dim) * sizeof(float); + const bool matches_hint_bf16 = expected_frames_hint > 0 && + file_size == static_cast(expected_frames_hint) * bf16_row_bytes; + const bool matches_hint_f32 = expected_frames_hint > 0 && + file_size == static_cast(expected_frames_hint) * f32_row_bytes; + + if (matches_hint_f32) { + out_frames = expected_frames_hint; + out_data.resize(static_cast(out_frames) * static_cast(dim)); + std::memcpy(out_data.data(), raw.data(), raw.size()); + return true; + } + + if (matches_hint_bf16 || file_size % bf16_row_bytes == 0) { + out_frames = static_cast(file_size / bf16_row_bytes); + load_bf16_tensor_data( + reinterpret_cast(raw.data()), + static_cast(out_frames) * static_cast(dim), + out_data); + return true; + } + + if (file_size % f32_row_bytes == 0) { + out_frames = static_cast(file_size / f32_row_bytes); + out_data.resize(static_cast(out_frames) * static_cast(dim)); + std::memcpy(out_data.data(), raw.data(), raw.size()); + return true; + } + return false; +} + +} // namespace + +VoxtralTTSRunner::VoxtralTTSRunner( + const std::string& model_path, + const std::string& codec_path, + const std::string& tokenizer_path, + const std::string& model_data_path, + const std::string& codec_data_path) + : rng_(42), + flow_rng_state_(42), + asset_root_dir_(std::filesystem::path(tokenizer_path).parent_path()), + model_path_(model_path), + model_data_path_(model_data_path), + codec_data_path_(codec_data_path), + // Mixed-precision CUDA exports keep weights/IO in fp32 — only the SDPA + // path is bf16 internally. Runner can stay fp32-native for all methods. + // Flip to true if a future export ships fully-bf16 methods. + lm_use_bf16_( + false) { // Updated from model.pte metadata in load_metadata() + // For CUDA backend, the .ptd file holds the AOTI .so and weights. + if (!model_data_path_.empty()) { + model_ = std::make_unique( + model_path, model_data_path_, Module::LoadMode::Mmap); + } else { + model_ = std::make_unique(model_path, Module::LoadMode::Mmap); + } + ET_CHECK_MSG(model_->load() == Error::Ok, "Failed to load model."); + + if (!codec_data_path_.empty()) { + codec_ = std::make_unique( + codec_path, codec_data_path_, Module::LoadMode::Mmap); + } else { + codec_ = std::make_unique(codec_path, Module::LoadMode::Mmap); + } + ET_CHECK_MSG(codec_->load() == Error::Ok, "Failed to load codec decoder."); + + tokenizer_ = ::executorch::extension::llm::load_tokenizer(tokenizer_path); + ET_CHECK_MSG(tokenizer_ != nullptr, "Failed to load tokenizer."); + + load_metadata(); + warmup(); +} + +void VoxtralTTSRunner::set_trace_output_path( + const std::string& trace_output_path) { + trace_output_path_ = trace_output_path; +} + +void VoxtralTTSRunner::set_seed(uint32_t seed) { + seed_ = seed; + rng_.seed(seed_); + // xorshift64 needs a non-zero state; matches voxtral-tts.c tts_rng_seed. + flow_rng_state_ = + seed_ ? static_cast(seed_) : 0x12345678ABCDEF01ULL; +} + +namespace { +// xorshift64 + Box-Muller, matching voxtral-tts.c voxtral_tts_kernels.c:644-668 +// so flow-matching x0 noise is bit-identical to the C reference under same +// seed. +inline uint64_t xorshift64(uint64_t* state) { + uint64_t x = *state; + x ^= x << 13; + x ^= x >> 7; + x ^= x << 17; + *state = x; + return x; +} +inline float uniform01_xs(uint64_t* state) { + return static_cast(xorshift64(state) >> 11) * + (1.0f / 9007199254740992.0f); +} +inline float randn_xs(uint64_t* state) { + float u1, u2; + do { + u1 = uniform01_xs(state); + } while (u1 < 1e-30f); + u2 = uniform01_xs(state); + return std::sqrt(-2.0f * std::log(u1)) * std::cos(6.2831853071795864f * u2); +} +} // namespace + +void VoxtralTTSRunner::reload_stateful_model() { + if (!model_data_path_.empty()) { + model_ = std::make_unique( + model_path_, model_data_path_, Module::LoadMode::Mmap); + } else { + model_ = std::make_unique(model_path_, Module::LoadMode::Mmap); + } + ET_CHECK_MSG(model_->load() == Error::Ok, "Failed to reload model."); + load_metadata(); +} + +void VoxtralTTSRunner::load_metadata() { + sample_rate_ = read_metadata_int(*model_, "sample_rate", 24000); + n_decoding_steps_ = read_metadata_int(*model_, "n_decoding_steps", 7); + int64_t alpha_x100 = read_metadata_int(*model_, "cfg_alpha_x100", 120); + cfg_alpha_ = static_cast(alpha_x100) / 100.0f; + n_acoustic_codebook_ = read_metadata_int(*model_, "n_acoustic_codebook", 36); + acoustic_levels_ = read_metadata_int(*model_, "acoustic_levels", 21); + n_special_tokens_ = read_metadata_int(*model_, "n_special_tokens", 2); + vocab_size_ = read_metadata_int(*model_, "vocab_size", 131072); + max_seq_len_ = read_metadata_int(*model_, "max_seq_len", 4096); + dim_ = read_metadata_int(*model_, "dim", 3072); + downsample_factor_ = read_metadata_int(*model_, "downsample_factor", 1920); + n_codebooks_ = read_metadata_int(*model_, "n_codebooks", 37); + end_audio_code_ = read_metadata_int(*model_, "end_audio_code", 1); + empty_audio_code_ = read_metadata_int(*model_, "empty_audio_code", 0); + audio_token_id_ = read_metadata_int(*model_, "audio_token_id", 24); + begin_audio_token_id_ = + read_metadata_int(*model_, "begin_audio_token_id", 25); + text_to_audio_token_id_ = + read_metadata_int(*model_, "text_to_audio_token_id", 36); + repeat_audio_text_token_id_ = + read_metadata_int(*model_, "repeat_audio_text_token_id", 35); + voice_embed_len_ = read_metadata_int(*model_, "voice_embed_len", 147); + // Whether the LM methods (text_decoder, semantic_head, predict_velocity, + // audio_token_embedding, token_embedding) take/return bf16. Set by the + // export script — fp32 mixed-precision exports report 0; quantized exports + // (--qlinear …) promote to bf16 and report 1. + lm_use_bf16_ = read_metadata_int(*model_, "lm_input_is_bf16", 0) != 0; + + is_streaming_ = read_metadata_int(*model_, "streaming", 0) != 0; + streaming_chunk_frames_ = + read_metadata_int(*model_, "streaming_chunk_frames", 25); + streaming_initial_chunk_ = + read_metadata_int(*model_, "streaming_initial_chunk", 5); + streaming_left_context_ = + read_metadata_int(*model_, "streaming_left_context", 25); + + max_codec_frames_ = read_metadata_int(*codec_, "max_codec_frames", 256); + codec_supports_exact_frames_ = + has_method(*codec_, "codec_supports_exact_frames") + ? (read_metadata_int(*codec_, "codec_supports_exact_frames", 0) != 0) + : false; + + std::cerr << "Model config: dim=" << dim_ + << " voice_embed_len=" << voice_embed_len_ + << " audio_tok=" << audio_token_id_ + << " begin_audio=" << begin_audio_token_id_ + << " max_seq=" << max_seq_len_ + << " codec_frames=" << max_codec_frames_ << std::endl; +} + +std::filesystem::path VoxtralTTSRunner::resolve_voice_path( + const std::string& voice_path) const { + const std::string requested = + voice_path.empty() ? "neutral_female" : voice_path; + std::filesystem::path candidate(requested); + if (std::filesystem::exists(candidate)) { + return candidate; + } + + const auto voice_dir = asset_root_dir_ / "voice_embedding"; + if (candidate.has_extension()) { + auto local_candidate = voice_dir / candidate.filename(); + if (std::filesystem::exists(local_candidate)) { + return local_candidate; + } + return candidate; + } + + for (const char* ext : {".pt", ".bin"}) { + auto local_candidate = voice_dir / (requested + ext); + if (std::filesystem::exists(local_candidate)) { + return local_candidate; + } + } + + return voice_dir / (requested + ".pt"); +} + +void VoxtralTTSRunner::load_voice_embedding(const std::string& voice_path) { + voice_embed_data_.clear(); + runtime_voice_embed_len_ = 0; + + const auto resolved_path = resolve_voice_path(voice_path); + if (!std::filesystem::exists(resolved_path)) { + if (voice_path.empty()) { + std::cerr << "No default voice embedding found at " << resolved_path + << ", continuing without voice conditioning." << std::endl; + return; + } + ET_CHECK_MSG( + false, + "Failed to open voice embedding: %s", + resolved_path.string().c_str()); + } + + bool ok = false; + if (resolved_path.extension() == ".pt") { + ok = load_pt_voice_tensor( + resolved_path, dim_, voice_embed_data_, runtime_voice_embed_len_); + } else { + int64_t expected_frames_hint = voice_embed_len_; + auto pt_peer = resolved_path; + pt_peer.replace_extension(".pt"); + if (std::filesystem::exists(pt_peer)) { + std::vector peer_voice_data; + int64_t peer_frames = 0; + if (load_pt_voice_tensor(pt_peer, dim_, peer_voice_data, peer_frames)) { + expected_frames_hint = peer_frames; + } + } + ok = load_bin_voice_tensor( + resolved_path, + dim_, + expected_frames_hint, + voice_embed_data_, + runtime_voice_embed_len_); + } + ET_CHECK_MSG( + ok, + "Failed to load voice embedding from %s", + resolved_path.string().c_str()); + + std::cerr << "Loaded voice embedding: " << runtime_voice_embed_len_ << " x " + << dim_ << " from " << resolved_path << std::endl; +} + +int64_t VoxtralTTSRunner::sample_semantic_code( + const float* logits, + int64_t vocab_size, + float temperature) { + if (temperature <= 0.0f) { + int64_t best = 0; + float best_val = logits[0]; + for (int64_t i = 1; i < vocab_size; ++i) { + if (logits[i] > best_val) { + best_val = logits[i]; + best = i; + } + } + return best; + } + float max_val = *std::max_element(logits, logits + vocab_size); + std::vector probs(vocab_size); + float sum = 0; + for (int64_t i = 0; i < vocab_size; ++i) { + probs[i] = std::exp((logits[i] - max_val) / temperature); + sum += probs[i]; + } + for (auto& p : probs) + p /= sum; + + std::discrete_distribution dist(probs.begin(), probs.end()); + return dist(rng_); +} + +void VoxtralTTSRunner::warmup() { + std::cerr << "Warming up..." << std::endl; + int dim = static_cast(dim_); + int n_aco = static_cast(n_acoustic_codebook_); + int n_cb = static_cast(n_codebooks_); + int mcf = static_cast(max_codec_frames_); + + int64_t tok_data = 1; + auto tok_t = from_blob(&tok_data, {1, 1}, ScalarType::Long); + auto token_embed_result = + model_->execute("token_embedding", std::vector{*tok_t}); + ET_CHECK_MSG(token_embed_result.ok(), "token_embedding warmup failed"); + + std::vector audio_code_data(n_cb, 0); + auto audio_codes_t = + from_blob(audio_code_data.data(), {1, n_cb, 1}, ScalarType::Long); + auto audio_embed_result = model_->execute( + "audio_token_embedding", std::vector{*audio_codes_t}); + ET_CHECK_MSG(audio_embed_result.ok(), "audio_token_embedding warmup failed"); + + // For bf16 LMs (CUDA quantized exports), allocate bf16 zero buffers and + // call with ScalarType::BFloat16. fp32 LMs (default mixed-precision and + // CPU backends) take the simple fp32 path. + std::vector embed_data_fp32; + std::vector embed_data_bf16; + std::vector xt_data_fp32; + std::vector xt_data_bf16; + TensorPtr hid_t, hv_t, xt_t; + ScalarType float_type = + lm_use_bf16_ ? ScalarType::BFloat16 : ScalarType::Float; + if (lm_use_bf16_) { + embed_data_bf16.assign(dim, fp32_to_bf16(0.0f)); + xt_data_bf16.assign(n_aco, fp32_to_bf16(0.0f)); + hid_t = from_blob(embed_data_bf16.data(), {1, dim}, float_type); + hv_t = from_blob(embed_data_bf16.data(), {1, dim}, float_type); + xt_t = from_blob(xt_data_bf16.data(), {1, n_aco}, float_type); + } else { + embed_data_fp32.assign(dim, 0.0f); + xt_data_fp32.assign(n_aco, 0.0f); + hid_t = from_blob(embed_data_fp32.data(), {1, dim}, float_type); + hv_t = from_blob(embed_data_fp32.data(), {1, dim}, float_type); + xt_t = from_blob(xt_data_fp32.data(), {1, n_aco}, float_type); + } + + // Avoid warming the stateful decoder because the Module API does not expose + // a cache reset; a dummy prefill would pollute the first real synthesis. + auto semantic_result = + model_->execute("semantic_head", std::vector{*hid_t}); + ET_CHECK_MSG(semantic_result.ok(), "semantic_head warmup failed"); + + int64_t tidx_data = 0; + auto ti_t = from_blob(&tidx_data, {1}, ScalarType::Long); + auto velocity_result = model_->execute( + "predict_velocity", std::vector{*xt_t, *ti_t, *hv_t}); + ET_CHECK_MSG(velocity_result.ok(), "predict_velocity warmup failed"); + + // Codec is on the portable (CPU) backend — there's no Triton autotune to + // amortize, and one forward over max_codec_frames=256 takes ~60-120s on CPU. + // Skipping it cuts startup wait substantially; the first real codec call + // at end-of-synth pays the same cost it always would. + (void)mcf; + (void)n_cb; + + std::cerr << "Warmup complete." << std::endl; +} + +std::vector VoxtralTTSRunner::tokenize(const std::string& text) { + auto encoded = tokenizer_->encode(text, /*bos=*/0, /*eos=*/0); + ET_CHECK_MSG(encoded.ok(), "Tokenizer encode failed"); + std::vector result; + result.reserve(encoded.get().size()); + for (auto id : encoded.get()) { + result.push_back(static_cast(id)); + } + return result; +} + +void VoxtralTTSRunner::build_prompt( + const std::string& text, + std::vector& token_ids, + int& voice_start, + int& voice_len) { + // Match mistral_common encode_speech_request(): + // [BOS] [BEGIN_AUDIO] [AUDIO]*N [TEXT_TO_AUDIO] {text_tokens} + // [AUDIO_TO_TEXT] [BEGIN_AUDIO] + auto text_tokens = tokenize(text); + + token_ids.clear(); + token_ids.push_back(1); // BOS + token_ids.push_back(begin_audio_token_id_); // [BEGIN_AUDIO] + + voice_start = static_cast(token_ids.size()); + voice_len = static_cast(runtime_voice_embed_len_); + for (int i = 0; i < voice_len; ++i) { + token_ids.push_back(audio_token_id_); // [AUDIO] placeholder + } + + token_ids.push_back(text_to_audio_token_id_); + for (auto t : text_tokens) { + token_ids.push_back(t); + } + token_ids.push_back(repeat_audio_text_token_id_); // [REPEAT_AUDIO_TEXT] + token_ids.push_back(begin_audio_token_id_); // [BEGIN_AUDIO] + + std::cerr << "Prompt: " << token_ids.size() + << " tokens (voice_start=" << voice_start + << " voice_len=" << voice_len + << " text_tokens=" << text_tokens.size() << ")" << std::endl; +} + +void VoxtralTTSRunner::synthesize_offline( + const std::string& text, + const std::string& voice_path, + const std::string& output_path, + float temperature, + int max_new_tokens) { + auto start = std::chrono::high_resolution_clock::now(); + int dim = static_cast(dim_); + int n_aco = static_cast(n_acoustic_codebook_); + int n_cb = static_cast(n_codebooks_); + const bool capture_trace = !trace_output_path_.empty(); + json trace; + + reload_stateful_model(); + rng_.seed(seed_); + dim = static_cast(dim_); + n_aco = static_cast(n_acoustic_codebook_); + n_cb = static_cast(n_codebooks_); + + load_voice_embedding(voice_path); + const auto resolved_voice_path = resolve_voice_path(voice_path); + + std::vector token_ids; + int voice_start, voice_len; + build_prompt(text, token_ids, voice_start, voice_len); + int prompt_len = static_cast(token_ids.size()); + if (capture_trace) { + trace = { + {"mode", "runner_exported"}, + {"text", text}, + {"voice_path", resolved_voice_path.string()}, + {"seed", seed_}, + {"prompt_token_ids", token_ids}, + {"voice_start", voice_start}, + {"voice_len", voice_len}, + {"seed_step_applied", false}, + {"frames", json::array()}, + }; + } + + // Embed all tokens + auto tok_t = from_blob(token_ids.data(), {1, prompt_len}, ScalarType::Long); + auto embed_result = + model_->execute("token_embedding", std::vector{*tok_t}); + ET_CHECK_MSG(embed_result.ok(), "token_embedding failed"); + auto embeds = embed_result.get()[0].toTensor(); + + // Read prompt embeddings as fp32 (the LM output is bf16 on CUDA, fp32 on CPU) + // so the voice-embedding splice (always fp32) lands in a known dtype. + std::vector prompt_embeds_fp32(static_cast(prompt_len) * dim); + read_float_tensor( + embeds, prompt_embeds_fp32.data(), prompt_embeds_fp32.size()); + + // Splice voice embedding into [AUDIO] positions + if (!voice_embed_data_.empty()) { + for (int i = 0; i < voice_len; ++i) { + int pos = voice_start + i; + std::memcpy( + prompt_embeds_fp32.data() + pos * dim, + voice_embed_data_.data() + i * dim, + dim * sizeof(float)); + } + std::cerr << "Voice embedding spliced at positions " << voice_start << ".." + << (voice_start + voice_len - 1) << std::endl; + } + + // Prefill decoder with combined embeddings. For CUDA, stage the fp32 buffer + // through a bf16 buffer that survives until execute() returns. + std::vector pos_vec(prompt_len); + std::iota(pos_vec.begin(), pos_vec.end(), 0); + auto pos_t = from_blob(pos_vec.data(), {prompt_len}, ScalarType::Long); + + std::vector prompt_embeds_bf16; + TensorPtr emb_t; + if (lm_use_bf16_) { + prompt_embeds_bf16.resize(prompt_embeds_fp32.size()); + write_float_to_bf16( + prompt_embeds_bf16.data(), + prompt_embeds_fp32.data(), + prompt_embeds_fp32.size()); + emb_t = from_blob( + prompt_embeds_bf16.data(), {1, prompt_len, dim}, ScalarType::BFloat16); + } else { + emb_t = from_blob( + prompt_embeds_fp32.data(), {1, prompt_len, dim}, ScalarType::Float); + } + auto dec_result = + model_->execute("text_decoder", std::vector{*emb_t, *pos_t}); + ET_CHECK_MSG(dec_result.ok(), "text_decoder prefill failed"); + + auto hidden_out = dec_result.get()[0].toTensor(); + // Read final-position hidden state into fp32 (handles bf16 → fp32 if needed). + std::vector hidden_state(dim); + if (hidden_out.scalar_type() == ScalarType::BFloat16) { + const uint16_t* h = + hidden_out.const_data_ptr() + (prompt_len - 1) * dim; + for (int i = 0; i < dim; ++i) + hidden_state[i] = bf16_to_fp32(h[i]); + } else { + std::memcpy( + hidden_state.data(), + hidden_out.const_data_ptr() + (prompt_len - 1) * dim, + static_cast(dim) * sizeof(float)); + } + + std::vector prefill_hidden(hidden_state); + + std::vector seed_token{audio_token_id_}; + auto seed_tok_t = from_blob(seed_token.data(), {1, 1}, ScalarType::Long); + auto seed_embed_result = + model_->execute("token_embedding", std::vector{*seed_tok_t}); + ET_CHECK_MSG(seed_embed_result.ok(), "token_embedding seed step failed"); + auto seed_embed = seed_embed_result.get()[0].toTensor(); + + int64_t seed_pos_val = prompt_len; + auto seed_pos_t = from_blob(&seed_pos_val, {1}, ScalarType::Long); + // seed_embed is already in the model's native dtype (bf16 on CUDA, fp32 on + // CPU); pass it through directly. + auto seed_emb_t = from_blob( + seed_embed.mutable_data_ptr(), {1, 1, dim}, seed_embed.scalar_type()); + auto seed_decode_result = model_->execute( + "text_decoder", std::vector{*seed_emb_t, *seed_pos_t}); + ET_CHECK_MSG(seed_decode_result.ok(), "text_decoder seed step failed"); + read_float_tensor( + seed_decode_result.get()[0].toTensor(), hidden_state.data(), dim); + if (capture_trace) { + trace["prefill_hidden"] = prefill_hidden; + trace["frame0_hidden"] = hidden_state; + trace["seed_hidden"] = hidden_state; + trace["seed_position"] = prompt_len; + trace["frame0_position"] = prompt_len; + trace["seed_step_applied"] = true; + } + + // Autoregressive decode + std::vector> frame_codes; + int64_t cur_pos = prompt_len + 1; + + std::vector timesteps(n_decoding_steps_ + 1); + for (int i = 0; i <= n_decoding_steps_; ++i) { + timesteps[i] = + static_cast(i) / static_cast(n_decoding_steps_); + } + + // Per-frame staging buffers for bf16 conversion (alloc once, reuse). + std::vector h_bf16; + std::vector sem_logits_fp32; + + for (int frame = 0; frame < max_new_tokens && cur_pos < max_seq_len_; + ++frame) { + TensorPtr h_t; + if (lm_use_bf16_) { + h_bf16.resize(dim); + write_float_to_bf16(h_bf16.data(), hidden_state.data(), dim); + h_t = from_blob(h_bf16.data(), {1, dim}, ScalarType::BFloat16); + } else { + h_t = from_blob(hidden_state.data(), {1, dim}, ScalarType::Float); + } + auto sem_r = model_->execute("semantic_head", std::vector{*h_t}); + ET_CHECK_MSG(sem_r.ok(), "semantic_head failed"); + + auto sem_t = sem_r.get()[0].toTensor(); + int64_t sem_vocab = sem_t.numel(); + sem_logits_fp32.resize(sem_vocab); + read_float_tensor(sem_t, sem_logits_fp32.data(), sem_vocab); + json semantic_topk = json::array(); + if (capture_trace && frame < 3) { + semantic_topk = topk_logits(sem_logits_fp32.data(), sem_vocab); + } + int64_t semantic_code = + sample_semantic_code(sem_logits_fp32.data(), sem_vocab, temperature); + + if (semantic_code == end_audio_code_) { + if (capture_trace && frame < 3) { + trace["frames"].push_back({ + {"frame", frame}, + {"hidden_norm_before_frame", + std::sqrt(std::inner_product( + hidden_state.begin(), + hidden_state.end(), + hidden_state.begin(), + 0.0f))}, + {"semantic_code", semantic_code}, + {"semantic_topk", semantic_topk}, + {"full_codes", json::array()}, + {"end_audio", true}, + }); + } + if (capture_trace) { + trace["end_audio_at_frame"] = frame; + } + std::cerr << "END_AUDIO at frame " << frame << std::endl; + break; + } + + // Flow matching ODE (7 steps with CFG) + // Use xorshift64 + Box-Muller (matches voxtral-tts.c) so x0 noise is + // bit-identical to the C reference under the same seed. + std::vector x(n_aco); + for (auto& v : x) { + v = randn_xs(&flow_rng_state_); + } + std::vector zeros(dim, 0.0f); + + { + // Pre-stage hidden_state (and zeros for uncond) into bf16 once per frame + // so the n_decoding_steps × 2 inner velocity calls reuse the buffers. + std::vector step_h_bf16, step_zeros_bf16, step_x_bf16; + if (lm_use_bf16_) { + step_h_bf16.resize(dim); + write_float_to_bf16(step_h_bf16.data(), hidden_state.data(), dim); + step_zeros_bf16.assign(dim, fp32_to_bf16(0.0f)); + step_x_bf16.resize(n_aco); + } + std::vector v_cond(n_aco); + std::vector v_uncond(n_aco); + + for (int step = 0; step < n_decoding_steps_; ++step) { + float dt = timesteps[step + 1] - timesteps[step]; + int64_t tidx_val = step; + + TensorPtr xt1, hc; + if (lm_use_bf16_) { + write_float_to_bf16(step_x_bf16.data(), x.data(), n_aco); + xt1 = from_blob(step_x_bf16.data(), {1, n_aco}, ScalarType::BFloat16); + hc = from_blob(step_h_bf16.data(), {1, dim}, ScalarType::BFloat16); + } else { + xt1 = from_blob(x.data(), {1, n_aco}, ScalarType::Float); + hc = from_blob(hidden_state.data(), {1, dim}, ScalarType::Float); + } + auto ti1 = from_blob(&tidx_val, {1}, ScalarType::Long); + auto vc = model_->execute( + "predict_velocity", std::vector{*xt1, *ti1, *hc}); + ET_CHECK_MSG(vc.ok(), "predict_velocity (cond) failed"); + read_float_tensor(vc.get()[0].toTensor(), v_cond.data(), n_aco); + + TensorPtr xt2, hu; + if (lm_use_bf16_) { + // step_x_bf16 already holds x; reuse. + xt2 = from_blob(step_x_bf16.data(), {1, n_aco}, ScalarType::BFloat16); + hu = + from_blob(step_zeros_bf16.data(), {1, dim}, ScalarType::BFloat16); + } else { + xt2 = from_blob(x.data(), {1, n_aco}, ScalarType::Float); + hu = from_blob(zeros.data(), {1, dim}, ScalarType::Float); + } + auto ti2 = from_blob(&tidx_val, {1}, ScalarType::Long); + auto vu = model_->execute( + "predict_velocity", std::vector{*xt2, *ti2, *hu}); + ET_CHECK_MSG(vu.ok(), "predict_velocity (uncond) failed"); + read_float_tensor(vu.get()[0].toTensor(), v_uncond.data(), n_aco); + + for (int j = 0; j < n_aco; ++j) { + float v = cfg_alpha_ * v_cond[j] + (1.0f - cfg_alpha_) * v_uncond[j]; + x[j] += v * dt; + } + } + } + + // Quantize acoustic codes + std::vector codes(n_codebooks_); + codes[0] = semantic_code; + float x_min = std::numeric_limits::infinity(); + float x_max = -std::numeric_limits::infinity(); + for (int j = 0; j < n_aco; ++j) { + float clamped = std::clamp(x[j], -1.0f, 1.0f); + x_min = std::min(x_min, clamped); + x_max = std::max(x_max, clamped); + float scaled = + ((clamped + 1.0f) / 2.0f) * static_cast(acoustic_levels_ - 1); + codes[j + 1] = + static_cast(std::round(scaled)) + n_special_tokens_; + } + frame_codes.push_back(codes); + // Optional per-frame code dump for parity vs voxtral-tts.c reference. + if (const char* dump_path = std::getenv("VOXTRAL_DUMP_CODES")) { + std::ofstream dump( + dump_path, frame == 0 ? std::ios::trunc : std::ios::app); + if (dump) { + dump << "frame=" << frame << " sem=" << semantic_code << " codes="; + for (size_t k = 0; k < codes.size(); ++k) { + if (k) + dump << ","; + dump << codes[k]; + } + dump << "\n"; + } + } + if (capture_trace && frame == 0) { + trace["frame0_full_codes"] = codes; + } + if (capture_trace && frame < 3) { + trace["frames"].push_back({ + {"frame", frame}, + {"hidden_norm_before_frame", + std::sqrt(std::inner_product( + hidden_state.begin(), + hidden_state.end(), + hidden_state.begin(), + 0.0f))}, + {"semantic_code", semantic_code}, + {"semantic_topk", semantic_topk}, + {"full_codes", codes}, + {"x_min", x_min}, + {"x_max", x_max}, + }); + } + + // Feed the generated multi-codebook frame back through the learned + // audio-token embedding path instead of the generic [AUDIO] placeholder. + auto next_codes = from_blob(codes.data(), {1, n_cb, 1}, ScalarType::Long); + auto ne = model_->execute( + "audio_token_embedding", std::vector{*next_codes}); + ET_CHECK_MSG(ne.ok(), "audio_token_embedding (next) failed"); + auto next_embeds = ne.get()[0].toTensor(); + if (capture_trace && frame == 0) { + std::vector first_audio_embed(dim); + read_float_tensor(next_embeds, first_audio_embed.data(), dim); + trace["frame0_audio_embed"] = first_audio_embed; + } + + // Pass next_embeds through directly — its dtype already matches the + // text_decoder's expected input dtype (bf16 for CUDA, fp32 for CPU). + int64_t next_pos_val = cur_pos; + auto np = from_blob(&next_pos_val, {1}, ScalarType::Long); + auto next_emb = from_blob( + next_embeds.mutable_data_ptr(), {1, 1, dim}, next_embeds.scalar_type()); + auto nd = + model_->execute("text_decoder", std::vector{*next_emb, *np}); + ET_CHECK_MSG(nd.ok(), "text_decoder (next) failed"); + read_float_tensor(nd.get()[0].toTensor(), hidden_state.data(), dim); + if (capture_trace && frame == 0) { + trace["frame1_position"] = cur_pos; + trace["frame1_hidden"] = hidden_state; + } + cur_pos++; + + if ((frame + 1) % 25 == 0) { + float audio_sec = static_cast((frame + 1) * downsample_factor_) / + static_cast(sample_rate_); + std::cerr << " Frame " << (frame + 1) << " (" << audio_sec << "s audio)" + << std::endl; + } + } + + auto gen_end = std::chrono::high_resolution_clock::now(); + + if (frame_codes.empty()) { + if (capture_trace) { + trace["generated_frames"] = 0; + trace["waveform"] = waveform_stats({}); + write_trace_json(trace_output_path_, trace); + std::cerr << "Wrote trace JSON: " << trace_output_path_ << std::endl; + } + std::cerr << "No audio frames generated." << std::endl; + return; + } + + int64_t total_frames = static_cast(frame_codes.size()); + float audio_duration = static_cast(total_frames * downsample_factor_) / + static_cast(sample_rate_); + auto gen_ms = + std::chrono::duration_cast(gen_end - start) + .count(); + + std::cerr << "Generated " << total_frames << " frames (" << audio_duration + << "s audio) in " << gen_ms << "ms" << std::endl; + std::cerr << "RTF: " + << (static_cast(gen_ms) / 1000.0f) / audio_duration + << std::endl; + + std::vector decoded_samples; + decode_codes_to_wav( + frame_codes, output_path, capture_trace ? &decoded_samples : nullptr); + if (capture_trace) { + trace["generated_frames"] = total_frames; + trace["waveform"] = waveform_stats(decoded_samples); + write_trace_json(trace_output_path_, trace); + std::cerr << "Wrote trace JSON: " << trace_output_path_ << std::endl; + } + + auto total_end = std::chrono::high_resolution_clock::now(); + auto total_ms = + std::chrono::duration_cast(total_end - start) + .count(); + std::cerr << "Total time: " << total_ms << "ms" << std::endl; +} + +void VoxtralTTSRunner::decode_codes_to_wav( + const std::vector>& frame_codes, + const std::string& output_path, + std::vector* out_samples) { + int64_t n_frames = static_cast(frame_codes.size()); + + std::vector all_samples; + for (int64_t s = 0; s < n_frames; s += max_codec_frames_) { + int64_t e = std::min(s + max_codec_frames_, n_frames); + std::vector chunk_samples; + decode_code_window(frame_codes, s, e, chunk_samples); + all_samples.insert( + all_samples.end(), chunk_samples.begin(), chunk_samples.end()); + } + + WavWriter wav(output_path, static_cast(sample_rate_)); + if (!wav.IsOpen()) { + std::cerr << "Failed to open output: " << output_path << std::endl; + return; + } + wav.Write(all_samples.data(), all_samples.size()); + wav.Close(); + if (out_samples != nullptr) { + *out_samples = all_samples; + } + std::cerr << "Wrote " << all_samples.size() << " samples to " << output_path + << std::endl; +} + +void VoxtralTTSRunner::decode_code_window( + const std::vector>& frame_codes, + int64_t start_frame, + int64_t end_frame, + std::vector& out_samples) { + int64_t window_frames = end_frame - start_frame; + ET_CHECK_MSG(window_frames > 0, "codec decode window must be non-empty"); + int n_cb = static_cast(n_codebooks_); + int mcf = static_cast(max_codec_frames_); + ET_CHECK_MSG( + window_frames <= max_codec_frames_, + "codec decode window exceeds exported maximum"); + + auto build_code_tensor = [&](int64_t target_frames) { + std::vector code_data( + static_cast(n_cb) * static_cast(target_frames), 0); + for (int64_t f = 0; f < target_frames; ++f) { + // Pad beyond window_frames by repeating the last valid frame so the + // codec's transformer attention sees a smooth extension instead of the + // FSQ=-1.0 cliff that zero-padding produces. + int64_t src = f < window_frames ? (start_frame + f) + : (start_frame + window_frames - 1); + for (int64_t c = 0; c < n_codebooks_; ++c) { + code_data[c * target_frames + f] = frame_codes[src][c]; + } + } + return code_data; + }; + + auto copy_waveform = [&](const auto& exec_result) { + auto waveform = exec_result.get()[0].toTensor(); + int64_t valid_samples = window_frames * downsample_factor_; + int64_t total_samples = waveform.numel(); + valid_samples = std::min(valid_samples, total_samples); + out_samples.resize(static_cast(valid_samples)); + // Codec output is fp32 on portable, bf16 on CUDA — handle both. + read_float_tensor(waveform, out_samples.data(), out_samples.size()); + }; + + const bool try_exact = + codec_supports_exact_frames_ || window_frames == max_codec_frames_; + if (try_exact) { + auto code_data = build_code_tensor(window_frames); + auto codes_t = from_blob( + code_data.data(), + {1, n_cb, static_cast(window_frames)}, + ScalarType::Long); + auto exact_result = + codec_->execute("forward", std::vector{*codes_t}); + if (exact_result.ok()) { + copy_waveform(exact_result); + return; + } + } + + auto padded_code_data = build_code_tensor(mcf); + auto padded_codes_t = + from_blob(padded_code_data.data(), {1, n_cb, mcf}, ScalarType::Long); + auto padded_result = + codec_->execute("forward", std::vector{*padded_codes_t}); + ET_CHECK_MSG(padded_result.ok(), "codec decode failed"); + copy_waveform(padded_result); +} + +void VoxtralTTSRunner::synthesize_streaming( + const std::string& text, + const std::string& voice_path, + const std::string& output_path, + AudioChunkCallback callback, + float temperature, + int max_new_tokens) { + auto start_time = std::chrono::high_resolution_clock::now(); + int dim = static_cast(dim_); + int n_aco = static_cast(n_acoustic_codebook_); + int n_cb = static_cast(n_codebooks_); + + reload_stateful_model(); + rng_.seed(seed_); + dim = static_cast(dim_); + n_aco = static_cast(n_acoustic_codebook_); + n_cb = static_cast(n_codebooks_); + + load_voice_embedding(voice_path); + + std::vector token_ids; + int voice_start, voice_len; + build_prompt(text, token_ids, voice_start, voice_len); + int prompt_len = static_cast(token_ids.size()); + + // Embed + splice voice + auto tok_t = from_blob(token_ids.data(), {1, prompt_len}, ScalarType::Long); + auto embed_result = + model_->execute("token_embedding", std::vector{*tok_t}); + ET_CHECK_MSG(embed_result.ok(), "token_embedding failed"); + auto embeds = embed_result.get()[0].toTensor(); + + // Read prompt embeddings into fp32 so the (always fp32) voice splice lands + // in a known dtype even when the model emits bf16. + std::vector prompt_embeds_fp32(static_cast(prompt_len) * dim); + read_float_tensor( + embeds, prompt_embeds_fp32.data(), prompt_embeds_fp32.size()); + + if (!voice_embed_data_.empty()) { + for (int i = 0; i < voice_len; ++i) { + std::memcpy( + prompt_embeds_fp32.data() + (voice_start + i) * dim, + voice_embed_data_.data() + i * dim, + dim * sizeof(float)); + } + } + + // Prefill + std::vector pos_vec(prompt_len); + std::iota(pos_vec.begin(), pos_vec.end(), 0); + auto pos_t = from_blob(pos_vec.data(), {prompt_len}, ScalarType::Long); + + std::vector prompt_embeds_bf16; + TensorPtr emb_t; + if (lm_use_bf16_) { + prompt_embeds_bf16.resize(prompt_embeds_fp32.size()); + write_float_to_bf16( + prompt_embeds_bf16.data(), + prompt_embeds_fp32.data(), + prompt_embeds_fp32.size()); + emb_t = from_blob( + prompt_embeds_bf16.data(), {1, prompt_len, dim}, ScalarType::BFloat16); + } else { + emb_t = from_blob( + prompt_embeds_fp32.data(), {1, prompt_len, dim}, ScalarType::Float); + } + auto dec_result = + model_->execute("text_decoder", std::vector{*emb_t, *pos_t}); + ET_CHECK_MSG(dec_result.ok(), "text_decoder prefill failed"); + + auto hidden_out = dec_result.get()[0].toTensor(); + std::vector hidden_state(dim); + if (hidden_out.scalar_type() == ScalarType::BFloat16) { + const uint16_t* h = + hidden_out.const_data_ptr() + (prompt_len - 1) * dim; + for (int i = 0; i < dim; ++i) + hidden_state[i] = bf16_to_fp32(h[i]); + } else { + std::memcpy( + hidden_state.data(), + hidden_out.const_data_ptr() + (prompt_len - 1) * dim, + static_cast(dim) * sizeof(float)); + } + + std::vector seed_token{audio_token_id_}; + auto seed_tok_t = from_blob(seed_token.data(), {1, 1}, ScalarType::Long); + auto seed_embed_result = + model_->execute("token_embedding", std::vector{*seed_tok_t}); + ET_CHECK_MSG(seed_embed_result.ok(), "token_embedding seed step failed"); + auto seed_embed = seed_embed_result.get()[0].toTensor(); + + int64_t seed_pos_val = prompt_len; + auto seed_pos_t = from_blob(&seed_pos_val, {1}, ScalarType::Long); + auto seed_emb_t = from_blob( + seed_embed.mutable_data_ptr(), {1, 1, dim}, seed_embed.scalar_type()); + auto seed_decode_result = model_->execute( + "text_decoder", std::vector{*seed_emb_t, *seed_pos_t}); + ET_CHECK_MSG(seed_decode_result.ok(), "text_decoder seed step failed"); + read_float_tensor( + seed_decode_result.get()[0].toTensor(), hidden_state.data(), dim); + + std::vector> frame_codes; + int64_t cur_pos = prompt_len + 1; + int64_t emitted_frames = 0; + // Re-seed xorshift64 RNG for flow-matching noise (matches C reference). + flow_rng_state_ = + seed_ ? static_cast(seed_) : 0x12345678ABCDEF01ULL; + + std::vector timesteps(n_decoding_steps_ + 1); + for (int i = 0; i <= n_decoding_steps_; ++i) { + timesteps[i] = + static_cast(i) / static_cast(n_decoding_steps_); + } + + WavWriter wav(output_path, static_cast(sample_rate_)); + ET_CHECK_MSG(wav.IsOpen(), "Failed to open WAV output"); + + auto emit_ready_audio = [&]() { + int64_t total = static_cast(frame_codes.size()); + int64_t pending = total - emitted_frames; + int64_t chunk_threshold = (emitted_frames == 0) ? streaming_initial_chunk_ + : streaming_chunk_frames_; + if (pending < chunk_threshold) + return; + + int64_t decode_start = + std::max(int64_t(0), emitted_frames - streaming_left_context_); + int64_t crop_frames = emitted_frames - decode_start; + + std::vector chunk_samples; + decode_code_window(frame_codes, decode_start, total, chunk_samples); + + int64_t crop_samples = crop_frames * downsample_factor_; + if (crop_samples < static_cast(chunk_samples.size())) { + float* new_start = chunk_samples.data() + crop_samples; + std::size_t new_count = chunk_samples.size() - crop_samples; + wav.Write(new_start, new_count); + if (callback) + callback(new_start, new_count); + } + emitted_frames = total; + }; + + std::vector stream_h_bf16; + std::vector stream_sem_fp32; + + for (int frame = 0; frame < max_new_tokens && cur_pos < max_seq_len_; + ++frame) { + TensorPtr h_t; + if (lm_use_bf16_) { + stream_h_bf16.resize(dim); + write_float_to_bf16(stream_h_bf16.data(), hidden_state.data(), dim); + h_t = from_blob(stream_h_bf16.data(), {1, dim}, ScalarType::BFloat16); + } else { + h_t = from_blob(hidden_state.data(), {1, dim}, ScalarType::Float); + } + auto sem_r = model_->execute("semantic_head", std::vector{*h_t}); + ET_CHECK_MSG(sem_r.ok(), "semantic_head failed"); + + auto sem_t = sem_r.get()[0].toTensor(); + int64_t sem_vocab = sem_t.numel(); + stream_sem_fp32.resize(sem_vocab); + read_float_tensor(sem_t, stream_sem_fp32.data(), sem_vocab); + int64_t semantic_code = + sample_semantic_code(stream_sem_fp32.data(), sem_vocab, temperature); + + if (semantic_code == end_audio_code_) { + std::cerr << "END_AUDIO at frame " << frame << std::endl; + break; + } + + // Use xorshift64 + Box-Muller RNG matching voxtral-tts.c + std::vector x(n_aco); + for (auto& v : x) { + v = randn_xs(&flow_rng_state_); + } + std::vector zeros(dim, 0.0f); + + { + // Pre-stage hidden_state (and zeros for uncond) into bf16 once per frame + // so the n_decoding_steps × 2 inner velocity calls reuse the buffers. + std::vector step_h_bf16, step_zeros_bf16, step_x_bf16; + if (lm_use_bf16_) { + step_h_bf16.resize(dim); + write_float_to_bf16(step_h_bf16.data(), hidden_state.data(), dim); + step_zeros_bf16.assign(dim, fp32_to_bf16(0.0f)); + step_x_bf16.resize(n_aco); + } + std::vector v_cond(n_aco); + std::vector v_uncond(n_aco); + + for (int step = 0; step < n_decoding_steps_; ++step) { + float dt = timesteps[step + 1] - timesteps[step]; + int64_t tidx_val = step; + + TensorPtr xt1, hc; + if (lm_use_bf16_) { + write_float_to_bf16(step_x_bf16.data(), x.data(), n_aco); + xt1 = from_blob(step_x_bf16.data(), {1, n_aco}, ScalarType::BFloat16); + hc = from_blob(step_h_bf16.data(), {1, dim}, ScalarType::BFloat16); + } else { + xt1 = from_blob(x.data(), {1, n_aco}, ScalarType::Float); + hc = from_blob(hidden_state.data(), {1, dim}, ScalarType::Float); + } + auto ti1 = from_blob(&tidx_val, {1}, ScalarType::Long); + auto vc = model_->execute( + "predict_velocity", std::vector{*xt1, *ti1, *hc}); + ET_CHECK_MSG(vc.ok(), "predict_velocity (cond) failed"); + read_float_tensor(vc.get()[0].toTensor(), v_cond.data(), n_aco); + + TensorPtr xt2, hu; + if (lm_use_bf16_) { + // step_x_bf16 already holds x; reuse. + xt2 = from_blob(step_x_bf16.data(), {1, n_aco}, ScalarType::BFloat16); + hu = + from_blob(step_zeros_bf16.data(), {1, dim}, ScalarType::BFloat16); + } else { + xt2 = from_blob(x.data(), {1, n_aco}, ScalarType::Float); + hu = from_blob(zeros.data(), {1, dim}, ScalarType::Float); + } + auto ti2 = from_blob(&tidx_val, {1}, ScalarType::Long); + auto vu = model_->execute( + "predict_velocity", std::vector{*xt2, *ti2, *hu}); + ET_CHECK_MSG(vu.ok(), "predict_velocity (uncond) failed"); + read_float_tensor(vu.get()[0].toTensor(), v_uncond.data(), n_aco); + + for (int j = 0; j < n_aco; ++j) { + float v = cfg_alpha_ * v_cond[j] + (1.0f - cfg_alpha_) * v_uncond[j]; + x[j] += v * dt; + } + } + } + + std::vector codes(n_codebooks_); + codes[0] = semantic_code; + for (int j = 0; j < n_aco; ++j) { + float clamped = std::clamp(x[j], -1.0f, 1.0f); + float scaled = + ((clamped + 1.0f) / 2.0f) * static_cast(acoustic_levels_ - 1); + codes[j + 1] = + static_cast(std::round(scaled)) + n_special_tokens_; + } + frame_codes.push_back(codes); + emit_ready_audio(); + + auto next_codes = from_blob(codes.data(), {1, n_cb, 1}, ScalarType::Long); + auto ne = model_->execute( + "audio_token_embedding", std::vector{*next_codes}); + ET_CHECK_MSG(ne.ok(), "audio_token_embedding (next) failed"); + auto next_embeds = ne.get()[0].toTensor(); + + int64_t next_pos_val = cur_pos; + auto np = from_blob(&next_pos_val, {1}, ScalarType::Long); + auto next_emb = from_blob( + next_embeds.mutable_data_ptr(), {1, 1, dim}, next_embeds.scalar_type()); + auto nd = + model_->execute("text_decoder", std::vector{*next_emb, *np}); + ET_CHECK_MSG(nd.ok(), "text_decoder (next) failed"); + read_float_tensor(nd.get()[0].toTensor(), hidden_state.data(), dim); + cur_pos++; + } + + // Flush remaining + if (emitted_frames < static_cast(frame_codes.size())) { + int64_t decode_start = + std::max(int64_t(0), emitted_frames - streaming_left_context_); + int64_t decode_end = static_cast(frame_codes.size()); + int64_t crop_frames = emitted_frames - decode_start; + + std::vector chunk_samples; + decode_code_window(frame_codes, decode_start, decode_end, chunk_samples); + + int64_t crop_samples = crop_frames * downsample_factor_; + if (crop_samples < static_cast(chunk_samples.size())) { + float* new_start = chunk_samples.data() + crop_samples; + std::size_t new_count = chunk_samples.size() - crop_samples; + wav.Write(new_start, new_count); + if (callback) + callback(new_start, new_count); + } + } + + wav.Close(); + + auto end_time = std::chrono::high_resolution_clock::now(); + auto total_ms = std::chrono::duration_cast( + end_time - start_time) + .count(); + int64_t total_frames = static_cast(frame_codes.size()); + float audio_duration = static_cast(total_frames * downsample_factor_) / + static_cast(sample_rate_); + std::cerr << "Streaming: " << total_frames << " frames (" << audio_duration + << "s) in " << total_ms << "ms, RTF=" + << (static_cast(total_ms) / 1000.0f) / audio_duration + << std::endl; +} + +} // namespace voxtral_tts diff --git a/examples/models/voxtral_tts/voxtral_tts_runner.h b/examples/models/voxtral_tts/voxtral_tts_runner.h new file mode 100644 index 00000000000..947a29e2773 --- /dev/null +++ b/examples/models/voxtral_tts/voxtral_tts_runner.h @@ -0,0 +1,137 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace voxtral_tts { + +using AudioChunkCallback = + std::function; + +class VoxtralTTSRunner { + public: + VoxtralTTSRunner( + const std::string& model_path, + const std::string& codec_path, + const std::string& tokenizer_path, + const std::string& model_data_path = "", + const std::string& codec_data_path = ""); + + void set_trace_output_path(const std::string& trace_output_path); + void set_seed(uint32_t seed); + + void synthesize_offline( + const std::string& text, + const std::string& voice_path, + const std::string& output_path, + float temperature = 0.0f, + int max_new_tokens = 2048); + + void synthesize_streaming( + const std::string& text, + const std::string& voice_path, + const std::string& output_path, + AudioChunkCallback callback = nullptr, + float temperature = 0.0f, + int max_new_tokens = 2048); + + private: + void load_metadata(); + void reload_stateful_model(); + void warmup(); + + std::vector tokenize(const std::string& text); + + void decode_codes_to_wav( + const std::vector>& frame_codes, + const std::string& output_path, + std::vector* out_samples = nullptr); + + void decode_code_window( + const std::vector>& frame_codes, + int64_t start_frame, + int64_t end_frame, + std::vector& out_samples); + + void build_prompt( + const std::string& text, + std::vector& token_ids, + int& voice_start, + int& voice_len); + + std::filesystem::path resolve_voice_path(const std::string& voice_path) const; + void load_voice_embedding(const std::string& voice_path); + + int64_t sample_semantic_code( + const float* logits, + int64_t vocab_size, + float temperature); + + std::unique_ptr<::executorch::extension::Module> model_; + std::unique_ptr<::executorch::extension::Module> codec_; + std::unique_ptr tokenizer_; + std::mt19937 rng_; // used for semantic sampling (temperature > 0) + uint64_t flow_rng_state_; // xorshift64 state for flow-matching x0 noise + // (matches voxtral-tts.c) + uint32_t seed_ = 42; + + // Voice embedding loaded from .pt or raw .bin assets. + std::vector voice_embed_data_; + + // Config from metadata + int64_t sample_rate_ = 24000; + int64_t n_decoding_steps_ = 7; + float cfg_alpha_ = 1.2f; + int64_t n_acoustic_codebook_ = 36; + int64_t acoustic_levels_ = 21; + int64_t n_special_tokens_ = 2; + int64_t vocab_size_ = 131072; + int64_t max_seq_len_ = 4096; + int64_t dim_ = 3072; + int64_t downsample_factor_ = 1920; + int64_t n_codebooks_ = 37; + int64_t end_audio_code_ = 1; + int64_t empty_audio_code_ = 0; + int64_t max_codec_frames_ = 256; + bool codec_supports_exact_frames_ = false; + int64_t audio_token_id_ = 24; + int64_t begin_audio_token_id_ = 25; + int64_t text_to_audio_token_id_ = 36; + int64_t repeat_audio_text_token_id_ = 35; + int64_t voice_embed_len_ = 147; + int64_t runtime_voice_embed_len_ = 0; + + // Streaming + bool is_streaming_ = false; + int64_t streaming_chunk_frames_ = 25; + int64_t streaming_initial_chunk_ = 5; + int64_t streaming_left_context_ = 25; + + std::string trace_output_path_; + std::filesystem::path asset_root_dir_; + std::string model_path_; + std::string model_data_path_; // .ptd alongside model.pte (CUDA backend) + std::string + codec_data_path_; // .ptd alongside codec_decoder.pte (CUDA backend) + bool lm_use_bf16_ = + false; // True when CUDA AOTI .ptd is loaded — LM methods need bf16 IO. +}; + +} // namespace voxtral_tts diff --git a/examples/models/voxtral_tts/wav_writer.cpp b/examples/models/voxtral_tts/wav_writer.cpp new file mode 100644 index 00000000000..1541737fd9c --- /dev/null +++ b/examples/models/voxtral_tts/wav_writer.cpp @@ -0,0 +1,106 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "wav_writer.h" + +#include + +namespace voxtral_tts { +namespace { + +void write_u16(std::ofstream& file, std::uint16_t value) { + file.write(reinterpret_cast(&value), sizeof(value)); +} + +void write_u32(std::ofstream& file, std::uint32_t value) { + file.write(reinterpret_cast(&value), sizeof(value)); +} + +} // namespace + +WavWriter::WavWriter(const std::string& path, int sample_rate, int num_channels) + : file_(path, std::ios::binary), + sample_rate_(sample_rate), + num_channels_(num_channels) { + if (file_.is_open()) { + WriteHeaderPlaceholder(); + } +} + +WavWriter::~WavWriter() { + Close(); +} + +bool WavWriter::IsOpen() const { + return file_.is_open() && !closed_; +} + +bool WavWriter::Write(const float* samples, std::size_t frame_count) { + if (!IsOpen() || samples == nullptr) { + return false; + } + + const std::size_t sample_count = + frame_count * static_cast(num_channels_); + for (std::size_t i = 0; i < sample_count; ++i) { + const float clipped = std::clamp(samples[i], -1.0f, 1.0f); + auto pcm = static_cast(clipped * 32767.0f); + file_.write(reinterpret_cast(&pcm), sizeof(pcm)); + } + + data_bytes_ += + static_cast(sample_count * sizeof(std::int16_t)); + return file_.good(); +} + +bool WavWriter::Close() { + if (!file_.is_open() || closed_) { + return true; + } + closed_ = true; + const bool ok = FinalizeHeader(); + file_.close(); + return ok; +} + +void WavWriter::WriteHeaderPlaceholder() { + const std::uint16_t bits_per_sample = 16; + const std::uint32_t byte_rate = static_cast( + sample_rate_ * num_channels_ * bits_per_sample / 8); + const std::uint16_t block_align = + static_cast(num_channels_ * bits_per_sample / 8); + + file_.write("RIFF", 4); + write_u32(file_, 0); + file_.write("WAVE", 4); + file_.write("fmt ", 4); + write_u32(file_, 16); + write_u16(file_, 1); // PCM + write_u16(file_, static_cast(num_channels_)); + write_u32(file_, static_cast(sample_rate_)); + write_u32(file_, byte_rate); + write_u16(file_, block_align); + write_u16(file_, bits_per_sample); + file_.write("data", 4); + write_u32(file_, 0); +} + +bool WavWriter::FinalizeHeader() { + if (!file_.good()) { + return false; + } + const std::uint32_t riff_size = 36 + data_bytes_; + file_.seekp(4, std::ios::beg); + write_u32(file_, riff_size); + file_.seekp(40, std::ios::beg); + write_u32(file_, data_bytes_); + file_.seekp(0, std::ios::end); + return file_.good(); +} + +} // namespace voxtral_tts diff --git a/examples/models/voxtral_tts/wav_writer.h b/examples/models/voxtral_tts/wav_writer.h new file mode 100644 index 00000000000..719661ba876 --- /dev/null +++ b/examples/models/voxtral_tts/wav_writer.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +namespace voxtral_tts { + +class WavWriter { + public: + WavWriter(const std::string& path, int sample_rate, int num_channels = 1); + ~WavWriter(); + + WavWriter(const WavWriter&) = delete; + WavWriter& operator=(const WavWriter&) = delete; + + bool Write(const float* samples, std::size_t frame_count); + bool Close(); + bool IsOpen() const; + + private: + void WriteHeaderPlaceholder(); + bool FinalizeHeader(); + + std::ofstream file_; + int sample_rate_; + int num_channels_; + std::uint32_t data_bytes_ = 0; + bool closed_ = false; +}; + +} // namespace voxtral_tts