Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ class STTOptions:
sample_rate: int
min_confidence_threshold: float
keywords: NotGivenOr[list[tuple[str, float]]] = NOT_GIVEN
speech_start_timeout: NotGivenOr[float] = NOT_GIVEN
speech_end_timeout: NotGivenOr[float] = NOT_GIVEN

@property
def version(self) -> int:
Expand Down Expand Up @@ -133,6 +135,8 @@ def __init__(
credentials_info: NotGivenOr[dict] = NOT_GIVEN,
credentials_file: NotGivenOr[str] = NOT_GIVEN,
keywords: NotGivenOr[list[tuple[str, float]]] = NOT_GIVEN,
speech_start_timeout: NotGivenOr[float] = NOT_GIVEN,
speech_end_timeout: NotGivenOr[float] = NOT_GIVEN,
use_streaming: NotGivenOr[bool] = NOT_GIVEN,
):
"""
Expand All @@ -159,6 +163,8 @@ def __init__(
credentials_info(dict): the credentials info to use for recognition (default: None)
credentials_file(str): the credentials file to use for recognition (default: None)
keywords(List[tuple[str, float]]): list of keywords to recognize (default: None)
speech_start_timeout(float): maximum seconds to wait for speech to begin before timeout (default: None)
speech_end_timeout(float): seconds of silence before marking utterance as complete (default: None)
use_streaming(bool): whether to use streaming for recognition (default: True)
"""
if not is_given(use_streaming):
Expand Down Expand Up @@ -201,6 +207,8 @@ def __init__(
sample_rate=sample_rate,
min_confidence_threshold=min_confidence_threshold,
keywords=keywords,
speech_start_timeout=speech_start_timeout,
speech_end_timeout=speech_end_timeout,
)
self._streams = weakref.WeakSet[SpeechStream]()
self._pool = utils.ConnectionPool[SpeechAsyncClientV2 | SpeechAsyncClientV1](
Expand Down Expand Up @@ -379,6 +387,8 @@ def update_options(
model: NotGivenOr[SpeechModels] = NOT_GIVEN,
location: NotGivenOr[str] = NOT_GIVEN,
keywords: NotGivenOr[list[tuple[str, float]]] = NOT_GIVEN,
speech_start_timeout: NotGivenOr[float] = NOT_GIVEN,
speech_end_timeout: NotGivenOr[float] = NOT_GIVEN,
) -> None:
if is_given(languages):
if isinstance(languages, str):
Expand All @@ -404,6 +414,10 @@ def update_options(
self._pool.invalidate()
if is_given(keywords):
self._config.keywords = keywords
if is_given(speech_start_timeout):
self._config.speech_start_timeout = speech_start_timeout
if is_given(speech_end_timeout):
self._config.speech_end_timeout = speech_end_timeout

for stream in self._streams:
stream.update_options(
Expand All @@ -414,6 +428,8 @@ def update_options(
spoken_punctuation=spoken_punctuation,
model=model,
keywords=keywords,
speech_start_timeout=speech_start_timeout,
speech_end_timeout=speech_end_timeout,
)

async def aclose(self) -> None:
Expand Down Expand Up @@ -450,6 +466,8 @@ def update_options(
model: NotGivenOr[SpeechModels] = NOT_GIVEN,
min_confidence_threshold: NotGivenOr[float] = NOT_GIVEN,
keywords: NotGivenOr[list[tuple[str, float]]] = NOT_GIVEN,
speech_start_timeout: NotGivenOr[float] = NOT_GIVEN,
speech_end_timeout: NotGivenOr[float] = NOT_GIVEN,
) -> None:
if is_given(languages):
if isinstance(languages, str):
Expand All @@ -472,13 +490,36 @@ def update_options(
self._config.min_confidence_threshold = min_confidence_threshold
if is_given(keywords):
self._config.keywords = keywords
if is_given(speech_start_timeout):
self._config.speech_start_timeout = speech_start_timeout
if is_given(speech_end_timeout):
self._config.speech_end_timeout = speech_end_timeout

self._reconnect_event.set()

def _build_streaming_config(
self,
) -> cloud_speech_v2.StreamingRecognitionConfig | cloud_speech_v1.StreamingRecognitionConfig:
if self._config.version == 2:
# Build voice activity timeout if either timeout is specified
voice_activity_timeout = None
if is_given(self._config.speech_start_timeout) or is_given(
self._config.speech_end_timeout
):
voice_activity_timeout = (
cloud_speech_v2.StreamingRecognitionFeatures.VoiceActivityTimeout()
)
if is_given(self._config.speech_start_timeout):
voice_activity_timeout.speech_start_timeout = Duration(
seconds=int(self._config.speech_start_timeout),
nanos=int((self._config.speech_start_timeout % 1) * 1e9),
)
if is_given(self._config.speech_end_timeout):
voice_activity_timeout.speech_end_timeout = Duration(
seconds=int(self._config.speech_end_timeout),
nanos=int((self._config.speech_end_timeout % 1) * 1e9),
)

return cloud_speech_v2.StreamingRecognitionConfig(
config=cloud_speech_v2.RecognitionConfig(
explicit_decoding_config=cloud_speech_v2.ExplicitDecodingConfig(
Expand All @@ -499,6 +540,7 @@ def _build_streaming_config(
streaming_features=cloud_speech_v2.StreamingRecognitionFeatures(
interim_results=self._config.interim_results,
enable_voice_activity_events=self._config.enable_voice_activity_events,
voice_activity_timeout=voice_activity_timeout,
),
)

Expand Down
98 changes: 98 additions & 0 deletions tests/test_plugin_google_stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,3 +322,101 @@ async def test_recognize_response_to_speech_event_words():
],
)
]


async def test_voice_activity_timeout_defaults():
"""Test voice activity timeouts are not set by default."""
from livekit.agents.types import NOT_GIVEN
from livekit.plugins.google import STT

stt = STT()
assert stt._config.speech_start_timeout is NOT_GIVEN
assert stt._config.speech_end_timeout is NOT_GIVEN


async def test_voice_activity_timeout_set():
"""Test voice activity timeouts can be set."""
from livekit.plugins.google import STT

stt = STT(
speech_start_timeout=10.0,
speech_end_timeout=2.5,
)
assert stt._config.speech_start_timeout == 10.0
assert stt._config.speech_end_timeout == 2.5


async def test_voice_activity_timeout_fractional_seconds():
"""Test voice activity timeouts handle fractional seconds."""
from livekit.plugins.google import STT

stt = STT(
speech_start_timeout=5.5,
speech_end_timeout=1.25,
)
assert stt._config.speech_start_timeout == 5.5
assert stt._config.speech_end_timeout == 1.25


async def test_voice_activity_timeout_speech_start_only():
"""Test setting only speech_start_timeout."""
from livekit.agents.types import NOT_GIVEN
from livekit.plugins.google import STT

stt = STT(speech_start_timeout=15.0)
assert stt._config.speech_start_timeout == 15.0
assert stt._config.speech_end_timeout is NOT_GIVEN


async def test_voice_activity_timeout_speech_end_only():
"""Test setting only speech_end_timeout."""
from livekit.agents.types import NOT_GIVEN
from livekit.plugins.google import STT

stt = STT(speech_end_timeout=3.0)
assert stt._config.speech_end_timeout == 3.0
assert stt._config.speech_start_timeout is NOT_GIVEN


async def test_voice_activity_timeout_v2_model():
"""Test that V2 model detection works correctly."""
from livekit.plugins.google import STT

stt_v2 = STT(model="chirp_3")
assert stt_v2._config.version == 2

stt_v1 = STT(model="default")
assert stt_v1._config.version == 1


async def test_voice_activity_timeout_update():
"""Test that timeout options can be updated dynamically."""
from livekit.plugins.google import STT

stt = STT(
speech_start_timeout=10.0,
speech_end_timeout=2.0,
)
stt.update_options(
speech_start_timeout=15.0,
speech_end_timeout=3.0,
)
assert stt._config.speech_start_timeout == 15.0
assert stt._config.speech_end_timeout == 3.0


async def test_voice_activity_timeout_partial_update():
"""Test updating only one timeout at a time."""
from livekit.plugins.google import STT

stt = STT(
speech_start_timeout=10.0,
speech_end_timeout=2.0,
)
stt.update_options(speech_start_timeout=20.0)
assert stt._config.speech_start_timeout == 20.0
assert stt._config.speech_end_timeout == 2.0

stt.update_options(speech_end_timeout=5.0)
assert stt._config.speech_start_timeout == 20.0
assert stt._config.speech_end_timeout == 5.0
Loading