-
Notifications
You must be signed in to change notification settings - Fork 2.8k
[inworld] use audio/pcm to leverage the fast AudioBytesStream path #4803
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -19,6 +19,7 @@ | |||||||||||||||||||
| import json | ||||||||||||||||||||
| import os | ||||||||||||||||||||
| import time | ||||||||||||||||||||
| import uuid | ||||||||||||||||||||
| import weakref | ||||||||||||||||||||
| from collections.abc import Callable | ||||||||||||||||||||
| from dataclasses import dataclass, field, replace | ||||||||||||||||||||
|
|
@@ -45,11 +46,14 @@ | |||||||||||||||||||
| from livekit.agents.voice.io import TimedString | ||||||||||||||||||||
|
|
||||||||||||||||||||
| from .log import logger | ||||||||||||||||||||
| from .version import __version__ | ||||||||||||||||||||
|
|
||||||||||||||||||||
| USER_AGENT = f"livekit-agents-py/{__version__}" | ||||||||||||||||||||
|
|
||||||||||||||||||||
| DEFAULT_BIT_RATE = 64000 | ||||||||||||||||||||
| DEFAULT_ENCODING = "OGG_OPUS" | ||||||||||||||||||||
| DEFAULT_MODEL = "inworld-tts-1" | ||||||||||||||||||||
| DEFAULT_SAMPLE_RATE = 48000 | ||||||||||||||||||||
| DEFAULT_ENCODING = "LINEAR16" | ||||||||||||||||||||
| DEFAULT_MODEL = "inworld-tts-1.5-max" | ||||||||||||||||||||
| DEFAULT_SAMPLE_RATE = 24000 | ||||||||||||||||||||
| DEFAULT_URL = "https://api.inworld.ai/" | ||||||||||||||||||||
| DEFAULT_WS_URL = "wss://api.inworld.ai/" | ||||||||||||||||||||
| DEFAULT_VOICE = "Ashley" | ||||||||||||||||||||
|
|
@@ -80,7 +84,13 @@ class _TTSOptions: | |||||||||||||||||||
|
|
||||||||||||||||||||
| @property | ||||||||||||||||||||
| def mime_type(self) -> str: | ||||||||||||||||||||
| if self.encoding == "MP3": | ||||||||||||||||||||
| if self.encoding == "LINEAR16": | ||||||||||||||||||||
| # Use audio/pcm so the emitter takes the fast synchronous | ||||||||||||||||||||
| # AudioByteStream path instead of the async AudioStreamDecoder. | ||||||||||||||||||||
| # WAV headers from the server are stripped before pushing to the | ||||||||||||||||||||
| # emitter (see _strip_wav_header). | ||||||||||||||||||||
| return "audio/pcm" | ||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if it's wav file sent back.. you should just pass back this would be preferred rather than having multiple wav handling in the code base
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe that's the existing behavior. When encoding is
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tinalenguyen would you be able to advise here? I tested with my own websocket benchmark script at https://github.com/inworld-ai/inworld-api-examples/tree/ian/livekit-integrations/integrations/livekit/python/benchmarks with 100+ iterations summarized by AI, the reason is that
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
could you share the benchmark scripts that you had ran? if AudioStreamDecoder is slow, we should optimize that instead. I still maintain that we should not be duplicating decoding logic within plugin code
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi David, can you access this script https://github.com/inworld-ai/inworld-api-examples/tree/ian/livekit-integrations/integrations/livekit/python/benchmarks ? I have the instructions to run in the README, you should be able to checkout this feature branch in the submodule to see the difference. We might also consider a new Audio encoding format
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @davidzhao @tinalenguyen, curious if you were able to reproduce and if there is any plan for optimization of AudioStreamDecoder? I do think that would be ideal to speed up every other encoding formats besides WAV |
||||||||||||||||||||
| elif self.encoding == "MP3": | ||||||||||||||||||||
| return "audio/mpeg" | ||||||||||||||||||||
| elif self.encoding == "OGG_OPUS": | ||||||||||||||||||||
| return "audio/ogg" | ||||||||||||||||||||
|
|
@@ -230,10 +240,19 @@ async def connect(self) -> None: | |||||||||||||||||||
| return | ||||||||||||||||||||
|
|
||||||||||||||||||||
| url = urljoin(self._ws_url, "/tts/v1/voice:streamBidirectional") | ||||||||||||||||||||
| request_id = str(uuid.uuid4()) | ||||||||||||||||||||
| self._ws = await self._session.ws_connect( | ||||||||||||||||||||
| url, headers={"Authorization": self._authorization} | ||||||||||||||||||||
| url, | ||||||||||||||||||||
| headers={ | ||||||||||||||||||||
| "Authorization": self._authorization, | ||||||||||||||||||||
| "X-User-Agent": USER_AGENT, | ||||||||||||||||||||
| "X-Request-Id": request_id, | ||||||||||||||||||||
| }, | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| logger.debug( | ||||||||||||||||||||
| "Established Inworld TTS WebSocket connection (shared)", | ||||||||||||||||||||
| extra={"request_id": request_id}, | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| logger.debug("Established Inworld TTS WebSocket connection (shared)") | ||||||||||||||||||||
|
|
||||||||||||||||||||
| self._send_task = asyncio.create_task(self._send_loop()) | ||||||||||||||||||||
| self._recv_task = asyncio.create_task(self._recv_loop()) | ||||||||||||||||||||
|
|
@@ -481,7 +500,7 @@ async def _recv_loop(self) -> None: | |||||||||||||||||||
| ctx.emitter.push_timed_transcript(ts) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if audio_content := audio_chunk.get("audioContent"): | ||||||||||||||||||||
| ctx.emitter.push(base64.b64decode(audio_content)) | ||||||||||||||||||||
| ctx.emitter.push(_strip_wav_header(base64.b64decode(audio_content))) | ||||||||||||||||||||
| continue | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if "flushCompleted" in result: | ||||||||||||||||||||
|
|
@@ -997,7 +1016,11 @@ async def list_voices(self, language: str | None = None) -> list[dict[str, Any]] | |||||||||||||||||||
|
|
||||||||||||||||||||
| async with self._ensure_session().get( | ||||||||||||||||||||
| url, | ||||||||||||||||||||
| headers={"Authorization": self._authorization}, | ||||||||||||||||||||
| headers={ | ||||||||||||||||||||
| "Authorization": self._authorization, | ||||||||||||||||||||
| "X-User-Agent": USER_AGENT, | ||||||||||||||||||||
| "X-Request-Id": str(uuid.uuid4()), | ||||||||||||||||||||
| }, | ||||||||||||||||||||
| params=params, | ||||||||||||||||||||
| ) as resp: | ||||||||||||||||||||
| if not resp.ok: | ||||||||||||||||||||
|
|
@@ -1040,10 +1063,13 @@ async def _run(self, output_emitter: tts.AudioEmitter) -> None: | |||||||||||||||||||
| if utils.is_given(self._opts.text_normalization): | ||||||||||||||||||||
| body_params["applyTextNormalization"] = self._opts.text_normalization | ||||||||||||||||||||
|
|
||||||||||||||||||||
| x_request_id = str(uuid.uuid4()) | ||||||||||||||||||||
| async with self._tts._ensure_session().post( | ||||||||||||||||||||
| urljoin(self._tts._base_url, "/tts/v1/voice:stream"), | ||||||||||||||||||||
| headers={ | ||||||||||||||||||||
| "Authorization": self._tts._authorization, | ||||||||||||||||||||
| "X-User-Agent": USER_AGENT, | ||||||||||||||||||||
| "X-Request-Id": x_request_id, | ||||||||||||||||||||
| }, | ||||||||||||||||||||
| json=body_params, | ||||||||||||||||||||
| timeout=aiohttp.ClientTimeout(sock_connect=self._conn_options.timeout), | ||||||||||||||||||||
|
|
@@ -1079,20 +1105,20 @@ async def _run(self, output_emitter: tts.AudioEmitter) -> None: | |||||||||||||||||||
| output_emitter.push_timed_transcript(timed_strings) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if audio_content := result.get("audioContent"): | ||||||||||||||||||||
| output_emitter.push(base64.b64decode(audio_content)) | ||||||||||||||||||||
| output_emitter.push(_strip_wav_header(base64.b64decode(audio_content))) | ||||||||||||||||||||
| output_emitter.flush() | ||||||||||||||||||||
| elif error := data.get("error"): | ||||||||||||||||||||
| raise APIStatusError( | ||||||||||||||||||||
| message=error.get("message"), | ||||||||||||||||||||
| status_code=error.get("code"), | ||||||||||||||||||||
| request_id=request_id, | ||||||||||||||||||||
| request_id=x_request_id, | ||||||||||||||||||||
| body=None, | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| except asyncio.TimeoutError: | ||||||||||||||||||||
| raise APITimeoutError() from None | ||||||||||||||||||||
| except aiohttp.ClientResponseError as e: | ||||||||||||||||||||
| raise APIStatusError( | ||||||||||||||||||||
| message=e.message, status_code=e.status, request_id=None, body=None | ||||||||||||||||||||
| message=e.message, status_code=e.status, request_id=x_request_id, body=None | ||||||||||||||||||||
| ) from None | ||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||
| raise APIConnectionError() from e | ||||||||||||||||||||
|
|
@@ -1166,6 +1192,18 @@ async def _send_task() -> None: | |||||||||||||||||||
| output_emitter.end_input() | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| def _strip_wav_header(data: bytes) -> bytes: | ||||||||||||||||||||
| """Strip WAV header from audio data, returning raw PCM. | ||||||||||||||||||||
|
|
||||||||||||||||||||
| Inworld returns LINEAR16 audio wrapped in WAV containers. The emitter's | ||||||||||||||||||||
| AudioByteStream fast-path requires raw PCM, so we strip the 44-byte | ||||||||||||||||||||
| standard WAV header (RIFF + fmt + data chunk headers) when present. | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| if len(data) > 44 and data[:4] == b"RIFF": | ||||||||||||||||||||
| return data[44:] | ||||||||||||||||||||
| return data | ||||||||||||||||||||
ianbbqzy marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| def _parse_timestamp_info( | ||||||||||||||||||||
| timestamp_info: dict[str, Any], cumulative_time: float = 0.0 | ||||||||||||||||||||
| ) -> list[TimedString]: | ||||||||||||||||||||
|
|
||||||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.