Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions livekit-agents/livekit/agents/inference/eot/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def stream(
async def unlikely_threshold(self, language: LanguageCode | None) -> float | None:
return self._opts.thresholds.lookup(language)

async def backchannel_threshold(self, language: LanguageCode | None) -> float | None:
return self._opts.thresholds.lookup_backchannel(language)

async def supports_language(self, language: LanguageCode | None) -> bool:
return self._opts.thresholds.supports(language)

Expand Down Expand Up @@ -140,6 +143,9 @@ def is_fallback(self) -> bool:
async def unlikely_threshold(self, language: LanguageCode | None) -> float | None:
return self._opts.thresholds.lookup(language)

async def backchannel_threshold(self, language: LanguageCode | None) -> float | None:
return self._opts.thresholds.lookup_backchannel(language)

async def supports_language(self, language: LanguageCode | None) -> bool:
return self._opts.thresholds.supports(language)

Expand Down Expand Up @@ -259,6 +265,7 @@ def _resolve_prediction(
*,
inference_duration: float | None = None,
detection_delay: float | None = None,
backchannel_probability: float | None = None,
) -> None:
"""Accept a prediction from a transport. Stale response is ignored."""
if request_id != self._request_id:
Expand All @@ -274,6 +281,7 @@ def _resolve_prediction(
end_of_turn_probability=probability,
detection_delay=detection_delay,
inference_duration=inference_duration,
backchannel_probability=backchannel_probability,
)
)

