Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 229 additions & 6 deletions livekit-plugins/livekit-plugins-rime/livekit/plugins/rime/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,43 +12,50 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import asyncio
import os
import json
import base64
from dataclasses import dataclass, replace
from livekit.agents.voice.io import TimedString

import aiohttp

from livekit.agents import (
APIConnectionError,
APIConnectOptions,
APIStatusError,
APITimeoutError,
tts,
utils,
)
from livekit.agents.types import (
DEFAULT_API_CONNECT_OPTIONS,
NOT_GIVEN,
NotGivenOr,
)
from livekit.agents.utils import is_given

from typing import Literal
from .langs import TTSLangs
from .log import logger
from .models import ArcanaVoices, TTSModels
from urllib.parse import urlencode

Check failure on line 44 in livekit-plugins/livekit-plugins-rime/livekit/plugins/rime/tts.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

livekit-plugins/livekit-plugins-rime/livekit/plugins/rime/tts.py:15:1: I001 Import block is un-sorted or un-formatted

# arcana can take as long as 80% of the total duration of the audio it's synthesizing.
ARCANA_MODEL_TIMEOUT = 60 * 4
MISTV2_MODEL_TIMEOUT = 30
RIME_BASE_URL = "https://users.rime.ai/v1/rime-tts"
RIME_BASE_URL = "https://users.rime.ai/v1/rime-tts" # http
RIME_WS_JSON_URL = "wss://users.rime.ai/ws2" # ws_json
RIME_WS_TEXT_URL = "wss://users.rime.ai/ws" # ws_text


@dataclass
class _TTSOptions:
model: TTSModels | str
speaker: str
segment: NotGivenOr[str] = NOT_GIVEN
arcana_options: _ArcanaOptions | None = None
mistv2_options: _Mistv2Options | None = None

Expand All @@ -71,6 +78,10 @@
reduce_latency: NotGivenOr[bool] = NOT_GIVEN
pause_between_brackets: NotGivenOr[bool] = NOT_GIVEN
phonemize_between_brackets: NotGivenOr[bool] = NOT_GIVEN
# websocket specific
no_text_normalization: NotGivenOr[bool] = NOT_GIVEN
inline_speed_alpha: NotGivenOr[str] = NOT_GIVEN
save_oovs: NotGivenOr[bool] = NOT_GIVEN


