Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 39 additions & 5 deletions livekit-plugins/livekit-plugins-xai/livekit/plugins/xai/tts.py
Comment thread
devin-ai-integration[bot] marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import os
import weakref
from dataclasses import dataclass, replace
from typing import Literal
from urllib.parse import urlencode

import aiohttp
Expand Down Expand Up @@ -51,6 +52,9 @@ class _TTSOptions:
voice: GrokVoices | str
language: TTSLanguages | str
tokenizer: tokenize.WordTokenizer
optimize_streaming_latency: NotGivenOr[Literal[0, 1, 2]]
speed: NotGivenOr[float]
text_normalization: NotGivenOr[bool]


class TTS(tts.TTS):
Expand All @@ -60,6 +64,9 @@ def __init__(
api_key: NotGivenOr[str] = NOT_GIVEN,
voice: GrokVoices | str = DEFAULT_VOICE,
language: TTSLanguages | str = "auto",
optimize_streaming_latency: NotGivenOr[Literal[0, 1, 2]] = NOT_GIVEN,
speed: NotGivenOr[float] = NOT_GIVEN,
text_normalization: NotGivenOr[bool] = NOT_GIVEN,
tokenizer: tokenize.WordTokenizer | None = None,
http_session: aiohttp.ClientSession | None = None,
) -> None:
Expand All @@ -71,6 +78,9 @@ def __init__(
Args:
voice (str, optional): The voice ID for the desired voice. Defaults to "ara".
language (TTSLanguages | str, optional): Language code for synthesis (e.g., "en", "fr", "ja"). Defaults to "auto".
optimize_streaming_latency (0 | 1 | 2, optional): Latency optimization level for the xAI TTS websocket.
speed (float, optional): Speaking-rate multiplier for the generated audio.
text_normalization (bool, optional): Whether to normalize text before synthesis.
api_key (str | None, optional): The xAI API key. If not provided, it will be read from the xAI environment variable.
http_session (aiohttp.ClientSession | None, optional): An existing aiohttp ClientSession to use. If not provided, a new session will be created.
""" # noqa: E501
Expand All @@ -93,6 +103,9 @@ def __init__(
voice=voice,
language=language,
tokenizer=tokenizer,
optimize_streaming_latency=optimize_streaming_latency,
speed=speed,
text_normalization=text_normalization,
)

self._session = http_session
Expand All @@ -106,13 +119,22 @@ def model(self) -> str:
def provider(self) -> str:
return "xAI"

async def _connect_ws(self, timeout: float) -> aiohttp.ClientWebSocketResponse:
params = {
"voice": self._opts.voice,
"language": self._opts.language,
async def _connect_ws(
self, timeout: float, opts: _TTSOptions
) -> aiohttp.ClientWebSocketResponse:
params: dict[str, str | int | float] = {
"voice": opts.voice,
"language": opts.language,
"codec": "pcm",
"sample_rate": SAMPLE_RATE,
}
if is_given(opts.optimize_streaming_latency):
params["optimize_streaming_latency"] = opts.optimize_streaming_latency
if is_given(opts.speed):
params["speed"] = opts.speed
if is_given(opts.text_normalization):
params["text_normalization"] = str(opts.text_normalization).lower()

url = f"{XAI_WEBSOCKET_URL}?{urlencode(params)}"
try:
ws = await asyncio.wait_for(
Expand Down Expand Up @@ -143,16 +165,28 @@ def update_options(
*,
voice: str | None = None,
language: TTSLanguages | str | None = None,
optimize_streaming_latency: NotGivenOr[Literal[0, 1, 2]] = NOT_GIVEN,
speed: NotGivenOr[float] = NOT_GIVEN,
text_normalization: NotGivenOr[bool] = NOT_GIVEN,
) -> None:
"""
Update the Text-to-Speech (TTS) configuration options.

Args:
voice (str, optional): The voice ID for the desired voice.
language (TTSLanguages | str, optional): Language code for synthesis (e.g., "en", "fr", "ja").
optimize_streaming_latency (0 | 1 | 2, optional): Latency optimization level for the xAI TTS websocket.
speed (float, optional): Speaking-rate multiplier for the generated audio.
text_normalization (bool, optional): Whether to normalize text before synthesis.
""" # noqa: E501
self._opts.voice = voice or self._opts.voice
self._opts.language = language or self._opts.language
if is_given(optimize_streaming_latency):
self._opts.optimize_streaming_latency = optimize_streaming_latency
if is_given(speed):
self._opts.speed = speed
if is_given(text_normalization):
self._opts.text_normalization = text_normalization

def synthesize(
self,
Expand Down Expand Up @@ -290,7 +324,7 @@ async def _recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
else:
logger.warning("Unexpected xAI message %s", data)

ws = await self._tts._connect_ws(self._conn_options.timeout)
ws = await self._tts._connect_ws(self._conn_options.timeout, self._opts)
tasks = [
asyncio.create_task(_send_task(ws)),
asyncio.create_task(_recv_task(ws)),
Expand Down
Loading