Expand Down
19 changes: 16 additions & 3 deletions livekit-agents/livekit/agents/inference/eot/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
*,
version: NotGivenOr[TurnDetectorVersions] = NOT_GIVEN,
unlikely_threshold: NotGivenOr[float | dict[LanguageCode | str, float]] = NOT_GIVEN,
backchannel_threshold: NotGivenOr[float | dict[LanguageCode | str, float]] = NOT_GIVEN,
base_url: NotGivenOr[str] = NOT_GIVEN,
api_key: NotGivenOr[str] = NOT_GIVEN,
api_secret: NotGivenOr[str] = NOT_GIVEN,
Expand Down Expand Up @@ -100,7 +101,7 @@ def __init__(

opts = TurnDetectorOptions(
sample_rate=sample_rate,
thresholds=ThresholdOptions(resolved_model, unlikely_threshold),
thresholds=ThresholdOptions(resolved_model, unlikely_threshold, backchannel_threshold),
)
super().__init__(opts=opts)

Expand All @@ -115,20 +116,32 @@ def model(self) -> TurnDetectorModels:
return self._model

def _warn_threshold_override(self) -> None:
if is_given(overrides := self._opts.thresholds.overrides):
thresholds = self._opts.thresholds
if is_given(overrides := thresholds.overrides):
logger.warning(
"a non-default turn detection threshold was provided "
"(unlikely_threshold=%s); the server provides calibrated defaults and "
"overriding them may be suboptimal",
overrides,
)
if is_given(bc_overrides := thresholds.backchannel_overrides):
logger.warning(
"a non-default backchannel threshold was provided "
"(backchannel_threshold=%s); the server provides calibrated defaults and "
"overriding them may be suboptimal",
bc_overrides,
)

def update_options(
self,
*,
unlikely_threshold: NotGivenOr[float | dict[LanguageCode | str, float]] = NOT_GIVEN,
backchannel_threshold: NotGivenOr[float | dict[LanguageCode | str, float]] = NOT_GIVEN,
) -> None:
self._opts.thresholds.update_overrides(unlikely_threshold)
if is_given(unlikely_threshold):
self._opts.thresholds.update_overrides(unlikely_threshold)
if is_given(backchannel_threshold):
self._opts.thresholds.update_backchannel_overrides(backchannel_threshold)
Comment thread
chenghao-mou marked this conversation as resolved.
self._warn_threshold_override()

def stream(
Expand Down
93 changes: 76 additions & 17 deletions livekit-agents/livekit/agents/inference/eot/languages.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@ def __init__(
self,
model: TurnDetectorModels,
overrides: NotGivenOr[float | dict[LanguageCode | str, float]] = NOT_GIVEN,
backchannel_overrides: NotGivenOr[float | dict[LanguageCode | str, float]] = NOT_GIVEN,
) -> None:
self._model = model
self._overrides = _normalize_overrides(overrides)
self._bc_overrides = _normalize_overrides(backchannel_overrides)

# server/shipped defaults
self._server_thresholds: dict[str, float] | None = None
Expand All @@ -54,9 +56,16 @@ def __init__(
self._server_thresholds = dict(LOCAL_LANGUAGES)
self._server_default = LOCAL_LANGUAGES["en"]

# materialized values
# backchannel server defaults: cloud-only (the local mini model produces no
# backchannel probability), arrive via ``SessionCreated``.
self._server_bc_thresholds: dict[str, float] | None = None
self._server_bc_default: float | None = None

# materialized values (server defaults layered with user overrides)
self._thresholds: dict[str, float] = {}
self._default: float | None = None
self._bc_thresholds: dict[str, float] = {}
self._bc_default: float | None = None

self._resolve()

Expand All @@ -68,6 +77,10 @@ def model(self) -> TurnDetectorModels:
def overrides(self) -> NotGivenOr[float | dict[str, float]]:
return self._overrides

@property
def backchannel_overrides(self) -> NotGivenOr[float | dict[str, float]]:
return self._bc_overrides

@property
def thresholds(self) -> dict[str, float]:
return self._thresholds
Expand All @@ -80,6 +93,13 @@ def lookup(self, language: LanguageCode | None) -> float | None:
lang_key = language.language if language else "en"
return self._thresholds.get(lang_key, self.default_threshold)

def lookup_backchannel(self, language: LanguageCode | None) -> float | None:
if not self._bc_thresholds and not self._bc_default:
return None
lang_key = language.language if language else "en"
threshold = self._bc_thresholds.get(lang_key, self._bc_default)
return threshold if threshold and threshold > 0 else None

def supports(self, language: LanguageCode | None) -> bool:
pending = self._model == "turn-detector-v1" and self._server_thresholds is None
return pending or self.lookup(language) is not None
Expand All @@ -90,7 +110,19 @@ def update_overrides(
self._overrides = _normalize_overrides(overrides)
self._resolve()

def _update_defaults(self, server_thresholds: dict[str, float], server_default: float) -> None:
def update_backchannel_overrides(
self, overrides: NotGivenOr[float | dict[LanguageCode | str, float]]
) -> None:
self._bc_overrides = _normalize_overrides(overrides)
self._resolve()

def _update_defaults(
self,
server_thresholds: dict[str, float],
server_default: float,
backchannel_thresholds: dict[str, float] | None = None,
backchannel_default: float = 0.0,
) -> None:
if not server_thresholds or server_default <= 0:
raise APIError(
"turn detector session created without usable default thresholds",
Expand All @@ -103,6 +135,17 @@ def _update_defaults(self, server_thresholds: dict[str, float], server_default:
}
self._server_default = round(server_default, 4)

# backchannel defaults are optional; an absent/empty map keeps backchannel disabled
self._server_bc_thresholds = (
{
LanguageCode(lang).language: round(value, 4)
for lang, value in backchannel_thresholds.items()
}
if backchannel_thresholds
else None
)
self._server_bc_default = round(backchannel_default, 4) if backchannel_default > 0 else None

self._resolve()

def _to_local_fallback(self) -> None:
Expand All @@ -121,28 +164,44 @@ def _to_local_fallback(self) -> None:
self._model = "turn-detector-v1-mini"
self._server_thresholds = dict(LOCAL_LANGUAGES)
self._server_default = LOCAL_LANGUAGES["en"]
# the mini model produces no backchannel probability
self._server_bc_thresholds = None
self._server_bc_default = None
self._resolve()

if rescaled is not None:
self._thresholds = rescaled
self._default = self.lookup(LanguageCode("en"))

def _resolve(self) -> None:
scalar_override = is_given(self._overrides) and not isinstance(self._overrides, dict)
if self._server_thresholds is None or self._server_default is None:
# cloud defaults not received yet; only a scalar override resolves up front
self._thresholds = {}
self._default = float(cast(float, self._overrides)) if scalar_override else None
return

if not is_given(self._overrides):
self._thresholds, self._default = dict(self._server_thresholds), self._server_default
return
self._thresholds, self._default = self._resolve_layer(
self._server_thresholds, self._server_default, self._overrides
)
self._bc_thresholds, self._bc_default = self._resolve_layer(
self._server_bc_thresholds, self._server_bc_default, self._bc_overrides
)

@staticmethod
def _resolve_layer(
server_thresholds: dict[str, float] | None,
server_default: float | None,
overrides: NotGivenOr[float | dict[str, float]],
) -> tuple[dict[str, float], float | None]:
"""Layer a user override onto the server defaults.

A scalar override replaces the whole map (every language resolves through
it); a dict override is merged over the server map. Before server defaults
arrive, only a scalar override resolves up front.
"""
scalar_override = is_given(overrides) and not isinstance(overrides, dict)
if server_thresholds is None or server_default is None:
return {}, (float(cast(float, overrides)) if scalar_override else None)

if not is_given(overrides):
return dict(server_thresholds), server_default

if scalar_override:
self._thresholds, self._default = {}, float(cast(float, self._overrides))
return
return {}, float(cast(float, overrides))

override = cast("dict[str, float]", self._overrides)
self._thresholds = {**self._server_thresholds, **override}
self._default = self._server_default
override = cast("dict[str, float]", overrides)
return {**server_thresholds, **override}, server_default
6 changes: 5 additions & 1 deletion livekit-agents/livekit/agents/inference/eot/transports.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def _process_message(self, msg: ServerMessage) -> None:
prediction.probability,
detection_delay=detection_delay_ms / 1000.0,
inference_duration=inference_duration_ms / 1000.0,
backchannel_probability=prediction.backchannel_probability,
)

client_e2e_ms = inference_stats.client_e2e_latency.ToMilliseconds()
Expand All @@ -239,7 +240,10 @@ def _process_message(self, msg: ServerMessage) -> None:
created = msg.session_created
thresholds = stream._opts.thresholds
thresholds._update_defaults(
dict(created.default_thresholds), created.default_threshold
dict(created.default_thresholds),
created.default_threshold,
dict(created.default_backchannel_thresholds),
created.default_backchannel_threshold,
Comment thread
chenghao-mou marked this conversation as resolved.
)
logger.debug(
"audio turn detector initialized",
Expand Down
6 changes: 6 additions & 0 deletions livekit-agents/livekit/agents/voice/agent_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
SpeechCreatedEvent,
UserInputTranscribedEvent,
UserTurnExceededEvent,
_AgentBackchannelOpportunityEvent,
)
from .generation import (
ToolExecutionOutput,
Expand Down Expand Up @@ -2071,6 +2072,11 @@ def on_eot_prediction(self, ev: EotPredictionEvent) -> None:
if (host := self._session._session_host) is not None:
host._on_eot_prediction(ev)

def on_agent_backchannel_opportunity(self, ev: _AgentBackchannelOpportunityEvent) -> None:
# TODO: consume the backchannel opportunity internally (e.g. trigger a
# backchannel phrase). Kept internal for now — not surfaced as a public event.
pass

def on_end_of_turn(self, info: _EndOfTurnInfo) -> bool:
# IMPORTANT: This method is sync to avoid it being cancelled by the AudioRecognition
# We explicitly create a new task here
Expand Down
35 changes: 34 additions & 1 deletion livekit-agents/livekit/agents/voice/audio_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@
from . import io
from ._utils import _set_participant_attributes
from .endpointing import BaseEndpointing
from .events import EotPredictionEvent, UserTurnExceededEvent
from .events import (
EotPredictionEvent,
UserTurnExceededEvent,
_AgentBackchannelOpportunityEvent,
)
from .turn import (
TurnDetectionEvent,
TurnDetectionMode as TurnDetectionMode,
Expand Down Expand Up @@ -87,6 +91,7 @@ def on_interim_transcript(self, ev: stt.SpeechEvent, *, speaking: bool | None) -
def on_final_transcript(self, ev: stt.SpeechEvent, *, speaking: bool | None = None) -> None: ...
def on_end_of_turn(self, info: _EndOfTurnInfo) -> bool: ...
def on_eot_prediction(self, ev: EotPredictionEvent) -> None: ...
def on_agent_backchannel_opportunity(self, ev: _AgentBackchannelOpportunityEvent) -> None: ...
def on_preemptive_generation(self, info: _PreemptiveGenerationInfo) -> None: ...
def on_user_turn_exceeded(self, ev: UserTurnExceededEvent) -> None: ...
def retrieve_chat_ctx(self) -> llm.ChatContext: ...
Expand Down Expand Up @@ -1291,6 +1296,7 @@ async def _bounce_eou_task(

end_of_turn_probability: float | None = None
unlikely_threshold: float | None = None
backchannel_threshold: float | None = None

if turn_detector is not None:
if not await turn_detector.supports_language(self._last_language):
Expand All @@ -1317,6 +1323,11 @@ async def _bounce_eou_task(
unlikely_threshold = await turn_detector.unlikely_threshold(
self._last_language
)
backchannel_threshold = (
await turn_detector.backchannel_threshold(
self._last_language
)
)
else:
logger.warning(
"eot prediction timed out, committing without a prediction",
Expand Down Expand Up @@ -1415,6 +1426,28 @@ async def _bounce_eou_task(
delay=delay,
)
)
# surface the backchannel opportunity whenever it clears its
# threshold, regardless of end-of-turn; AgentActivity decides
# whether to acknowledge mid-turn or let it lead the reply
backchannel_probability = (
prediction_event.backchannel_probability
if prediction_event is not None
else None
)
if (
backchannel_probability is not None
and backchannel_threshold is not None
and backchannel_probability >= backchannel_threshold
):
self._hooks.on_agent_backchannel_opportunity(
_AgentBackchannelOpportunityEvent(
probability=backchannel_probability,
threshold=backchannel_threshold,
end_of_turn_probability=end_of_turn_probability,
end_of_turn_threshold=unlikely_threshold,
language=self._last_language,
)
)
Comment thread
chenghao-mou marked this conversation as resolved.
if (
prediction_event is not None
and prediction_event.detection_delay is not None
Expand Down
21 changes: 21 additions & 0 deletions livekit-agents/livekit/agents/voice/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,27 @@ class EotPredictionEvent(BaseModel):
created_at: float = Field(default_factory=time.time)


class _AgentBackchannelOpportunityEvent(BaseModel):
"""Internal: a window in which the agent could backchannel (a short
acknowledgment such as "mm-hmm"), as predicted by the turn detector. Passed to
``AgentActivity`` only — not surfaced as a public ``AgentSession`` event yet.

``AgentActivity`` owns the decision of what to do with it. The end-of-turn margin
(``end_of_turn_threshold - end_of_turn_probability``) gives a progressive risk axis:
a large positive margin means the user is clearly still going, so riskier
backchannels (yeah/okay/right) are safe; a small margin (or a negative one, where
``end_of_turn_probability >= end_of_turn_threshold`` and a reply is imminent) calls
for safe, less ambiguous ones (hmm/uh-huh) that won't collide with the reply."""

type: Literal["agent_backchannel_opportunity"] = "agent_backchannel_opportunity"
probability: float
threshold: float
end_of_turn_probability: float
end_of_turn_threshold: float
language: str | None = None
created_at: float = Field(default_factory=time.time)


class AgentFalseInterruptionEvent(BaseModel):
type: Literal["agent_false_interruption"] = "agent_false_interruption"
resumed: bool
Expand Down
4 changes: 4 additions & 0 deletions livekit-agents/livekit/agents/voice/turn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class TurnDetectionEvent:
"""Latest input audio creation time -> prediction receive time."""
inference_duration: float | None = None
"""Server-side model inference time."""
backchannel_probability: float | None = None
"""How appropriate it is for the agent to backchannel at this pause.
``None`` when the detector does not produce one (e.g. the local mini model)."""


class _TurnDetector(Protocol):
Expand Down Expand Up @@ -60,6 +63,7 @@ def provider(self) -> str: ...
def is_fallback(self) -> bool: ...

async def unlikely_threshold(self, language: LanguageCode | None) -> float | None: ...
async def backchannel_threshold(self, language: LanguageCode | None) -> float | None: ...
async def supports_language(self, language: LanguageCode | None) -> bool: ...

def predict(self) -> asyncio.Future[TurnDetectionEvent]: ...
Expand Down
Loading
Loading