NUM_CHANNELS = 1
Expand All @@ -81,9 +92,13 @@
self,
*,
base_url: str = RIME_BASE_URL,
ws_text_url: str = RIME_WS_TEXT_URL,
ws_json_url: str = RIME_WS_JSON_URL,
protocol: Literal["http", "ws_json", "ws_text"] = "http",
model: TTSModels | str = "arcana",
speaker: NotGivenOr[ArcanaVoices | str] = NOT_GIVEN,
lang: TTSLangs | str = "eng",
segment: NotGivenOr[str] = NOT_GIVEN,
# Arcana options
repetition_penalty: NotGivenOr[float] = NOT_GIVEN,
temperature: NotGivenOr[float] = NOT_GIVEN,
Expand All @@ -95,12 +110,15 @@
reduce_latency: NotGivenOr[bool] = NOT_GIVEN,
pause_between_brackets: NotGivenOr[bool] = NOT_GIVEN,
phonemize_between_brackets: NotGivenOr[bool] = NOT_GIVEN,
no_text_normalization: NotGivenOr[bool] = NOT_GIVEN,
inline_speed_alpha: NotGivenOr[str] = NOT_GIVEN,
save_oovs: NotGivenOr[bool] = NOT_GIVEN,
api_key: NotGivenOr[str] = NOT_GIVEN,
http_session: aiohttp.ClientSession | None = None,
) -> None:
super().__init__(
capabilities=tts.TTSCapabilities(
streaming=False,
streaming=protocol != "http",
),
sample_rate=sample_rate,
num_channels=NUM_CHANNELS,
Expand All @@ -120,6 +138,7 @@
self._opts = _TTSOptions(
model=model,
speaker=speaker,
segment=segment,
)
if model == "arcana":
self._opts.arcana_options = _ArcanaOptions(
Expand All @@ -138,9 +157,15 @@
reduce_latency=reduce_latency,
pause_between_brackets=pause_between_brackets,
phonemize_between_brackets=phonemize_between_brackets,
no_text_normalization=no_text_normalization,
inline_speed_alpha=inline_speed_alpha,
save_oovs=save_oovs,
)
self._session = http_session
self._base_url = base_url
self._ws_text_url = ws_text_url
self._ws_json_url = ws_json_url
self._protocol = protocol

self._total_timeout = ARCANA_MODEL_TIMEOUT if model == "arcana" else MISTV2_MODEL_TIMEOUT

Expand All @@ -159,10 +184,18 @@
return self._session

def synthesize(
self, text: str, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
self,
text: str,
*,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> ChunkedStream:
return ChunkedStream(tts=self, input_text=text, conn_options=conn_options)

def stream(
self, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
) -> JSONSynthesizeStream:
return JSONSynthesizeStream(tts=self, conn_options=conn_options)

def update_options(
self,
*,
Expand All @@ -181,12 +214,23 @@
pause_between_brackets: NotGivenOr[bool] = NOT_GIVEN,
phonemize_between_brackets: NotGivenOr[bool] = NOT_GIVEN,
base_url: NotGivenOr[str] = NOT_GIVEN,
ws_text_url: NotGivenOr[str] = NOT_GIVEN,
ws_json_url: NotGivenOr[str] = NOT_GIVEN,
segment: NotGivenOr[str] = NOT_GIVEN,
no_text_normalization: NotGivenOr[bool] = NOT_GIVEN,
save_oovs: NotGivenOr[bool] = NOT_GIVEN,
inline_speed_alpha: NotGivenOr[str] = NOT_GIVEN,
) -> None:
if is_given(base_url):
self._base_url = base_url
if is_given(ws_text_url):
self._ws_text_url = ws_text_url
if is_given(ws_json_url):
self._ws_json_url = ws_json_url
if is_given(segment):
self._opts.segment = segment
if is_given(model):
self._opts.model = model

if model == "arcana" and self._opts.arcana_options is None:
self._opts.arcana_options = _ArcanaOptions()
elif model == "mistv2" and self._opts.mistv2_options is None:
Expand Down Expand Up @@ -222,6 +266,12 @@
self._opts.mistv2_options.pause_between_brackets = pause_between_brackets
if is_given(phonemize_between_brackets):
self._opts.mistv2_options.phonemize_between_brackets = phonemize_between_brackets
if is_given(no_text_normalization):
self._opts.mistv2_options.no_text_normalization = no_text_normalization
if is_given(inline_speed_alpha):
self._opts.mistv2_options.inline_speed_alpha = inline_speed_alpha
if is_given(save_oovs):
self._opts.mistv2_options.save_oovs = save_oovs


class ChunkedStream(tts.ChunkedStream):
Expand Down Expand Up @@ -280,7 +330,8 @@
},
json=payload,
timeout=aiohttp.ClientTimeout(
total=self._tts._total_timeout, sock_connect=self._conn_options.timeout
total=self._tts._total_timeout,
sock_connect=self._conn_options.timeout,
),
) as resp:
resp.raise_for_status()
Expand Down Expand Up @@ -308,3 +359,175 @@
) from None
except Exception as e:
raise APIConnectionError() from e


class JSONSynthesizeStream(tts.SynthesizeStream):
def __init__(self, tts: TTS, conn_options: APIConnectOptions) -> None:
super().__init__(tts=tts, conn_options=conn_options)
self._tts: TTS = tts
self._opts = replace(tts._opts)
self._ws: aiohttp.ClientWebSocketResponse | None = None
self._input_complete = asyncio.Event()
self._model_timeout = tts._total_timeout
if self._opts.model == "arcana":
raise ValueError(
"The Arcana model is not supported for JSON WebSocket streaming. Please switch to the 'mistv2' model."

)

def _build_ws_url(self) -> str:
params = {
"modelId": self._opts.model,
"speaker": self._opts.speaker,
"audioFormat": "pcm",
}
if is_given(self._opts.segment):
params["segment"] = self._opts.segment
elif self._opts.model == "mistv2":
mistv2_opts = self._opts.mistv2_options
assert mistv2_opts is not None
if is_given(mistv2_opts.lang):
params["lang"] = mistv2_opts.lang
if is_given(mistv2_opts.sample_rate):
params["samplingRate"] = mistv2_opts.sample_rate
if is_given(mistv2_opts.speed_alpha):
params["speedAlpha"] = mistv2_opts.speed_alpha
if is_given(mistv2_opts.reduce_latency):
params["reduceLatency"] = mistv2_opts.reduce_latency
if is_given(mistv2_opts.pause_between_brackets):
params["pauseBetweenBrackets"] = mistv2_opts.pause_between_brackets
if is_given(mistv2_opts.phonemize_between_brackets):
params["phonemizeBetweenBrackets"] = mistv2_opts.phonemize_between_brackets
if is_given(mistv2_opts.no_text_normalization):
params["noTextNormalization"] = mistv2_opts.no_text_normalization
if is_given(mistv2_opts.inline_speed_alpha):
params["inlineSpeedAlpha"] = mistv2_opts.inline_speed_alpha
if is_given(mistv2_opts.save_oovs):
params["saveOovs"] = mistv2_opts.save_oovs
return f"{self._tts._ws_json_url}?{urlencode(params)}"

async def clear_buffer(self) -> None:
"""Send clear operation to discard buffered text"""
if self._ws and not self._ws.closed:
await self._ws.send_str(json.dumps({"operation": "clear"}))

async def aclose(self) -> None:
"""Close the stream and send EOS if needed"""
if self._ws and not self._ws.closed:
await self._ws.send_str(json.dumps({"operation": "eos"}))
await super().aclose()

async def _send_task(self, ws: aiohttp.ClientWebSocketResponse) -> None:
try:
async for input_data in self._input_ch:
if isinstance(input_data, str):
await ws.send_str(json.dumps({"text": input_data}))
elif isinstance(input_data, self._FlushSentinel):
await ws.send_str(json.dumps({"operation": "flush"}))
except Exception as e:
logger.error("Rime WebSocket send task failed: %s", e)
raise APIConnectionError(f"Send task failed: {e}") from e
finally:
self._input_complete.set()

async def _recv_task(
self, ws: aiohttp.ClientWebSocketResponse, output_emitter: tts.AudioEmitter
) -> None:
segment_started = False

while True:
try:
# Use timeout to detect completion - 2 seconds after input is complete
if self._input_complete.is_set():
timeout = self._model_timeout
else:
timeout = self._conn_options.timeout

msg = await asyncio.wait_for(ws.receive(), timeout=timeout)
except asyncio.TimeoutError:
# If input is complete and we get timeout, synthesis is done
if self._input_complete.is_set():
if segment_started:
output_emitter.end_segment()
output_emitter.end_input()
break
continue

if msg.type in (
aiohttp.WSMsgType.CLOSE,
aiohttp.WSMsgType.CLOSED,
aiohttp.WSMsgType.CLOSING,
):
if segment_started:
output_emitter.end_segment()
output_emitter.end_input()
return

if msg.type != aiohttp.WSMsgType.TEXT:
logger.warning("Unexpected Rime message type: %s", msg.type)
continue

try:
data = json.loads(msg.data)
except json.JSONDecodeError:
logger.warning("Invalid JSON from Rime: %s", msg.data)
continue

if data.get("type") == "chunk":
if not segment_started:
segment_id = data.get("contextId") or utils.shortuuid()
output_emitter.start_segment(segment_id=segment_id)
segment_started = True
audio_data = base64.b64decode(data["data"])
output_emitter.push(audio_data)
elif data.get("type") == "timestamps":
word_timestamps = data.get("word_timestamps", {})
words = word_timestamps.get("words", [])
starts = word_timestamps.get("start", [])
ends = word_timestamps.get("end", [])

timed_words = []
for word, start, end in zip(words, starts, ends):
timed_words.append(TimedString(text=word, start_time=start, end_time=end))
if timed_words:
output_emitter.push_timed_transcript(timed_words)
elif data.get("type") == "error":
logger.error(f"Rime error: {data.get('message')}")
if segment_started:
output_emitter.end_segment()
output_emitter.end_input()
raise APIStatusError(f"Rime error: {data.get('message')}")

async def _run(self, output_emitter: tts.AudioEmitter) -> None:
request_id = utils.shortuuid()
output_emitter.initialize(
request_id=request_id,
sample_rate=self._tts.sample_rate,
num_channels=NUM_CHANNELS,
mime_type="audio/pcm",
stream=True,
)
ws_url = self._build_ws_url()

send_task = None
recv_task = None

try:
async with self._tts._ensure_session().ws_connect(
ws_url,
headers={"Authorization": f"Bearer {self._tts._api_key}"},
timeout=aiohttp.ClientTimeout(total=self._tts._total_timeout),
) as ws:
self._ws = ws
send_task = asyncio.create_task(self._send_task(ws))
recv_task = asyncio.create_task(self._recv_task(ws, output_emitter))
await asyncio.gather(send_task, recv_task)
except asyncio.TimeoutError:
raise APITimeoutError() from None
except aiohttp.ClientError as e:
raise APIConnectionError() from e
except Exception as e:
raise APIConnectionError() from e
finally:
if send_task is not None or recv_task is not None:
await utils.aio.gracefully_cancel(send_task, recv_task)
Loading