From 13d8e0683ffa08bcfc9a381eb9604adf5abdbe0b Mon Sep 17 00:00:00 2001 From: plutoless Date: Thu, 2 Jul 2026 05:48:48 -0700 Subject: [PATCH 01/12] chore: ignore local venv and planning docs --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index d2e4ca8..3e30921 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,5 @@ __pycache__/ dist/ poetry.toml +.venv/ +docs/superpowers/ From 551fb6376af1db94504afd9d60cf393f13637164 Mon Sep 17 00:00:00 2001 From: plutoless Date: Thu, 2 Jul 2026 05:51:59 -0700 Subject: [PATCH 02/12] test: golden-master snapshots for vendor config collapse Co-Authored-By: Claude Sonnet 4.6 --- tests/custom/test_vendor_collapse_golden.py | 107 ++++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 tests/custom/test_vendor_collapse_golden.py diff --git a/tests/custom/test_vendor_collapse_golden.py b/tests/custom/test_vendor_collapse_golden.py new file mode 100644 index 0000000..2d3f42d --- /dev/null +++ b/tests/custom/test_vendor_collapse_golden.py @@ -0,0 +1,107 @@ +"""Golden-master snapshots for the trickiest vendor transformations. + +These freeze to_config() output BEFORE the collapse refactor and must stay green +throughout it. The broader tests/custom/test_*_vendors.py suite guards the rote +classes; this file targets the classes with custom __init__, sample_rate logic, +cross-vendor inheritance, or aliases. +""" +from agora_agent import ( + DeepgramSTT, + ElevenLabsTTS, + GoogleTTS, + OpenAITTS, + OpenAI, + Groq, + CustomLLM, + VertexAILLM, +) +from agora_agent.agentkit.vendors.cn import AliyunLLM, FengmingSTT, SenseTimeAvatar +from agora_agent.agentkit.vendors.avatar import HeyGenAvatar + + +def test_deepgram_stt_golden() -> None: + cfg = DeepgramSTT(model="nova-3", language="en-US", smart_format=True).to_config() + assert cfg == { + "vendor": "deepgram", + "params": {"model": "nova-3", "language": "en-US", "smart_format": True}, + } + + +def test_elevenlabs_sample_rate_field_golden() -> None: + tts = ElevenLabsTTS( + key="k", model_id="eleven_flash_v2_5", voice_id="v", + base_url="wss://api.elevenlabs.io/v1", sample_rate=24000, + ) + assert tts.sample_rate == 24000 + assert tts.to_config()["params"]["sample_rate"] == 24000 + + +def test_google_tts_sample_rate_hertz_golden() -> None: + # GoogleTTS uses `key` for the credentials JSON string (not project_id/location/adc_credentials_string). + tts = GoogleTTS( + key="{}", voice_name="en-US-Neural2-A", language_code="en-US", sample_rate_hertz=16000, + ) + assert tts.sample_rate == 16000 + assert tts.to_config()["params"]["AudioConfig"]["sample_rate_hertz"] == 16000 + + +def test_openai_tts_fixed_sample_rate_golden() -> None: + assert OpenAITTS(voice="alloy").sample_rate == 24000 + + +def test_openai_llm_golden() -> None: + cfg = OpenAI(model="gpt-4o-mini").to_config() + assert cfg["style"] == "openai" + assert cfg["params"]["model"] == "gpt-4o-mini" + assert cfg["url"] == "https://api.openai.com/v1/chat/completions" + + +def test_groq_golden() -> None: + cfg = Groq(api_key="k", model="llama-3.3-70b-versatile", + base_url="https://api.groq.com/openai/v1/chat/completions").to_config() + assert cfg["url"] == "https://api.groq.com/openai/v1/chat/completions" + assert cfg["style"] == "openai" + assert cfg["params"]["model"] == "llama-3.3-70b-versatile" + + +def test_custom_llm_golden() -> None: + cfg = CustomLLM(api_key="k", model="m", base_url="https://x/chat").to_config() + assert cfg["vendor"] == "custom" + assert cfg["url"] == "https://x/chat" + + +def test_vertexai_llm_golden() -> None: + cfg = VertexAILLM(api_key="tok", project_id="proj", location="us-central1", + model="gemini-1.5-pro").to_config() + assert cfg["api_key"] == "tok" + assert "us-central1-aiplatform.googleapis.com" in cfg["url"] + + +def test_aliyun_llm_pins_vendor_golden() -> None: + cfg = AliyunLLM(api_key="k", model="qwen-max", + base_url="https://dashscope.example/chat").to_config() + assert cfg["vendor"] == "aliyun" + assert cfg["api_key"] == "k" + + +def test_sensetime_avatar_camelcase_golden() -> None: + # SenseTimeAvatarOptions uses alias "appId" for the app_id field; + # pydantic v2 requires the alias keyword in the constructor. + cfg = SenseTimeAvatar(agora_uid="2", appId="app", app_key="key").to_config() + assert cfg["vendor"] == "sensetime" + assert cfg["params"]["appId"] == "app" + + +def test_heygen_avatar_golden() -> None: + import warnings + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + cfg = HeyGenAvatar(api_key="k", quality="high", agora_uid="2").to_config() + assert cfg["vendor"] == "heygen" + + +def test_fengming_rejects_kwargs() -> None: + import pytest + with pytest.raises(TypeError): + FengmingSTT(unexpected="x") + assert FengmingSTT().to_config() == {"vendor": "fengming"} From 2b49847d87c0ac2af8b1aea9fabea227f64360c5 Mon Sep 17 00:00:00 2001 From: plutoless Date: Thu, 2 Jul 2026 05:56:24 -0700 Subject: [PATCH 03/12] refactor: collapse STT vendor configs into pydantic models --- src/agora_agent/agentkit/vendors/base.py | 3 +- src/agora_agent/agentkit/vendors/cn.py | 122 +++++------- src/agora_agent/agentkit/vendors/stt.py | 205 ++++++++------------ tests/custom/test_vendor_collapse_golden.py | 3 +- 4 files changed, 133 insertions(+), 200 deletions(-) diff --git a/src/agora_agent/agentkit/vendors/base.py b/src/agora_agent/agentkit/vendors/base.py index f4c4ce0..2888d29 100644 --- a/src/agora_agent/agentkit/vendors/base.py +++ b/src/agora_agent/agentkit/vendors/base.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, Optional +from pydantic import BaseModel from typing_extensions import Literal # Supported sample rates across all TTS providers. @@ -49,7 +50,7 @@ def sample_rate(self) -> Optional[int]: """The configured sample rate in Hz, or ``None`` if not explicitly set.""" -class BaseSTT(ABC): +class BaseSTT(BaseModel, ABC): """Abstract base class for all STT vendor implementations. Subclasses must implement :meth:`to_config` to return a dict that maps to diff --git a/src/agora_agent/agentkit/vendors/cn.py b/src/agora_agent/agentkit/vendors/cn.py index e5dbca3..95949e3 100644 --- a/src/agora_agent/agentkit/vendors/cn.py +++ b/src/agora_agent/agentkit/vendors/cn.py @@ -10,7 +10,7 @@ from .tts import BaseTTS as _BaseTTSCompat -class TencentSTTOptions(BaseModel): +class TencentSTT(_BaseSTTCompat): model_config = ConfigDict(extra="forbid") key: str = Field(..., description="Tencent ASR secret key") @@ -20,36 +20,28 @@ class TencentSTTOptions(BaseModel): voice_id: str = Field(..., description="Tencent ASR voice id") additional_params: Optional[Dict[str, Any]] = Field(default=None) - -class TencentSTT(_BaseSTTCompat): - def __init__(self, **kwargs: Any): - self.options = TencentSTTOptions(**kwargs) - def to_config(self) -> Dict[str, Any]: - params: Dict[str, Any] = dict(self.options.additional_params or {}) + params: Dict[str, Any] = dict(self.additional_params or {}) params.update( { - "key": self.options.key, - "app_id": self.options.app_id, - "secret": self.options.secret, - "engine_model_type": self.options.engine_model_type, - "voice_id": self.options.voice_id, + "key": self.key, + "app_id": self.app_id, + "secret": self.secret, + "engine_model_type": self.engine_model_type, + "voice_id": self.voice_id, } ) return {"vendor": "tencent", "params": params} class FengmingSTT(_BaseSTTCompat): - def __init__(self, **kwargs: Any): - if kwargs: - unexpected = ", ".join(sorted(kwargs)) - raise TypeError(f"FengmingSTT does not accept parameters: {unexpected}") + model_config = ConfigDict(extra="forbid") def to_config(self) -> Dict[str, Any]: return {"vendor": "fengming"} -class XfyunSTTOptions(BaseModel): +class XfyunSTT(_BaseSTTCompat): model_config = ConfigDict(extra="forbid") api_key: Optional[str] = Field(default=None, description="Xfyun ASR API key") @@ -58,28 +50,23 @@ class XfyunSTTOptions(BaseModel): language: Optional[str] = Field(default=None, description="Xfyun ASR language") additional_params: Optional[Dict[str, Any]] = Field(default=None) - -class XfyunSTT(_BaseSTTCompat): - def __init__(self, **kwargs: Any): - self.options = XfyunSTTOptions(**kwargs) - def to_config(self) -> Dict[str, Any]: - params: Dict[str, Any] = dict(self.options.additional_params or {}) - if self.options.api_key is not None: - params["api_key"] = self.options.api_key - if self.options.app_id is not None: - params["app_id"] = self.options.app_id - if self.options.api_secret is not None: - params["api_secret"] = self.options.api_secret - if self.options.language is not None: - params["language"] = self.options.language + params: Dict[str, Any] = dict(self.additional_params or {}) + if self.api_key is not None: + params["api_key"] = self.api_key + if self.app_id is not None: + params["app_id"] = self.app_id + if self.api_secret is not None: + params["api_secret"] = self.api_secret + if self.language is not None: + params["language"] = self.language return { "vendor": "xfyun", "params": params, } -class XfyunBigModelSTTOptions(BaseModel): +class XfyunBigModelSTT(_BaseSTTCompat): model_config = ConfigDict(extra="forbid") api_key: Optional[str] = Field(default=None, description="Xfyun BigModel ASR API key") @@ -89,30 +76,25 @@ class XfyunBigModelSTTOptions(BaseModel): language: Optional[str] = Field(default=None, description="Xfyun BigModel ASR language") additional_params: Optional[Dict[str, Any]] = Field(default=None) - -class XfyunBigModelSTT(_BaseSTTCompat): - def __init__(self, **kwargs: Any): - self.options = XfyunBigModelSTTOptions(**kwargs) - def to_config(self) -> Dict[str, Any]: - params: Dict[str, Any] = dict(self.options.additional_params or {}) - if self.options.api_key is not None: - params["api_key"] = self.options.api_key - if self.options.app_id is not None: - params["app_id"] = self.options.app_id - if self.options.api_secret is not None: - params["api_secret"] = self.options.api_secret - if self.options.language_name is not None: - params["language_name"] = self.options.language_name - if self.options.language is not None: - params["language"] = self.options.language + params: Dict[str, Any] = dict(self.additional_params or {}) + if self.api_key is not None: + params["api_key"] = self.api_key + if self.app_id is not None: + params["app_id"] = self.app_id + if self.api_secret is not None: + params["api_secret"] = self.api_secret + if self.language_name is not None: + params["language_name"] = self.language_name + if self.language is not None: + params["language"] = self.language return { "vendor": "xfyun_bigmodel", "params": params, } -class XfyunDialectSTTOptions(BaseModel): +class XfyunDialectSTT(_BaseSTTCompat): model_config = ConfigDict(extra="forbid") app_id: Optional[str] = Field(default=None, description="Xfyun Dialect ASR app id") @@ -121,28 +103,23 @@ class XfyunDialectSTTOptions(BaseModel): language: Optional[str] = Field(default=None, description="Xfyun Dialect ASR language") additional_params: Optional[Dict[str, Any]] = Field(default=None) - -class XfyunDialectSTT(_BaseSTTCompat): - def __init__(self, **kwargs: Any): - self.options = XfyunDialectSTTOptions(**kwargs) - def to_config(self) -> Dict[str, Any]: - params: Dict[str, Any] = dict(self.options.additional_params or {}) - if self.options.app_id is not None: - params["app_id"] = self.options.app_id - if self.options.access_key_id is not None: - params["access_key_id"] = self.options.access_key_id - if self.options.access_key_secret is not None: - params["access_key_secret"] = self.options.access_key_secret - if self.options.language is not None: - params["language"] = self.options.language + params: Dict[str, Any] = dict(self.additional_params or {}) + if self.app_id is not None: + params["app_id"] = self.app_id + if self.access_key_id is not None: + params["access_key_id"] = self.access_key_id + if self.access_key_secret is not None: + params["access_key_secret"] = self.access_key_secret + if self.language is not None: + params["language"] = self.language return { "vendor": "xfyun_dialect", "params": params, } -class MicrosoftSTTOptions(BaseModel): +class MicrosoftSTT(_BaseSTTCompat): model_config = ConfigDict(extra="forbid") key: str = Field(..., description="Azure subscription key") @@ -151,20 +128,15 @@ class MicrosoftSTTOptions(BaseModel): phrase_list: Optional[List[str]] = Field(default=None, description="Microsoft ASR phrase list") additional_params: Optional[Dict[str, Any]] = Field(default=None) - -class MicrosoftSTT(_BaseSTTCompat): - def __init__(self, **kwargs: Any): - self.options = MicrosoftSTTOptions(**kwargs) - def to_config(self) -> Dict[str, Any]: - params: Dict[str, Any] = dict(self.options.additional_params or {}) + params: Dict[str, Any] = dict(self.additional_params or {}) params.update({ - "key": self.options.key, - "region": self.options.region, - "language": self.options.language, + "key": self.key, + "region": self.region, + "language": self.language, }) - if self.options.phrase_list is not None: - params["phrase_list"] = self.options.phrase_list + if self.phrase_list is not None: + params["phrase_list"] = self.phrase_list return { "vendor": "microsoft", "params": params, diff --git a/src/agora_agent/agentkit/vendors/stt.py b/src/agora_agent/agentkit/vendors/stt.py index aa651cb..4a79fb6 100644 --- a/src/agora_agent/agentkit/vendors/stt.py +++ b/src/agora_agent/agentkit/vendors/stt.py @@ -1,13 +1,13 @@ from typing import Any, Dict, Optional -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import ConfigDict, Field, model_validator from .base import BaseSTT _DEEPGRAM_MANAGED_MODELS = {"nova-2", "nova-3"} -class SpeechmaticsSTTOptions(BaseModel): +class SpeechmaticsSTT(BaseSTT): model_config = ConfigDict(extra="forbid") api_key: str = Field(..., description="Speechmatics API key") @@ -16,20 +16,16 @@ class SpeechmaticsSTTOptions(BaseModel): uri: Optional[str] = Field(default=None, description="Speechmatics streaming WebSocket URL") additional_params: Optional[Dict[str, Any]] = Field(default=None) -class SpeechmaticsSTT(BaseSTT): - def __init__(self, **kwargs: Any): - self.options = SpeechmaticsSTTOptions(**kwargs) - def to_config(self) -> Dict[str, Any]: - params: Dict[str, Any] = dict(self.options.additional_params or {}) + params: Dict[str, Any] = dict(self.additional_params or {}) params.update({ - "api_key": self.options.api_key, - "language": self.options.language, + "api_key": self.api_key, + "language": self.language, }) - if self.options.model is not None: - params["model"] = self.options.model - if self.options.uri is not None: - params["uri"] = self.options.uri + if self.model is not None: + params["model"] = self.model + if self.uri is not None: + params["uri"] = self.uri config: Dict[str, Any] = { "vendor": "speechmatics", @@ -38,7 +34,7 @@ def to_config(self) -> Dict[str, Any]: return config -class DeepgramSTTOptions(BaseModel): +class DeepgramSTT(BaseSTT): model_config = ConfigDict(extra="forbid") api_key: Optional[str] = Field(default=None, description="Deepgram API key") @@ -50,30 +46,26 @@ class DeepgramSTTOptions(BaseModel): additional_params: Optional[Dict[str, Any]] = Field(default=None) @model_validator(mode="after") - def _validate_managed_model(self) -> "DeepgramSTTOptions": + def _validate_managed_model(self) -> "DeepgramSTT": if self.api_key is None and (self.model is None or self.model.strip().lower() not in _DEEPGRAM_MANAGED_MODELS): raise ValueError("DeepgramSTT requires api_key unless using a supported Agora-managed model") return self -class DeepgramSTT(BaseSTT): - def __init__(self, **kwargs: Any): - self.options = DeepgramSTTOptions(**kwargs) - def to_config(self) -> Dict[str, Any]: - params: Dict[str, Any] = dict(self.options.additional_params or {}) - - if self.options.api_key is not None: - params["key"] = self.options.api_key - if self.options.model is not None: - params["model"] = self.options.model - if self.options.language is not None: - params["language"] = self.options.language - if self.options.smart_format is not None: - params["smart_format"] = self.options.smart_format - if self.options.punctuation is not None: - params["punctuation"] = self.options.punctuation - if self.options.keyterm is not None: - params["keyterm"] = self.options.keyterm + params: Dict[str, Any] = dict(self.additional_params or {}) + + if self.api_key is not None: + params["key"] = self.api_key + if self.model is not None: + params["model"] = self.model + if self.language is not None: + params["language"] = self.language + if self.smart_format is not None: + params["smart_format"] = self.smart_format + if self.punctuation is not None: + params["punctuation"] = self.punctuation + if self.keyterm is not None: + params["keyterm"] = self.keyterm config: Dict[str, Any] = { "vendor": "deepgram", "params": params, @@ -81,7 +73,7 @@ def to_config(self) -> Dict[str, Any]: return config -class MicrosoftSTTOptions(BaseModel): +class MicrosoftSTT(BaseSTT): model_config = ConfigDict(extra="forbid") key: str = Field(..., description="Azure subscription key") @@ -89,18 +81,14 @@ class MicrosoftSTTOptions(BaseModel): language: str = Field(..., description="Language code (e.g., en-US)") additional_params: Optional[Dict[str, Any]] = Field(default=None) -class MicrosoftSTT(BaseSTT): - def __init__(self, **kwargs: Any): - self.options = MicrosoftSTTOptions(**kwargs) - def to_config(self) -> Dict[str, Any]: - params: Dict[str, Any] = dict(self.options.additional_params or {}) + params: Dict[str, Any] = dict(self.additional_params or {}) params.update({ - "key": self.options.key, - "region": self.options.region, + "key": self.key, + "region": self.region, }) - if self.options.language is not None: - params["language"] = self.options.language + if self.language is not None: + params["language"] = self.language config: Dict[str, Any] = { "vendor": "microsoft", @@ -109,7 +97,7 @@ def to_config(self) -> Dict[str, Any]: return config -class OpenAISTTOptions(BaseModel): +class OpenAISTT(BaseSTT): model_config = ConfigDict(extra="forbid") api_key: str = Field(..., description="OpenAI API key") @@ -119,22 +107,18 @@ class OpenAISTTOptions(BaseModel): input_audio_transcription: Optional[Dict[str, Any]] = Field(default=None, description="OpenAI transcription settings") additional_params: Optional[Dict[str, Any]] = Field(default=None) -class OpenAISTT(BaseSTT): - def __init__(self, **kwargs: Any): - self.options = OpenAISTTOptions(**kwargs) - def to_config(self) -> Dict[str, Any]: - params: Dict[str, Any] = dict(self.options.additional_params or {}) - params["api_key"] = self.options.api_key + params: Dict[str, Any] = dict(self.additional_params or {}) + params["api_key"] = self.api_key transcription: Dict[str, Any] = {"model": "gpt-4o-mini-transcribe"} - transcription.update(self.options.input_audio_transcription or {}) - if self.options.model is not None: - transcription["model"] = self.options.model - if self.options.prompt is not None: - transcription["prompt"] = self.options.prompt - if self.options.language is not None: - transcription["language"] = self.options.language + transcription.update(self.input_audio_transcription or {}) + if self.model is not None: + transcription["model"] = self.model + if self.prompt is not None: + transcription["prompt"] = self.prompt + if self.language is not None: + transcription["language"] = self.language if not transcription.get("model"): raise ValueError("OpenAISTT: input_audio_transcription.model is required") if not transcription.get("prompt"): @@ -150,7 +134,7 @@ def to_config(self) -> Dict[str, Any]: return config -class GoogleSTTOptions(BaseModel): +class GoogleSTT(BaseSTT): model_config = ConfigDict(extra="forbid") project_id: str = Field(..., description="Google Cloud project ID") @@ -160,22 +144,18 @@ class GoogleSTTOptions(BaseModel): model: Optional[str] = Field(default=None, description="Recognition model") additional_params: Optional[Dict[str, Any]] = Field(default=None) -class GoogleSTT(BaseSTT): - def __init__(self, **kwargs: Any): - self.options = GoogleSTTOptions(**kwargs) - def to_config(self) -> Dict[str, Any]: - params: Dict[str, Any] = dict(self.options.additional_params or {}) + params: Dict[str, Any] = dict(self.additional_params or {}) params.update({ - "project_id": self.options.project_id, - "location": self.options.location, - "adc_credentials_string": self.options.adc_credentials_string, + "project_id": self.project_id, + "location": self.location, + "adc_credentials_string": self.adc_credentials_string, }) - if self.options.language is not None: - params["language"] = self.options.language - if self.options.model is not None: - params["model"] = self.options.model + if self.language is not None: + params["language"] = self.language + if self.model is not None: + params["model"] = self.model config: Dict[str, Any] = { "vendor": "google", @@ -184,7 +164,7 @@ def to_config(self) -> Dict[str, Any]: return config -class AmazonSTTOptions(BaseModel): +class AmazonSTT(BaseSTT): model_config = ConfigDict(extra="forbid") access_key: str = Field(..., description="AWS Access Key ID") @@ -193,19 +173,15 @@ class AmazonSTTOptions(BaseModel): language: str = Field(..., description="Language code") additional_params: Optional[Dict[str, Any]] = Field(default=None) -class AmazonSTT(BaseSTT): - def __init__(self, **kwargs: Any): - self.options = AmazonSTTOptions(**kwargs) - def to_config(self) -> Dict[str, Any]: - params: Dict[str, Any] = dict(self.options.additional_params or {}) + params: Dict[str, Any] = dict(self.additional_params or {}) params.update({ - "access_key_id": self.options.access_key, - "secret_access_key": self.options.secret_key, - "region": self.options.region, + "access_key_id": self.access_key, + "secret_access_key": self.secret_key, + "region": self.region, }) - if self.options.language is not None: - params["language_code"] = self.options.language + if self.language is not None: + params["language_code"] = self.language config: Dict[str, Any] = { "vendor": "amazon", @@ -214,7 +190,7 @@ def to_config(self) -> Dict[str, Any]: return config -class AssemblyAISTTOptions(BaseModel): +class AssemblyAISTT(BaseSTT): model_config = ConfigDict(extra="forbid") api_key: str = Field(..., description="AssemblyAI API key") @@ -222,17 +198,13 @@ class AssemblyAISTTOptions(BaseModel): uri: Optional[str] = Field(default=None, description="AssemblyAI streaming WebSocket URL") additional_params: Optional[Dict[str, Any]] = Field(default=None) -class AssemblyAISTT(BaseSTT): - def __init__(self, **kwargs: Any): - self.options = AssemblyAISTTOptions(**kwargs) - def to_config(self) -> Dict[str, Any]: - params: Dict[str, Any] = dict(self.options.additional_params or {}) - params["api_key"] = self.options.api_key - if self.options.language is not None: - params["language"] = self.options.language - if self.options.uri is not None: - params["uri"] = self.options.uri + params: Dict[str, Any] = dict(self.additional_params or {}) + params["api_key"] = self.api_key + if self.language is not None: + params["language"] = self.language + if self.uri is not None: + params["uri"] = self.uri config: Dict[str, Any] = { "vendor": "assemblyai", @@ -241,23 +213,19 @@ def to_config(self) -> Dict[str, Any]: return config -class AresSTTOptions(BaseModel): +class AresSTT(BaseSTT): model_config = ConfigDict(extra="forbid") additional_params: Optional[Dict[str, Any]] = Field(default=None) -class AresSTT(BaseSTT): - def __init__(self, **kwargs: Any): - self.options = AresSTTOptions(**kwargs) - def to_config(self) -> Dict[str, Any]: config: Dict[str, Any] = {"vendor": "ares"} - if self.options.additional_params: - config["params"] = self.options.additional_params + if self.additional_params: + config["params"] = self.additional_params return config -class SarvamSTTOptions(BaseModel): +class SarvamSTT(BaseSTT): model_config = ConfigDict(extra="forbid") api_key: str = Field(..., description="Sarvam API key") @@ -265,18 +233,14 @@ class SarvamSTTOptions(BaseModel): model: Optional[str] = Field(default=None, description="Model name") additional_params: Optional[Dict[str, Any]] = Field(default=None) -class SarvamSTT(BaseSTT): - def __init__(self, **kwargs: Any): - self.options = SarvamSTTOptions(**kwargs) - def to_config(self) -> Dict[str, Any]: - params: Dict[str, Any] = dict(self.options.additional_params or {}) + params: Dict[str, Any] = dict(self.additional_params or {}) params.update({ - "api_key": self.options.api_key, - "language": self.options.language, + "api_key": self.api_key, + "language": self.language, }) - if self.options.model is not None: - params["model"] = self.options.model + if self.model is not None: + params["model"] = self.model config: Dict[str, Any] = { "vendor": "sarvam", @@ -285,7 +249,7 @@ def to_config(self) -> Dict[str, Any]: return config -class XaiSTTOptions(BaseModel): +class XaiSTT(BaseSTT): model_config = ConfigDict(extra="forbid") api_key: str = Field(..., description="xAI API key") @@ -294,20 +258,15 @@ class XaiSTTOptions(BaseModel): language: Optional[str] = Field(default=None, description="Language code for speech recognition") additional_params: Optional[Dict[str, Any]] = Field(default=None) - -class XaiSTT(BaseSTT): - def __init__(self, **kwargs: Any): - self.options = XaiSTTOptions(**kwargs) - def to_config(self) -> Dict[str, Any]: - params: Dict[str, Any] = dict(self.options.additional_params or {}) - params["api_key"] = self.options.api_key - if self.options.base_url is not None: - params["base_url"] = self.options.base_url - if self.options.sample_rate is not None: - params["sample_rate"] = self.options.sample_rate - if self.options.language is not None: - params["language"] = self.options.language + params: Dict[str, Any] = dict(self.additional_params or {}) + params["api_key"] = self.api_key + if self.base_url is not None: + params["base_url"] = self.base_url + if self.sample_rate is not None: + params["sample_rate"] = self.sample_rate + if self.language is not None: + params["language"] = self.language config: Dict[str, Any] = { "vendor": "xai", diff --git a/tests/custom/test_vendor_collapse_golden.py b/tests/custom/test_vendor_collapse_golden.py index 2d3f42d..254648a 100644 --- a/tests/custom/test_vendor_collapse_golden.py +++ b/tests/custom/test_vendor_collapse_golden.py @@ -102,6 +102,7 @@ def test_heygen_avatar_golden() -> None: def test_fengming_rejects_kwargs() -> None: import pytest - with pytest.raises(TypeError): + from pydantic import ValidationError + with pytest.raises(ValidationError): FengmingSTT(unexpected="x") assert FengmingSTT().to_config() == {"vendor": "fengming"} From 8d10148630a7885dc335092feb1f0c12df98672e Mon Sep 17 00:00:00 2001 From: plutoless Date: Thu, 2 Jul 2026 06:08:33 -0700 Subject: [PATCH 04/12] refactor: collapse LLM vendor configs; standalone copies for OpenAI-compatible LLMs --- src/agora_agent/agentkit/vendors/base.py | 2 +- src/agora_agent/agentkit/vendors/cn.py | 344 +++++++++++- src/agora_agent/agentkit/vendors/llm.py | 683 ++++++++++++++--------- 3 files changed, 748 insertions(+), 281 deletions(-) diff --git a/src/agora_agent/agentkit/vendors/base.py b/src/agora_agent/agentkit/vendors/base.py index 2888d29..f2df4eb 100644 --- a/src/agora_agent/agentkit/vendors/base.py +++ b/src/agora_agent/agentkit/vendors/base.py @@ -16,7 +16,7 @@ GoogleTTSSampleRate = Literal[8000, 16000, 22050, 24000, 44100, 48000] -class BaseLLM(ABC): +class BaseLLM(BaseModel, ABC): """Abstract base class for all LLM vendor implementations. Subclasses must implement :meth:`to_config` to return a dict that maps to diff --git a/src/agora_agent/agentkit/vendors/cn.py b/src/agora_agent/agentkit/vendors/cn.py index 95949e3..695ad8f 100644 --- a/src/agora_agent/agentkit/vendors/cn.py +++ b/src/agora_agent/agentkit/vendors/cn.py @@ -5,7 +5,13 @@ from pydantic import BaseModel, ConfigDict, Field, model_validator from .avatar import BaseAvatar -from .llm import OpenAI +from .base import BaseLLM +from .llm import ( + _OPENAI_MANAGED_MODELS, + LlmGreetingConfigs, + _dump_optional_model, + _ensure_mcp_transport, +) from .stt import BaseSTT as _BaseSTTCompat from .tts import BaseTTS as _BaseTTSCompat @@ -497,28 +503,332 @@ def to_config(self) -> Dict[str, Any]: return result -class AliyunLLM(OpenAI): - def __init__(self, **kwargs: Any): - kwargs["vendor"] = "aliyun" - super().__init__(**kwargs) +class AliyunLLM(BaseLLM): + model_config = ConfigDict(extra="forbid") + api_key: Optional[str] = Field(default=None, description="OpenAI API key") + model: str = Field(..., description="Model name") + base_url: Optional[str] = Field(default=None, description="Custom base URL") + temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0) + top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0) + max_tokens: Optional[int] = Field(default=None, gt=0) + system_messages: Optional[List[Dict[str, Any]]] = Field(default=None) + greeting_message: Optional[str] = Field(default=None) + greeting_audio_url: Optional[str] = Field(default=None) + failure_message: Optional[str] = Field(default=None) + input_modalities: Optional[List[str]] = Field(default=None) + params: Optional[Dict[str, Any]] = Field(default=None) + headers: Optional[Dict[str, str]] = Field(default=None) + output_modalities: Optional[List[str]] = Field(default=None) + greeting_configs: Optional[LlmGreetingConfigs] = Field(default=None) + template_variables: Optional[Dict[str, str]] = Field(default=None) + vendor: Optional[str] = Field(default="aliyun") + mcp_servers: Optional[List[Dict[str, Any]]] = Field(default=None) + max_history: Optional[int] = Field(default=None, gt=0, description="Maximum number of conversation history messages to cache") -class BytedanceLLM(OpenAI): - def __init__(self, **kwargs: Any): - kwargs["vendor"] = "bytedance" - super().__init__(**kwargs) + @model_validator(mode="after") + def _validate_byok_params(self) -> "AliyunLLM": + if not self.model: + raise ValueError("OpenAI requires model") + if self.api_key is not None and self.base_url is None: + raise ValueError("OpenAI requires base_url when api_key is set") + if self.api_key is None and self.base_url is not None: + raise ValueError("OpenAI base_url is only valid when api_key is set") + if self.api_key is None and self.model.strip().lower() not in _OPENAI_MANAGED_MODELS: + raise ValueError("OpenAI requires api_key unless using a supported Agora-managed model") + if self.api_key is None and self.vendor is not None: + raise ValueError("OpenAI Agora-managed mode does not allow vendor") + return self + def to_config(self) -> Dict[str, Any]: + params: Dict[str, Any] = {"model": self.model, **(self.params or {})} -class DeepSeekLLM(OpenAI): - def __init__(self, **kwargs: Any): - kwargs["vendor"] = "deepseek" - super().__init__(**kwargs) + if self.max_tokens is not None: + params["max_tokens"] = self.max_tokens + if self.temperature is not None: + params["temperature"] = self.temperature + if self.top_p is not None: + params["top_p"] = self.top_p + config: Dict[str, Any] = { + "url": self.base_url or "https://api.openai.com/v1/chat/completions", + "params": params, + "style": "openai", + "input_modalities": self.input_modalities or ["text"], + } + if self.api_key is not None: + config["api_key"] = self.api_key + if self.headers is not None: + config["headers"] = self.headers + + if self.system_messages is not None: + config["system_messages"] = self.system_messages + if self.greeting_message is not None: + config["greeting_message"] = self.greeting_message + if self.greeting_audio_url is not None: + config["greeting_audio_url"] = self.greeting_audio_url + if self.failure_message is not None: + config["failure_message"] = self.failure_message + if self.output_modalities is not None: + config["output_modalities"] = self.output_modalities + if self.greeting_configs is not None: + config["greeting_configs"] = _dump_optional_model(self.greeting_configs) + if self.template_variables is not None: + config["template_variables"] = self.template_variables + if self.vendor is not None: + config["vendor"] = self.vendor + if self.mcp_servers is not None: + config["mcp_servers"] = _ensure_mcp_transport(self.mcp_servers) + if self.max_history is not None: + config["max_history"] = self.max_history + + return config + + +class BytedanceLLM(BaseLLM): + model_config = ConfigDict(extra="forbid") -class TencentLLM(OpenAI): - def __init__(self, **kwargs: Any): - kwargs["vendor"] = "tencent" - super().__init__(**kwargs) + api_key: Optional[str] = Field(default=None, description="OpenAI API key") + model: str = Field(..., description="Model name") + base_url: Optional[str] = Field(default=None, description="Custom base URL") + temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0) + top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0) + max_tokens: Optional[int] = Field(default=None, gt=0) + system_messages: Optional[List[Dict[str, Any]]] = Field(default=None) + greeting_message: Optional[str] = Field(default=None) + greeting_audio_url: Optional[str] = Field(default=None) + failure_message: Optional[str] = Field(default=None) + input_modalities: Optional[List[str]] = Field(default=None) + params: Optional[Dict[str, Any]] = Field(default=None) + headers: Optional[Dict[str, str]] = Field(default=None) + output_modalities: Optional[List[str]] = Field(default=None) + greeting_configs: Optional[LlmGreetingConfigs] = Field(default=None) + template_variables: Optional[Dict[str, str]] = Field(default=None) + vendor: Optional[str] = Field(default="bytedance") + mcp_servers: Optional[List[Dict[str, Any]]] = Field(default=None) + max_history: Optional[int] = Field(default=None, gt=0, description="Maximum number of conversation history messages to cache") + + @model_validator(mode="after") + def _validate_byok_params(self) -> "BytedanceLLM": + if not self.model: + raise ValueError("OpenAI requires model") + if self.api_key is not None and self.base_url is None: + raise ValueError("OpenAI requires base_url when api_key is set") + if self.api_key is None and self.base_url is not None: + raise ValueError("OpenAI base_url is only valid when api_key is set") + if self.api_key is None and self.model.strip().lower() not in _OPENAI_MANAGED_MODELS: + raise ValueError("OpenAI requires api_key unless using a supported Agora-managed model") + if self.api_key is None and self.vendor is not None: + raise ValueError("OpenAI Agora-managed mode does not allow vendor") + return self + + def to_config(self) -> Dict[str, Any]: + params: Dict[str, Any] = {"model": self.model, **(self.params or {})} + + if self.max_tokens is not None: + params["max_tokens"] = self.max_tokens + if self.temperature is not None: + params["temperature"] = self.temperature + if self.top_p is not None: + params["top_p"] = self.top_p + + config: Dict[str, Any] = { + "url": self.base_url or "https://api.openai.com/v1/chat/completions", + "params": params, + "style": "openai", + "input_modalities": self.input_modalities or ["text"], + } + if self.api_key is not None: + config["api_key"] = self.api_key + if self.headers is not None: + config["headers"] = self.headers + + if self.system_messages is not None: + config["system_messages"] = self.system_messages + if self.greeting_message is not None: + config["greeting_message"] = self.greeting_message + if self.greeting_audio_url is not None: + config["greeting_audio_url"] = self.greeting_audio_url + if self.failure_message is not None: + config["failure_message"] = self.failure_message + if self.output_modalities is not None: + config["output_modalities"] = self.output_modalities + if self.greeting_configs is not None: + config["greeting_configs"] = _dump_optional_model(self.greeting_configs) + if self.template_variables is not None: + config["template_variables"] = self.template_variables + if self.vendor is not None: + config["vendor"] = self.vendor + if self.mcp_servers is not None: + config["mcp_servers"] = _ensure_mcp_transport(self.mcp_servers) + if self.max_history is not None: + config["max_history"] = self.max_history + + return config + + +class DeepSeekLLM(BaseLLM): + model_config = ConfigDict(extra="forbid") + + api_key: Optional[str] = Field(default=None, description="OpenAI API key") + model: str = Field(..., description="Model name") + base_url: Optional[str] = Field(default=None, description="Custom base URL") + temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0) + top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0) + max_tokens: Optional[int] = Field(default=None, gt=0) + system_messages: Optional[List[Dict[str, Any]]] = Field(default=None) + greeting_message: Optional[str] = Field(default=None) + greeting_audio_url: Optional[str] = Field(default=None) + failure_message: Optional[str] = Field(default=None) + input_modalities: Optional[List[str]] = Field(default=None) + params: Optional[Dict[str, Any]] = Field(default=None) + headers: Optional[Dict[str, str]] = Field(default=None) + output_modalities: Optional[List[str]] = Field(default=None) + greeting_configs: Optional[LlmGreetingConfigs] = Field(default=None) + template_variables: Optional[Dict[str, str]] = Field(default=None) + vendor: Optional[str] = Field(default="deepseek") + mcp_servers: Optional[List[Dict[str, Any]]] = Field(default=None) + max_history: Optional[int] = Field(default=None, gt=0, description="Maximum number of conversation history messages to cache") + + @model_validator(mode="after") + def _validate_byok_params(self) -> "DeepSeekLLM": + if not self.model: + raise ValueError("OpenAI requires model") + if self.api_key is not None and self.base_url is None: + raise ValueError("OpenAI requires base_url when api_key is set") + if self.api_key is None and self.base_url is not None: + raise ValueError("OpenAI base_url is only valid when api_key is set") + if self.api_key is None and self.model.strip().lower() not in _OPENAI_MANAGED_MODELS: + raise ValueError("OpenAI requires api_key unless using a supported Agora-managed model") + if self.api_key is None and self.vendor is not None: + raise ValueError("OpenAI Agora-managed mode does not allow vendor") + return self + + def to_config(self) -> Dict[str, Any]: + params: Dict[str, Any] = {"model": self.model, **(self.params or {})} + + if self.max_tokens is not None: + params["max_tokens"] = self.max_tokens + if self.temperature is not None: + params["temperature"] = self.temperature + if self.top_p is not None: + params["top_p"] = self.top_p + + config: Dict[str, Any] = { + "url": self.base_url or "https://api.openai.com/v1/chat/completions", + "params": params, + "style": "openai", + "input_modalities": self.input_modalities or ["text"], + } + if self.api_key is not None: + config["api_key"] = self.api_key + if self.headers is not None: + config["headers"] = self.headers + + if self.system_messages is not None: + config["system_messages"] = self.system_messages + if self.greeting_message is not None: + config["greeting_message"] = self.greeting_message + if self.greeting_audio_url is not None: + config["greeting_audio_url"] = self.greeting_audio_url + if self.failure_message is not None: + config["failure_message"] = self.failure_message + if self.output_modalities is not None: + config["output_modalities"] = self.output_modalities + if self.greeting_configs is not None: + config["greeting_configs"] = _dump_optional_model(self.greeting_configs) + if self.template_variables is not None: + config["template_variables"] = self.template_variables + if self.vendor is not None: + config["vendor"] = self.vendor + if self.mcp_servers is not None: + config["mcp_servers"] = _ensure_mcp_transport(self.mcp_servers) + if self.max_history is not None: + config["max_history"] = self.max_history + + return config + + +class TencentLLM(BaseLLM): + model_config = ConfigDict(extra="forbid") + + api_key: Optional[str] = Field(default=None, description="OpenAI API key") + model: str = Field(..., description="Model name") + base_url: Optional[str] = Field(default=None, description="Custom base URL") + temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0) + top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0) + max_tokens: Optional[int] = Field(default=None, gt=0) + system_messages: Optional[List[Dict[str, Any]]] = Field(default=None) + greeting_message: Optional[str] = Field(default=None) + greeting_audio_url: Optional[str] = Field(default=None) + failure_message: Optional[str] = Field(default=None) + input_modalities: Optional[List[str]] = Field(default=None) + params: Optional[Dict[str, Any]] = Field(default=None) + headers: Optional[Dict[str, str]] = Field(default=None) + output_modalities: Optional[List[str]] = Field(default=None) + greeting_configs: Optional[LlmGreetingConfigs] = Field(default=None) + template_variables: Optional[Dict[str, str]] = Field(default=None) + vendor: Optional[str] = Field(default="tencent") + mcp_servers: Optional[List[Dict[str, Any]]] = Field(default=None) + max_history: Optional[int] = Field(default=None, gt=0, description="Maximum number of conversation history messages to cache") + + @model_validator(mode="after") + def _validate_byok_params(self) -> "TencentLLM": + if not self.model: + raise ValueError("OpenAI requires model") + if self.api_key is not None and self.base_url is None: + raise ValueError("OpenAI requires base_url when api_key is set") + if self.api_key is None and self.base_url is not None: + raise ValueError("OpenAI base_url is only valid when api_key is set") + if self.api_key is None and self.model.strip().lower() not in _OPENAI_MANAGED_MODELS: + raise ValueError("OpenAI requires api_key unless using a supported Agora-managed model") + if self.api_key is None and self.vendor is not None: + raise ValueError("OpenAI Agora-managed mode does not allow vendor") + return self + + def to_config(self) -> Dict[str, Any]: + params: Dict[str, Any] = {"model": self.model, **(self.params or {})} + + if self.max_tokens is not None: + params["max_tokens"] = self.max_tokens + if self.temperature is not None: + params["temperature"] = self.temperature + if self.top_p is not None: + params["top_p"] = self.top_p + + config: Dict[str, Any] = { + "url": self.base_url or "https://api.openai.com/v1/chat/completions", + "params": params, + "style": "openai", + "input_modalities": self.input_modalities or ["text"], + } + if self.api_key is not None: + config["api_key"] = self.api_key + if self.headers is not None: + config["headers"] = self.headers + + if self.system_messages is not None: + config["system_messages"] = self.system_messages + if self.greeting_message is not None: + config["greeting_message"] = self.greeting_message + if self.greeting_audio_url is not None: + config["greeting_audio_url"] = self.greeting_audio_url + if self.failure_message is not None: + config["failure_message"] = self.failure_message + if self.output_modalities is not None: + config["output_modalities"] = self.output_modalities + if self.greeting_configs is not None: + config["greeting_configs"] = _dump_optional_model(self.greeting_configs) + if self.template_variables is not None: + config["template_variables"] = self.template_variables + if self.vendor is not None: + config["vendor"] = self.vendor + if self.mcp_servers is not None: + config["mcp_servers"] = _ensure_mcp_transport(self.mcp_servers) + if self.max_history is not None: + config["max_history"] = self.max_history + + return config class SenseTimeAvatarOptions(BaseModel): diff --git a/src/agora_agent/agentkit/vendors/llm.py b/src/agora_agent/agentkit/vendors/llm.py index 3e31992..1b4a9ad 100644 --- a/src/agora_agent/agentkit/vendors/llm.py +++ b/src/agora_agent/agentkit/vendors/llm.py @@ -1,6 +1,6 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import ConfigDict, Field, model_validator from .base import BaseLLM @@ -26,7 +26,8 @@ def _dump_optional_model(value: Any) -> Any: return value.dict(exclude_none=True) return value -class OpenAIOptions(BaseModel): + +class OpenAI(BaseLLM): model_config = ConfigDict(extra="forbid") api_key: Optional[str] = Field(default=None, description="OpenAI API key") @@ -50,7 +51,7 @@ class OpenAIOptions(BaseModel): max_history: Optional[int] = Field(default=None, gt=0, description="Maximum number of conversation history messages to cache") @model_validator(mode="after") - def _validate_byok_params(self) -> "OpenAIOptions": + def _validate_byok_params(self) -> "OpenAI": if not self.model: raise ValueError("OpenAI requires model") if self.api_key is not None and self.base_url is None: @@ -63,59 +64,55 @@ def _validate_byok_params(self) -> "OpenAIOptions": raise ValueError("OpenAI Agora-managed mode does not allow vendor") return self -class OpenAI(BaseLLM): - def __init__(self, **kwargs: Any): - self.options = OpenAIOptions(**kwargs) - def to_config(self) -> Dict[str, Any]: # model is the default; explicit params entries extend/override it. # This matches the TS SDK behaviour: { model, ...params }. - params: Dict[str, Any] = {"model": self.options.model, **(self.options.params or {})} + params: Dict[str, Any] = {"model": self.model, **(self.params or {})} # Named fields take precedence over anything in the generic params dict. - if self.options.max_tokens is not None: - params["max_tokens"] = self.options.max_tokens - if self.options.temperature is not None: - params["temperature"] = self.options.temperature - if self.options.top_p is not None: - params["top_p"] = self.options.top_p + if self.max_tokens is not None: + params["max_tokens"] = self.max_tokens + if self.temperature is not None: + params["temperature"] = self.temperature + if self.top_p is not None: + params["top_p"] = self.top_p config: Dict[str, Any] = { - "url": self.options.base_url or "https://api.openai.com/v1/chat/completions", + "url": self.base_url or "https://api.openai.com/v1/chat/completions", "params": params, "style": "openai", - "input_modalities": self.options.input_modalities or ["text"], + "input_modalities": self.input_modalities or ["text"], } - if self.options.api_key is not None: - config["api_key"] = self.options.api_key - if self.options.headers is not None: - config["headers"] = self.options.headers - - if self.options.system_messages is not None: - config["system_messages"] = self.options.system_messages - if self.options.greeting_message is not None: - config["greeting_message"] = self.options.greeting_message - if self.options.greeting_audio_url is not None: - config["greeting_audio_url"] = self.options.greeting_audio_url - if self.options.failure_message is not None: - config["failure_message"] = self.options.failure_message - if self.options.output_modalities is not None: - config["output_modalities"] = self.options.output_modalities - if self.options.greeting_configs is not None: - config["greeting_configs"] = _dump_optional_model(self.options.greeting_configs) - if self.options.template_variables is not None: - config["template_variables"] = self.options.template_variables - if self.options.vendor is not None: - config["vendor"] = self.options.vendor - if self.options.mcp_servers is not None: - config["mcp_servers"] = _ensure_mcp_transport(self.options.mcp_servers) - if self.options.max_history is not None: - config["max_history"] = self.options.max_history + if self.api_key is not None: + config["api_key"] = self.api_key + if self.headers is not None: + config["headers"] = self.headers + + if self.system_messages is not None: + config["system_messages"] = self.system_messages + if self.greeting_message is not None: + config["greeting_message"] = self.greeting_message + if self.greeting_audio_url is not None: + config["greeting_audio_url"] = self.greeting_audio_url + if self.failure_message is not None: + config["failure_message"] = self.failure_message + if self.output_modalities is not None: + config["output_modalities"] = self.output_modalities + if self.greeting_configs is not None: + config["greeting_configs"] = _dump_optional_model(self.greeting_configs) + if self.template_variables is not None: + config["template_variables"] = self.template_variables + if self.vendor is not None: + config["vendor"] = self.vendor + if self.mcp_servers is not None: + config["mcp_servers"] = _ensure_mcp_transport(self.mcp_servers) + if self.max_history is not None: + config["max_history"] = self.max_history return config -class AzureOpenAIOptions(BaseModel): +class AzureOpenAI(BaseLLM): model_config = ConfigDict(extra="forbid") api_key: str = Field(..., description="Azure OpenAI API key") @@ -140,60 +137,56 @@ class AzureOpenAIOptions(BaseModel): mcp_servers: Optional[List[Dict[str, Any]]] = Field(default=None) max_history: Optional[int] = Field(default=None, gt=0, description="Maximum number of conversation history messages to cache") -class AzureOpenAI(BaseLLM): - def __init__(self, **kwargs: Any): - self.options = AzureOpenAIOptions(**kwargs) - def to_config(self) -> Dict[str, Any]: url = ( - f"{self.options.endpoint}/openai/deployments/" - f"{self.options.deployment_name}/chat/completions" - f"?api-version={self.options.api_version}" + f"{self.endpoint}/openai/deployments/" + f"{self.deployment_name}/chat/completions" + f"?api-version={self.api_version}" ) config: Dict[str, Any] = { "url": url, - "api_key": self.options.api_key, - "vendor": self.options.vendor or "azure", + "api_key": self.api_key, + "vendor": self.vendor or "azure", "style": "openai", - "input_modalities": self.options.input_modalities or ["text"], + "input_modalities": self.input_modalities or ["text"], } # Named fields take precedence over anything in the generic params dict. - params: Dict[str, Any] = {"model": self.options.model, **(self.options.params or {})} - if self.options.temperature is not None: - params["temperature"] = self.options.temperature - if self.options.top_p is not None: - params["top_p"] = self.options.top_p - if self.options.max_tokens is not None: - params["max_tokens"] = self.options.max_tokens + params: Dict[str, Any] = {"model": self.model, **(self.params or {})} + if self.temperature is not None: + params["temperature"] = self.temperature + if self.top_p is not None: + params["top_p"] = self.top_p + if self.max_tokens is not None: + params["max_tokens"] = self.max_tokens if params: config["params"] = params - if self.options.headers is not None: - config["headers"] = self.options.headers - - if self.options.system_messages is not None: - config["system_messages"] = self.options.system_messages - if self.options.greeting_message is not None: - config["greeting_message"] = self.options.greeting_message - if self.options.greeting_audio_url is not None: - config["greeting_audio_url"] = self.options.greeting_audio_url - if self.options.failure_message is not None: - config["failure_message"] = self.options.failure_message - if self.options.output_modalities is not None: - config["output_modalities"] = self.options.output_modalities - if self.options.greeting_configs is not None: - config["greeting_configs"] = _dump_optional_model(self.options.greeting_configs) - if self.options.template_variables is not None: - config["template_variables"] = self.options.template_variables - if self.options.mcp_servers is not None: - config["mcp_servers"] = _ensure_mcp_transport(self.options.mcp_servers) - if self.options.max_history is not None: - config["max_history"] = self.options.max_history + if self.headers is not None: + config["headers"] = self.headers + + if self.system_messages is not None: + config["system_messages"] = self.system_messages + if self.greeting_message is not None: + config["greeting_message"] = self.greeting_message + if self.greeting_audio_url is not None: + config["greeting_audio_url"] = self.greeting_audio_url + if self.failure_message is not None: + config["failure_message"] = self.failure_message + if self.output_modalities is not None: + config["output_modalities"] = self.output_modalities + if self.greeting_configs is not None: + config["greeting_configs"] = _dump_optional_model(self.greeting_configs) + if self.template_variables is not None: + config["template_variables"] = self.template_variables + if self.mcp_servers is not None: + config["mcp_servers"] = _ensure_mcp_transport(self.mcp_servers) + if self.max_history is not None: + config["max_history"] = self.max_history return config -class AnthropicOptions(BaseModel): +class Anthropic(BaseLLM): model_config = ConfigDict(extra="forbid") api_key: str = Field(..., description="Anthropic API key") @@ -216,54 +209,50 @@ class AnthropicOptions(BaseModel): mcp_servers: Optional[List[Dict[str, Any]]] = Field(default=None) max_history: Optional[int] = Field(default=None, gt=0, description="Maximum number of conversation history messages to cache") -class Anthropic(BaseLLM): - def __init__(self, **kwargs: Any): - self.options = AnthropicOptions(**kwargs) - def to_config(self) -> Dict[str, Any]: # Named fields take precedence over anything in the generic params dict. - params: Dict[str, Any] = {"model": self.options.model, **(self.options.params or {})} - if self.options.max_tokens is not None: - params["max_tokens"] = self.options.max_tokens - if self.options.temperature is not None: - params["temperature"] = self.options.temperature - if self.options.top_p is not None: - params["top_p"] = self.options.top_p + params: Dict[str, Any] = {"model": self.model, **(self.params or {})} + if self.max_tokens is not None: + params["max_tokens"] = self.max_tokens + if self.temperature is not None: + params["temperature"] = self.temperature + if self.top_p is not None: + params["top_p"] = self.top_p config: Dict[str, Any] = { - "url": self.options.url, - "api_key": self.options.api_key, + "url": self.url, + "api_key": self.api_key, "params": params, - "headers": self.options.headers, + "headers": self.headers, "style": "anthropic", - "input_modalities": self.options.input_modalities or ["text"], + "input_modalities": self.input_modalities or ["text"], } - if self.options.system_messages is not None: - config["system_messages"] = self.options.system_messages - if self.options.greeting_message is not None: - config["greeting_message"] = self.options.greeting_message - if self.options.greeting_audio_url is not None: - config["greeting_audio_url"] = self.options.greeting_audio_url - if self.options.failure_message is not None: - config["failure_message"] = self.options.failure_message - if self.options.output_modalities is not None: - config["output_modalities"] = self.options.output_modalities - if self.options.greeting_configs is not None: - config["greeting_configs"] = _dump_optional_model(self.options.greeting_configs) - if self.options.template_variables is not None: - config["template_variables"] = self.options.template_variables - if self.options.vendor is not None: - config["vendor"] = self.options.vendor - if self.options.mcp_servers is not None: - config["mcp_servers"] = _ensure_mcp_transport(self.options.mcp_servers) - if self.options.max_history is not None: - config["max_history"] = self.options.max_history + if self.system_messages is not None: + config["system_messages"] = self.system_messages + if self.greeting_message is not None: + config["greeting_message"] = self.greeting_message + if self.greeting_audio_url is not None: + config["greeting_audio_url"] = self.greeting_audio_url + if self.failure_message is not None: + config["failure_message"] = self.failure_message + if self.output_modalities is not None: + config["output_modalities"] = self.output_modalities + if self.greeting_configs is not None: + config["greeting_configs"] = _dump_optional_model(self.greeting_configs) + if self.template_variables is not None: + config["template_variables"] = self.template_variables + if self.vendor is not None: + config["vendor"] = self.vendor + if self.mcp_servers is not None: + config["mcp_servers"] = _ensure_mcp_transport(self.mcp_servers) + if self.max_history is not None: + config["max_history"] = self.max_history return config -class GeminiOptions(BaseModel): +class Gemini(BaseLLM): model_config = ConfigDict(extra="forbid") api_key: str = Field(..., description="Google AI API key") @@ -287,121 +276,299 @@ class GeminiOptions(BaseModel): mcp_servers: Optional[List[Dict[str, Any]]] = Field(default=None) max_history: Optional[int] = Field(default=None, gt=0, description="Maximum number of conversation history messages to cache") -class Gemini(BaseLLM): - def __init__(self, **kwargs: Any): - self.options = GeminiOptions(**kwargs) - def to_config(self) -> Dict[str, Any]: # Named fields take precedence over anything in the generic params dict. - params: Dict[str, Any] = {"model": self.options.model, **(self.options.params or {})} - if self.options.temperature is not None: - params["temperature"] = self.options.temperature - if self.options.top_p is not None: - params["top_p"] = self.options.top_p - if self.options.top_k is not None: - params["top_k"] = self.options.top_k - if self.options.max_output_tokens is not None: - params["max_output_tokens"] = self.options.max_output_tokens + params: Dict[str, Any] = {"model": self.model, **(self.params or {})} + if self.temperature is not None: + params["temperature"] = self.temperature + if self.top_p is not None: + params["top_p"] = self.top_p + if self.top_k is not None: + params["top_k"] = self.top_k + if self.max_output_tokens is not None: + params["max_output_tokens"] = self.max_output_tokens config: Dict[str, Any] = { - "url": self.options.url or ( + "url": self.url or ( f"https://generativelanguage.googleapis.com/v1beta/models/" - f"{self.options.model}:streamGenerateContent?alt=sse&key={self.options.api_key}" + f"{self.model}:streamGenerateContent?alt=sse&key={self.api_key}" ), "params": params, "style": "gemini", - "input_modalities": self.options.input_modalities or ["text"], + "input_modalities": self.input_modalities or ["text"], } - if self.options.system_messages is not None: - config["system_messages"] = self.options.system_messages - if self.options.headers is not None: - config["headers"] = self.options.headers - if self.options.greeting_message is not None: - config["greeting_message"] = self.options.greeting_message - if self.options.greeting_audio_url is not None: - config["greeting_audio_url"] = self.options.greeting_audio_url - if self.options.failure_message is not None: - config["failure_message"] = self.options.failure_message - if self.options.output_modalities is not None: - config["output_modalities"] = self.options.output_modalities - if self.options.greeting_configs is not None: - config["greeting_configs"] = _dump_optional_model(self.options.greeting_configs) - if self.options.template_variables is not None: - config["template_variables"] = self.options.template_variables - if self.options.vendor is not None: - config["vendor"] = self.options.vendor - if self.options.mcp_servers is not None: - config["mcp_servers"] = _ensure_mcp_transport(self.options.mcp_servers) - if self.options.max_history is not None: - config["max_history"] = self.options.max_history + if self.system_messages is not None: + config["system_messages"] = self.system_messages + if self.headers is not None: + config["headers"] = self.headers + if self.greeting_message is not None: + config["greeting_message"] = self.greeting_message + if self.greeting_audio_url is not None: + config["greeting_audio_url"] = self.greeting_audio_url + if self.failure_message is not None: + config["failure_message"] = self.failure_message + if self.output_modalities is not None: + config["output_modalities"] = self.output_modalities + if self.greeting_configs is not None: + config["greeting_configs"] = _dump_optional_model(self.greeting_configs) + if self.template_variables is not None: + config["template_variables"] = self.template_variables + if self.vendor is not None: + config["vendor"] = self.vendor + if self.mcp_servers is not None: + config["mcp_servers"] = _ensure_mcp_transport(self.mcp_servers) + if self.max_history is not None: + config["max_history"] = self.max_history return config -class GroqOptions(OpenAIOptions): +class Groq(BaseLLM): model_config = ConfigDict(extra="forbid") api_key: str = Field(..., description="Groq API key") model: str = Field(..., description="Model name") base_url: str = Field(..., description="Groq-compatible endpoint") + temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0) + top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0) + max_tokens: Optional[int] = Field(default=None, gt=0) + system_messages: Optional[List[Dict[str, Any]]] = Field(default=None) + greeting_message: Optional[str] = Field(default=None) + greeting_audio_url: Optional[str] = Field(default=None) + failure_message: Optional[str] = Field(default=None) + input_modalities: Optional[List[str]] = Field(default=None) + params: Optional[Dict[str, Any]] = Field(default=None) + headers: Optional[Dict[str, str]] = Field(default=None) + output_modalities: Optional[List[str]] = Field(default=None) + greeting_configs: Optional[LlmGreetingConfigs] = Field(default=None) + template_variables: Optional[Dict[str, str]] = Field(default=None) + vendor: Optional[str] = Field(default=None) + mcp_servers: Optional[List[Dict[str, Any]]] = Field(default=None) + max_history: Optional[int] = Field(default=None, gt=0, description="Maximum number of conversation history messages to cache") - -class Groq(BaseLLM): - def __init__(self, **kwargs: Any): - self.options = GroqOptions(**kwargs) + @model_validator(mode="after") + def _validate_byok_params(self) -> "Groq": + if not self.model: + raise ValueError("OpenAI requires model") + if self.api_key is not None and self.base_url is None: + raise ValueError("OpenAI requires base_url when api_key is set") + if self.api_key is None and self.base_url is not None: + raise ValueError("OpenAI base_url is only valid when api_key is set") + if self.api_key is None and self.model.strip().lower() not in _OPENAI_MANAGED_MODELS: + raise ValueError("OpenAI requires api_key unless using a supported Agora-managed model") + if self.api_key is None and self.vendor is not None: + raise ValueError("OpenAI Agora-managed mode does not allow vendor") + return self def to_config(self) -> Dict[str, Any]: - config = OpenAI(**_dump_optional_model(self.options)).to_config() - config["url"] = self.options.base_url + params: Dict[str, Any] = {"model": self.model, **(self.params or {})} + + if self.max_tokens is not None: + params["max_tokens"] = self.max_tokens + if self.temperature is not None: + params["temperature"] = self.temperature + if self.top_p is not None: + params["top_p"] = self.top_p + + config: Dict[str, Any] = { + "url": self.base_url or "https://api.openai.com/v1/chat/completions", + "params": params, + "style": "openai", + "input_modalities": self.input_modalities or ["text"], + } + if self.api_key is not None: + config["api_key"] = self.api_key + if self.headers is not None: + config["headers"] = self.headers + + if self.system_messages is not None: + config["system_messages"] = self.system_messages + if self.greeting_message is not None: + config["greeting_message"] = self.greeting_message + if self.greeting_audio_url is not None: + config["greeting_audio_url"] = self.greeting_audio_url + if self.failure_message is not None: + config["failure_message"] = self.failure_message + if self.output_modalities is not None: + config["output_modalities"] = self.output_modalities + if self.greeting_configs is not None: + config["greeting_configs"] = _dump_optional_model(self.greeting_configs) + if self.template_variables is not None: + config["template_variables"] = self.template_variables + if self.vendor is not None: + config["vendor"] = self.vendor + if self.mcp_servers is not None: + config["mcp_servers"] = _ensure_mcp_transport(self.mcp_servers) + if self.max_history is not None: + config["max_history"] = self.max_history + + config["url"] = self.base_url return config -class CustomLLMOptions(OpenAIOptions): +class CustomLLM(BaseLLM): model_config = ConfigDict(extra="forbid") api_key: str = Field(..., description="Custom LLM API key") + model: str = Field(..., description="Model name") base_url: str = Field(..., description="OpenAI-compatible chat completions endpoint") + temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0) + top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0) + max_tokens: Optional[int] = Field(default=None, gt=0) + system_messages: Optional[List[Dict[str, Any]]] = Field(default=None) + greeting_message: Optional[str] = Field(default=None) + greeting_audio_url: Optional[str] = Field(default=None) + failure_message: Optional[str] = Field(default=None) + input_modalities: Optional[List[str]] = Field(default=None) + params: Optional[Dict[str, Any]] = Field(default=None) + headers: Optional[Dict[str, str]] = Field(default=None) + output_modalities: Optional[List[str]] = Field(default=None) + greeting_configs: Optional[LlmGreetingConfigs] = Field(default=None) + template_variables: Optional[Dict[str, str]] = Field(default=None) + vendor: Optional[str] = Field(default=None) + mcp_servers: Optional[List[Dict[str, Any]]] = Field(default=None) + max_history: Optional[int] = Field(default=None, gt=0, description="Maximum number of conversation history messages to cache") - -class CustomLLM(BaseLLM): - def __init__(self, **kwargs: Any): - self.options = CustomLLMOptions(**kwargs) + @model_validator(mode="after") + def _validate_byok_params(self) -> "CustomLLM": + if not self.model: + raise ValueError("OpenAI requires model") + if self.api_key is not None and self.base_url is None: + raise ValueError("OpenAI requires base_url when api_key is set") + if self.api_key is None and self.base_url is not None: + raise ValueError("OpenAI base_url is only valid when api_key is set") + if self.api_key is None and self.model.strip().lower() not in _OPENAI_MANAGED_MODELS: + raise ValueError("OpenAI requires api_key unless using a supported Agora-managed model") + if self.api_key is None and self.vendor is not None: + raise ValueError("OpenAI Agora-managed mode does not allow vendor") + return self def to_config(self) -> Dict[str, Any]: - config = OpenAI(**_dump_optional_model(self.options)).to_config() - config["vendor"] = self.options.vendor or "custom" + params: Dict[str, Any] = {"model": self.model, **(self.params or {})} + + if self.max_tokens is not None: + params["max_tokens"] = self.max_tokens + if self.temperature is not None: + params["temperature"] = self.temperature + if self.top_p is not None: + params["top_p"] = self.top_p + + config: Dict[str, Any] = { + "url": self.base_url or "https://api.openai.com/v1/chat/completions", + "params": params, + "style": "openai", + "input_modalities": self.input_modalities or ["text"], + } + if self.api_key is not None: + config["api_key"] = self.api_key + if self.headers is not None: + config["headers"] = self.headers + + if self.system_messages is not None: + config["system_messages"] = self.system_messages + if self.greeting_message is not None: + config["greeting_message"] = self.greeting_message + if self.greeting_audio_url is not None: + config["greeting_audio_url"] = self.greeting_audio_url + if self.failure_message is not None: + config["failure_message"] = self.failure_message + if self.output_modalities is not None: + config["output_modalities"] = self.output_modalities + if self.greeting_configs is not None: + config["greeting_configs"] = _dump_optional_model(self.greeting_configs) + if self.template_variables is not None: + config["template_variables"] = self.template_variables + if self.vendor is not None: + config["vendor"] = self.vendor + if self.mcp_servers is not None: + config["mcp_servers"] = _ensure_mcp_transport(self.mcp_servers) + if self.max_history is not None: + config["max_history"] = self.max_history + + config["vendor"] = self.vendor or "custom" return config -class VertexAILLMOptions(GeminiOptions): +class VertexAILLM(BaseLLM): model_config = ConfigDict(extra="forbid") api_key: str = Field(..., description="Vertex AI access token or API key") project_id: str = Field(..., description="Google Cloud project ID") location: str = Field(..., description="Google Cloud location") + model: str = Field(..., description="Model name") + url: Optional[str] = Field(default=None, description="Custom API endpoint URL") + temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0) + top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0) + top_k: Optional[int] = Field(default=None, gt=0) + max_output_tokens: Optional[int] = Field(default=None, gt=0) + system_messages: Optional[List[Dict[str, Any]]] = Field(default=None) + greeting_message: Optional[str] = Field(default=None) + greeting_audio_url: Optional[str] = Field(default=None) + failure_message: Optional[str] = Field(default=None) + input_modalities: Optional[List[str]] = Field(default=None) + params: Optional[Dict[str, Any]] = Field(default=None) + headers: Optional[Dict[str, str]] = Field(default=None) + output_modalities: Optional[List[str]] = Field(default=None) + greeting_configs: Optional[LlmGreetingConfigs] = Field(default=None) + template_variables: Optional[Dict[str, str]] = Field(default=None) + vendor: Optional[str] = Field(default=None) + mcp_servers: Optional[List[Dict[str, Any]]] = Field(default=None) + max_history: Optional[int] = Field(default=None, gt=0, description="Maximum number of conversation history messages to cache") + def to_config(self) -> Dict[str, Any]: + # Named fields take precedence over anything in the generic params dict. + params: Dict[str, Any] = {"model": self.model, **(self.params or {})} + if self.temperature is not None: + params["temperature"] = self.temperature + if self.top_p is not None: + params["top_p"] = self.top_p + if self.top_k is not None: + params["top_k"] = self.top_k + if self.max_output_tokens is not None: + params["max_output_tokens"] = self.max_output_tokens + + url = self.url or ( + f"https://{self.location}-aiplatform.googleapis.com/v1/projects/" + f"{self.project_id}/locations/{self.location}/" + f"publishers/google/models/{self.model}:streamGenerateContent?alt=sse" + ) -class VertexAILLM(BaseLLM): - def __init__(self, **kwargs: Any): - self.options = VertexAILLMOptions(**kwargs) + config: Dict[str, Any] = { + "url": url, + "params": params, + "style": "gemini", + "input_modalities": self.input_modalities or ["text"], + } - def to_config(self) -> Dict[str, Any]: - options = _dump_optional_model(self.options) - options.pop("project_id", None) - options.pop("location", None) - if not options.get("url"): - options["url"] = ( - f"https://{self.options.location}-aiplatform.googleapis.com/v1/projects/" - f"{self.options.project_id}/locations/{self.options.location}/" - f"publishers/google/models/{self.options.model}:streamGenerateContent?alt=sse" - ) - config = Gemini(**options).to_config() - config["api_key"] = self.options.api_key + if self.system_messages is not None: + config["system_messages"] = self.system_messages + if self.headers is not None: + config["headers"] = self.headers + if self.greeting_message is not None: + config["greeting_message"] = self.greeting_message + if self.greeting_audio_url is not None: + config["greeting_audio_url"] = self.greeting_audio_url + if self.failure_message is not None: + config["failure_message"] = self.failure_message + if self.output_modalities is not None: + config["output_modalities"] = self.output_modalities + if self.greeting_configs is not None: + config["greeting_configs"] = _dump_optional_model(self.greeting_configs) + if self.template_variables is not None: + config["template_variables"] = self.template_variables + if self.vendor is not None: + config["vendor"] = self.vendor + if self.mcp_servers is not None: + config["mcp_servers"] = _ensure_mcp_transport(self.mcp_servers) + if self.max_history is not None: + config["max_history"] = self.max_history + + config["api_key"] = self.api_key return config -class AmazonBedrockOptions(BaseModel): +class AmazonBedrock(BaseLLM): model_config = ConfigDict(extra="forbid") access_key: str = Field(..., description="AWS access key ID") @@ -426,56 +593,51 @@ class AmazonBedrockOptions(BaseModel): mcp_servers: Optional[List[Dict[str, Any]]] = Field(default=None) max_history: Optional[int] = Field(default=None, gt=0, description="Maximum number of conversation history messages to cache") - -class AmazonBedrock(BaseLLM): - def __init__(self, **kwargs: Any): - self.options = AmazonBedrockOptions(**kwargs) - def to_config(self) -> Dict[str, Any]: - params: Dict[str, Any] = dict(self.options.params or {}) - if self.options.max_tokens is not None: - params["max_tokens"] = self.options.max_tokens - if self.options.temperature is not None: - params["temperature"] = self.options.temperature - if self.options.top_p is not None: - params["top_p"] = self.options.top_p + params: Dict[str, Any] = dict(self.params or {}) + if self.max_tokens is not None: + params["max_tokens"] = self.max_tokens + if self.temperature is not None: + params["temperature"] = self.temperature + if self.top_p is not None: + params["top_p"] = self.top_p config: Dict[str, Any] = { - "url": self.options.url or f"https://bedrock-runtime.{self.options.region}.amazonaws.com/model/{self.options.model}/converse-stream", - "access_key": self.options.access_key, - "secret_key": self.options.secret_key, - "region": self.options.region, - "model": self.options.model, + "url": self.url or f"https://bedrock-runtime.{self.region}.amazonaws.com/model/{self.model}/converse-stream", + "access_key": self.access_key, + "secret_key": self.secret_key, + "region": self.region, + "model": self.model, "params": params, "style": "bedrock", - "input_modalities": self.options.input_modalities or ["text"], + "input_modalities": self.input_modalities or ["text"], } - if self.options.system_messages is not None: - config["system_messages"] = self.options.system_messages - if self.options.headers is not None: - config["headers"] = self.options.headers - if self.options.greeting_message is not None: - config["greeting_message"] = self.options.greeting_message - if self.options.greeting_audio_url is not None: - config["greeting_audio_url"] = self.options.greeting_audio_url - if self.options.failure_message is not None: - config["failure_message"] = self.options.failure_message - if self.options.output_modalities is not None: - config["output_modalities"] = self.options.output_modalities - if self.options.greeting_configs is not None: - config["greeting_configs"] = _dump_optional_model(self.options.greeting_configs) - if self.options.template_variables is not None: - config["template_variables"] = self.options.template_variables - if self.options.vendor is not None: - config["vendor"] = self.options.vendor - if self.options.mcp_servers is not None: - config["mcp_servers"] = _ensure_mcp_transport(self.options.mcp_servers) - if self.options.max_history is not None: - config["max_history"] = self.options.max_history + if self.system_messages is not None: + config["system_messages"] = self.system_messages + if self.headers is not None: + config["headers"] = self.headers + if self.greeting_message is not None: + config["greeting_message"] = self.greeting_message + if self.greeting_audio_url is not None: + config["greeting_audio_url"] = self.greeting_audio_url + if self.failure_message is not None: + config["failure_message"] = self.failure_message + if self.output_modalities is not None: + config["output_modalities"] = self.output_modalities + if self.greeting_configs is not None: + config["greeting_configs"] = _dump_optional_model(self.greeting_configs) + if self.template_variables is not None: + config["template_variables"] = self.template_variables + if self.vendor is not None: + config["vendor"] = self.vendor + if self.mcp_servers is not None: + config["mcp_servers"] = _ensure_mcp_transport(self.mcp_servers) + if self.max_history is not None: + config["max_history"] = self.max_history return config -class DifyOptions(BaseModel): +class Dify(BaseLLM): model_config = ConfigDict(extra="forbid") api_key: str = Field(..., description="Dify API key") @@ -497,45 +659,40 @@ class DifyOptions(BaseModel): mcp_servers: Optional[List[Dict[str, Any]]] = Field(default=None) max_history: Optional[int] = Field(default=None, gt=0) - -class Dify(BaseLLM): - def __init__(self, **kwargs: Any): - self.options = DifyOptions(**kwargs) - def to_config(self) -> Dict[str, Any]: - params: Dict[str, Any] = {"model": self.options.model, **(self.options.params or {})} - if self.options.user is not None: - params["user"] = self.options.user - if self.options.conversation_id is not None: - params["conversation_id"] = self.options.conversation_id + params: Dict[str, Any] = {"model": self.model, **(self.params or {})} + if self.user is not None: + params["user"] = self.user + if self.conversation_id is not None: + params["conversation_id"] = self.conversation_id config: Dict[str, Any] = { - "url": self.options.url, - "api_key": self.options.api_key, + "url": self.url, + "api_key": self.api_key, "params": params, "style": "dify", - "input_modalities": self.options.input_modalities or ["text"], + "input_modalities": self.input_modalities or ["text"], } - if self.options.headers is not None: - config["headers"] = self.options.headers - if self.options.system_messages is not None: - config["system_messages"] = self.options.system_messages - if self.options.greeting_message is not None: - config["greeting_message"] = self.options.greeting_message - if self.options.greeting_audio_url is not None: - config["greeting_audio_url"] = self.options.greeting_audio_url - if self.options.failure_message is not None: - config["failure_message"] = self.options.failure_message - if self.options.output_modalities is not None: - config["output_modalities"] = self.options.output_modalities - if self.options.greeting_configs is not None: - config["greeting_configs"] = _dump_optional_model(self.options.greeting_configs) - if self.options.template_variables is not None: - config["template_variables"] = self.options.template_variables - if self.options.vendor is not None: - config["vendor"] = self.options.vendor - if self.options.mcp_servers is not None: - config["mcp_servers"] = _ensure_mcp_transport(self.options.mcp_servers) - if self.options.max_history is not None: - config["max_history"] = self.options.max_history + if self.headers is not None: + config["headers"] = self.headers + if self.system_messages is not None: + config["system_messages"] = self.system_messages + if self.greeting_message is not None: + config["greeting_message"] = self.greeting_message + if self.greeting_audio_url is not None: + config["greeting_audio_url"] = self.greeting_audio_url + if self.failure_message is not None: + config["failure_message"] = self.failure_message + if self.output_modalities is not None: + config["output_modalities"] = self.output_modalities + if self.greeting_configs is not None: + config["greeting_configs"] = _dump_optional_model(self.greeting_configs) + if self.template_variables is not None: + config["template_variables"] = self.template_variables + if self.vendor is not None: + config["vendor"] = self.vendor + if self.mcp_servers is not None: + config["mcp_servers"] = _ensure_mcp_transport(self.mcp_servers) + if self.max_history is not None: + config["max_history"] = self.max_history return config From 6e54199c179952be6f09dd50e7d2b648baab102d Mon Sep 17 00:00:00 2001 From: plutoless Date: Thu, 2 Jul 2026 06:18:34 -0700 Subject: [PATCH 05/12] refactor: tidy LLM standalone copies (Groq url, dead branches, error messages) --- src/agora_agent/agentkit/vendors/cn.py | 40 ++++++++++++------------- src/agora_agent/agentkit/vendors/llm.py | 23 ++------------ 2 files changed, 23 insertions(+), 40 deletions(-) diff --git a/src/agora_agent/agentkit/vendors/cn.py b/src/agora_agent/agentkit/vendors/cn.py index 695ad8f..b46bfd1 100644 --- a/src/agora_agent/agentkit/vendors/cn.py +++ b/src/agora_agent/agentkit/vendors/cn.py @@ -529,15 +529,15 @@ class AliyunLLM(BaseLLM): @model_validator(mode="after") def _validate_byok_params(self) -> "AliyunLLM": if not self.model: - raise ValueError("OpenAI requires model") + raise ValueError("AliyunLLM requires model") if self.api_key is not None and self.base_url is None: - raise ValueError("OpenAI requires base_url when api_key is set") + raise ValueError("AliyunLLM requires base_url when api_key is set") if self.api_key is None and self.base_url is not None: - raise ValueError("OpenAI base_url is only valid when api_key is set") + raise ValueError("AliyunLLM base_url is only valid when api_key is set") if self.api_key is None and self.model.strip().lower() not in _OPENAI_MANAGED_MODELS: - raise ValueError("OpenAI requires api_key unless using a supported Agora-managed model") + raise ValueError("AliyunLLM requires api_key unless using a supported Agora-managed model") if self.api_key is None and self.vendor is not None: - raise ValueError("OpenAI Agora-managed mode does not allow vendor") + raise ValueError("AliyunLLM Agora-managed mode does not allow vendor") return self def to_config(self) -> Dict[str, Any]: @@ -611,15 +611,15 @@ class BytedanceLLM(BaseLLM): @model_validator(mode="after") def _validate_byok_params(self) -> "BytedanceLLM": if not self.model: - raise ValueError("OpenAI requires model") + raise ValueError("BytedanceLLM requires model") if self.api_key is not None and self.base_url is None: - raise ValueError("OpenAI requires base_url when api_key is set") + raise ValueError("BytedanceLLM requires base_url when api_key is set") if self.api_key is None and self.base_url is not None: - raise ValueError("OpenAI base_url is only valid when api_key is set") + raise ValueError("BytedanceLLM base_url is only valid when api_key is set") if self.api_key is None and self.model.strip().lower() not in _OPENAI_MANAGED_MODELS: - raise ValueError("OpenAI requires api_key unless using a supported Agora-managed model") + raise ValueError("BytedanceLLM requires api_key unless using a supported Agora-managed model") if self.api_key is None and self.vendor is not None: - raise ValueError("OpenAI Agora-managed mode does not allow vendor") + raise ValueError("BytedanceLLM Agora-managed mode does not allow vendor") return self def to_config(self) -> Dict[str, Any]: @@ -693,15 +693,15 @@ class DeepSeekLLM(BaseLLM): @model_validator(mode="after") def _validate_byok_params(self) -> "DeepSeekLLM": if not self.model: - raise ValueError("OpenAI requires model") + raise ValueError("DeepSeekLLM requires model") if self.api_key is not None and self.base_url is None: - raise ValueError("OpenAI requires base_url when api_key is set") + raise ValueError("DeepSeekLLM requires base_url when api_key is set") if self.api_key is None and self.base_url is not None: - raise ValueError("OpenAI base_url is only valid when api_key is set") + raise ValueError("DeepSeekLLM base_url is only valid when api_key is set") if self.api_key is None and self.model.strip().lower() not in _OPENAI_MANAGED_MODELS: - raise ValueError("OpenAI requires api_key unless using a supported Agora-managed model") + raise ValueError("DeepSeekLLM requires api_key unless using a supported Agora-managed model") if self.api_key is None and self.vendor is not None: - raise ValueError("OpenAI Agora-managed mode does not allow vendor") + raise ValueError("DeepSeekLLM Agora-managed mode does not allow vendor") return self def to_config(self) -> Dict[str, Any]: @@ -775,15 +775,15 @@ class TencentLLM(BaseLLM): @model_validator(mode="after") def _validate_byok_params(self) -> "TencentLLM": if not self.model: - raise ValueError("OpenAI requires model") + raise ValueError("TencentLLM requires model") if self.api_key is not None and self.base_url is None: - raise ValueError("OpenAI requires base_url when api_key is set") + raise ValueError("TencentLLM requires base_url when api_key is set") if self.api_key is None and self.base_url is not None: - raise ValueError("OpenAI base_url is only valid when api_key is set") + raise ValueError("TencentLLM base_url is only valid when api_key is set") if self.api_key is None and self.model.strip().lower() not in _OPENAI_MANAGED_MODELS: - raise ValueError("OpenAI requires api_key unless using a supported Agora-managed model") + raise ValueError("TencentLLM requires api_key unless using a supported Agora-managed model") if self.api_key is None and self.vendor is not None: - raise ValueError("OpenAI Agora-managed mode does not allow vendor") + raise ValueError("TencentLLM Agora-managed mode does not allow vendor") return self def to_config(self) -> Dict[str, Any]: diff --git a/src/agora_agent/agentkit/vendors/llm.py b/src/agora_agent/agentkit/vendors/llm.py index 1b4a9ad..4e35491 100644 --- a/src/agora_agent/agentkit/vendors/llm.py +++ b/src/agora_agent/agentkit/vendors/llm.py @@ -350,15 +350,7 @@ class Groq(BaseLLM): @model_validator(mode="after") def _validate_byok_params(self) -> "Groq": if not self.model: - raise ValueError("OpenAI requires model") - if self.api_key is not None and self.base_url is None: - raise ValueError("OpenAI requires base_url when api_key is set") - if self.api_key is None and self.base_url is not None: - raise ValueError("OpenAI base_url is only valid when api_key is set") - if self.api_key is None and self.model.strip().lower() not in _OPENAI_MANAGED_MODELS: - raise ValueError("OpenAI requires api_key unless using a supported Agora-managed model") - if self.api_key is None and self.vendor is not None: - raise ValueError("OpenAI Agora-managed mode does not allow vendor") + raise ValueError("Groq requires model") return self def to_config(self) -> Dict[str, Any]: @@ -372,7 +364,7 @@ def to_config(self) -> Dict[str, Any]: params["top_p"] = self.top_p config: Dict[str, Any] = { - "url": self.base_url or "https://api.openai.com/v1/chat/completions", + "url": self.base_url, "params": params, "style": "openai", "input_modalities": self.input_modalities or ["text"], @@ -403,7 +395,6 @@ def to_config(self) -> Dict[str, Any]: if self.max_history is not None: config["max_history"] = self.max_history - config["url"] = self.base_url return config @@ -433,15 +424,7 @@ class CustomLLM(BaseLLM): @model_validator(mode="after") def _validate_byok_params(self) -> "CustomLLM": if not self.model: - raise ValueError("OpenAI requires model") - if self.api_key is not None and self.base_url is None: - raise ValueError("OpenAI requires base_url when api_key is set") - if self.api_key is None and self.base_url is not None: - raise ValueError("OpenAI base_url is only valid when api_key is set") - if self.api_key is None and self.model.strip().lower() not in _OPENAI_MANAGED_MODELS: - raise ValueError("OpenAI requires api_key unless using a supported Agora-managed model") - if self.api_key is None and self.vendor is not None: - raise ValueError("OpenAI Agora-managed mode does not allow vendor") + raise ValueError("CustomLLM requires model") return self def to_config(self) -> Dict[str, Any]: From b9d0feca59678a9add91329cb7ce846db5db9c7b Mon Sep 17 00:00:00 2001 From: plutoless Date: Thu, 2 Jul 2026 06:26:43 -0700 Subject: [PATCH 06/12] refactor: collapse TTS vendor configs; add resolved_sample_rate accessor Co-Authored-By: Claude Opus 4.8 --- src/agora_agent/agentkit/agent.py | 2 +- src/agora_agent/agentkit/vendors/base.py | 20 +- src/agora_agent/agentkit/vendors/cn.py | 270 +++++------- src/agora_agent/agentkit/vendors/tts.py | 525 +++++++++-------------- 4 files changed, 333 insertions(+), 484 deletions(-) diff --git a/src/agora_agent/agentkit/agent.py b/src/agora_agent/agentkit/agent.py index 9733ebd..e0080fd 100644 --- a/src/agora_agent/agentkit/agent.py +++ b/src/agora_agent/agentkit/agent.py @@ -473,7 +473,7 @@ def with_llm(self, vendor: BaseLLM) -> "Agent": return new_agent def with_tts(self, vendor: BaseTTS) -> "Agent": - sample_rate = vendor.sample_rate + sample_rate = vendor.resolved_sample_rate if ( self._avatar_required_sample_rate not in (None, 0) and sample_rate is not None diff --git a/src/agora_agent/agentkit/vendors/base.py b/src/agora_agent/agentkit/vendors/base.py index f2df4eb..b072ef6 100644 --- a/src/agora_agent/agentkit/vendors/base.py +++ b/src/agora_agent/agentkit/vendors/base.py @@ -28,16 +28,16 @@ def to_config(self) -> Dict[str, Any]: """Serialize the LLM configuration to a dict for the REST API.""" -class BaseTTS(ABC): +class BaseTTS(BaseModel, ABC): """Abstract base class for all TTS vendor implementations. - Subclasses must implement :meth:`to_config` and :attr:`sample_rate`. + Subclasses must implement :meth:`to_config`. - ``sample_rate`` is used by :class:`~agora_agent.agentkit.AgentSession` to - validate TTS/avatar compatibility at runtime (avatars require a specific - sample rate). Subclasses should return ``None`` when the user has not - explicitly configured a sample rate, which will cause a warning at session - start time rather than a hard error. + :attr:`resolved_sample_rate` is used by + :class:`~agora_agent.agentkit.AgentSession` to validate TTS/avatar + compatibility at runtime (avatars require a specific sample rate). It + returns ``None`` when the user has not explicitly configured a sample rate, + which will cause a warning at session start time rather than a hard error. """ @abstractmethod @@ -45,9 +45,9 @@ def to_config(self) -> Dict[str, Any]: """Serialize the TTS configuration to a dict for the REST API.""" @property - @abstractmethod - def sample_rate(self) -> Optional[int]: - """The configured sample rate in Hz, or ``None`` if not explicitly set.""" + def resolved_sample_rate(self) -> Optional[int]: + """The effective configured sample rate in Hz, or None if not set.""" + return getattr(self, "sample_rate", None) class BaseSTT(BaseModel, ABC): diff --git a/src/agora_agent/agentkit/vendors/cn.py b/src/agora_agent/agentkit/vendors/cn.py index b46bfd1..80ff398 100644 --- a/src/agora_agent/agentkit/vendors/cn.py +++ b/src/agora_agent/agentkit/vendors/cn.py @@ -149,7 +149,7 @@ def to_config(self) -> Dict[str, Any]: } -class TencentTTSOptions(BaseModel): +class TencentTTS(_BaseTTSCompat): model_config = ConfigDict(extra="forbid") app_id: str = Field(..., description="Tencent TTS app id") @@ -163,14 +163,9 @@ class TencentTTSOptions(BaseModel): additional_params: Optional[Dict[str, Any]] = Field(default=None, description="Additional Tencent TTS params") skip_patterns: Optional[List[int]] = Field(default=None) - -class TencentTTS(_BaseTTSCompat): - def __init__(self, **kwargs: Any): - self.options = TencentTTSOptions(**kwargs) - @property def sample_rate(self) -> Optional[int]: - audio_setting = (self.options.additional_params or {}).get("audio_setting") + audio_setting = (self.additional_params or {}).get("audio_setting") if isinstance(audio_setting, dict): sample_rate = audio_setting.get("sample_rate") if isinstance(sample_rate, int): @@ -178,34 +173,34 @@ def sample_rate(self) -> Optional[int]: return None def to_config(self) -> Dict[str, Any]: - params: Dict[str, Any] = dict(self.options.additional_params or {}) + params: Dict[str, Any] = dict(self.additional_params or {}) params.update( { - "app_id": self.options.app_id, - "secret_id": self.options.secret_id, - "secret_key": self.options.secret_key, - "voice_type": self.options.voice_type, + "app_id": self.app_id, + "secret_id": self.secret_id, + "secret_key": self.secret_key, + "voice_type": self.voice_type, } ) - if self.options.volume is not None: - params["volume"] = self.options.volume - if self.options.speed is not None: - params["speed"] = self.options.speed - if self.options.emotion_category is not None: - params["emotion_category"] = self.options.emotion_category - if self.options.emotion_intensity is not None: - params["emotion_intensity"] = self.options.emotion_intensity + if self.volume is not None: + params["volume"] = self.volume + if self.speed is not None: + params["speed"] = self.speed + if self.emotion_category is not None: + params["emotion_category"] = self.emotion_category + if self.emotion_intensity is not None: + params["emotion_intensity"] = self.emotion_intensity result: Dict[str, Any] = { "vendor": "tencent", "params": params, } - if self.options.skip_patterns is not None: - result["skip_patterns"] = self.options.skip_patterns + if self.skip_patterns is not None: + result["skip_patterns"] = self.skip_patterns return result -class BytedanceTTSOptions(BaseModel): +class BytedanceTTS(_BaseTTSCompat): model_config = ConfigDict(extra="forbid") token: str = Field(..., description="Bytedance TTS auth token") @@ -219,14 +214,9 @@ class BytedanceTTSOptions(BaseModel): additional_params: Optional[Dict[str, Any]] = Field(default=None, description="Additional Bytedance TTS params") skip_patterns: Optional[List[int]] = Field(default=None) - -class BytedanceTTS(_BaseTTSCompat): - def __init__(self, **kwargs: Any): - self.options = BytedanceTTSOptions(**kwargs) - @property def sample_rate(self) -> Optional[int]: - audio_setting = (self.options.additional_params or {}).get("audio_setting") + audio_setting = (self.additional_params or {}).get("audio_setting") if isinstance(audio_setting, dict): sample_rate = audio_setting.get("sample_rate") if isinstance(sample_rate, int): @@ -234,34 +224,34 @@ def sample_rate(self) -> Optional[int]: return None def to_config(self) -> Dict[str, Any]: - params: Dict[str, Any] = dict(self.options.additional_params or {}) + params: Dict[str, Any] = dict(self.additional_params or {}) params.update( { - "token": self.options.token, - "app_id": self.options.app_id, - "cluster": self.options.cluster, - "voice_type": self.options.voice_type, + "token": self.token, + "app_id": self.app_id, + "cluster": self.cluster, + "voice_type": self.voice_type, } ) - if self.options.speed_ratio is not None: - params["speed_ratio"] = self.options.speed_ratio - if self.options.volume_ratio is not None: - params["volume_ratio"] = self.options.volume_ratio - if self.options.pitch_ratio is not None: - params["pitch_ratio"] = self.options.pitch_ratio - if self.options.emotion is not None: - params["emotion"] = self.options.emotion + if self.speed_ratio is not None: + params["speed_ratio"] = self.speed_ratio + if self.volume_ratio is not None: + params["volume_ratio"] = self.volume_ratio + if self.pitch_ratio is not None: + params["pitch_ratio"] = self.pitch_ratio + if self.emotion is not None: + params["emotion"] = self.emotion result: Dict[str, Any] = { "vendor": "bytedance", "params": params, } - if self.options.skip_patterns is not None: - result["skip_patterns"] = self.options.skip_patterns + if self.skip_patterns is not None: + result["skip_patterns"] = self.skip_patterns return result -class BytedanceDuplexTTSOptions(BaseModel): +class BytedanceDuplexTTS(_BaseTTSCompat): model_config = ConfigDict(extra="forbid") token: str = Field(..., description="Bytedance Duplex TTS auth token") @@ -270,14 +260,9 @@ class BytedanceDuplexTTSOptions(BaseModel): additional_params: Optional[Dict[str, Any]] = Field(default=None, description="Additional Bytedance Duplex TTS params") skip_patterns: Optional[List[int]] = Field(default=None) - -class BytedanceDuplexTTS(_BaseTTSCompat): - def __init__(self, **kwargs: Any): - self.options = BytedanceDuplexTTSOptions(**kwargs) - @property def sample_rate(self) -> Optional[int]: - audio_setting = (self.options.additional_params or {}).get("audio_setting") + audio_setting = (self.additional_params or {}).get("audio_setting") if isinstance(audio_setting, dict): sample_rate = audio_setting.get("sample_rate") if isinstance(sample_rate, int): @@ -285,12 +270,12 @@ def sample_rate(self) -> Optional[int]: return None def to_config(self) -> Dict[str, Any]: - params: Dict[str, Any] = dict(self.options.additional_params or {}) + params: Dict[str, Any] = dict(self.additional_params or {}) params.update( { - "token": self.options.token, - "app_id": self.options.app_id, - "speaker": self.options.speaker, + "token": self.token, + "app_id": self.app_id, + "speaker": self.speaker, } ) @@ -298,12 +283,12 @@ def to_config(self) -> Dict[str, Any]: "vendor": "bytedance_duplex", "params": params, } - if self.options.skip_patterns is not None: - result["skip_patterns"] = self.options.skip_patterns + if self.skip_patterns is not None: + result["skip_patterns"] = self.skip_patterns return result -class CosyVoiceTTSOptions(BaseModel): +class CosyVoiceTTS(_BaseTTSCompat): model_config = ConfigDict(extra="forbid") api_key: Optional[str] = Field(default=None, description="CosyVoice API key") @@ -313,42 +298,26 @@ class CosyVoiceTTSOptions(BaseModel): additional_params: Optional[Dict[str, Any]] = Field(default=None, description="CosyVoice TTS params from REST doc") skip_patterns: Optional[List[int]] = Field(default=None) - -class CosyVoiceTTS(_BaseTTSCompat): - def __init__(self, **kwargs: Any): - self.options = CosyVoiceTTSOptions(**kwargs) - - @property - def sample_rate(self) -> Optional[int]: - if self.options.sample_rate is not None: - return self.options.sample_rate - audio_setting = (self.options.additional_params or {}).get("audio_setting") - if isinstance(audio_setting, dict): - sample_rate = audio_setting.get("sample_rate") - if isinstance(sample_rate, int): - return sample_rate - return None - def to_config(self) -> Dict[str, Any]: - params: Dict[str, Any] = dict(self.options.additional_params or {}) - if self.options.api_key is not None: - params["api_key"] = self.options.api_key - if self.options.model is not None: - params["model"] = self.options.model - if self.options.sample_rate is not None: - params["sample_rate"] = self.options.sample_rate - if self.options.voice is not None: - params["voice"] = self.options.voice + params: Dict[str, Any] = dict(self.additional_params or {}) + if self.api_key is not None: + params["api_key"] = self.api_key + if self.model is not None: + params["model"] = self.model + if self.sample_rate is not None: + params["sample_rate"] = self.sample_rate + if self.voice is not None: + params["voice"] = self.voice result: Dict[str, Any] = { "vendor": "cosyvoice", "params": params, } - if self.options.skip_patterns is not None: - result["skip_patterns"] = self.options.skip_patterns + if self.skip_patterns is not None: + result["skip_patterns"] = self.skip_patterns return result -class StepFunTTSOptions(BaseModel): +class StepFunTTS(_BaseTTSCompat): model_config = ConfigDict(extra="forbid") api_key: Optional[str] = Field(default=None, description="StepFun TTS API key") @@ -357,14 +326,9 @@ class StepFunTTSOptions(BaseModel): additional_params: Optional[Dict[str, Any]] = Field(default=None, description="StepFun TTS params from REST doc") skip_patterns: Optional[List[int]] = Field(default=None) - -class StepFunTTS(_BaseTTSCompat): - def __init__(self, **kwargs: Any): - self.options = StepFunTTSOptions(**kwargs) - @property def sample_rate(self) -> Optional[int]: - audio_setting = (self.options.additional_params or {}).get("audio_setting") + audio_setting = (self.additional_params or {}).get("audio_setting") if isinstance(audio_setting, dict): sample_rate = audio_setting.get("sample_rate") if isinstance(sample_rate, int): @@ -372,23 +336,23 @@ def sample_rate(self) -> Optional[int]: return None def to_config(self) -> Dict[str, Any]: - params: Dict[str, Any] = dict(self.options.additional_params or {}) - if self.options.api_key is not None: - params["api_key"] = self.options.api_key - if self.options.model is not None: - params["model"] = self.options.model - if self.options.voice_id is not None: - params["voice_id"] = self.options.voice_id + params: Dict[str, Any] = dict(self.additional_params or {}) + if self.api_key is not None: + params["api_key"] = self.api_key + if self.model is not None: + params["model"] = self.model + if self.voice_id is not None: + params["voice_id"] = self.voice_id result: Dict[str, Any] = { "vendor": "stepfun", "params": params, } - if self.options.skip_patterns is not None: - result["skip_patterns"] = self.options.skip_patterns + if self.skip_patterns is not None: + result["skip_patterns"] = self.skip_patterns return result -class MicrosoftTTSOptions(BaseModel): +class MicrosoftTTS(_BaseTTSCompat): model_config = ConfigDict(extra="forbid") key: str = Field(..., description="Azure subscription key") @@ -400,36 +364,27 @@ class MicrosoftTTSOptions(BaseModel): additional_params: Optional[Dict[str, Any]] = Field(default=None, description="Additional Microsoft TTS params") skip_patterns: Optional[List[int]] = Field(default=None) - -class MicrosoftTTS(_BaseTTSCompat): - def __init__(self, **kwargs: Any): - self.options = MicrosoftTTSOptions(**kwargs) - - @property - def sample_rate(self) -> Optional[int]: - return self.options.sample_rate - def to_config(self) -> Dict[str, Any]: - params: Dict[str, Any] = dict(self.options.additional_params or {}) + params: Dict[str, Any] = dict(self.additional_params or {}) params.update({ - "key": self.options.key, - "region": self.options.region, - "voice_name": self.options.voice_name, + "key": self.key, + "region": self.region, + "voice_name": self.voice_name, }) - if self.options.sample_rate is not None: - params["sample_rate"] = self.options.sample_rate - if self.options.speed is not None: - params["speed"] = self.options.speed - if self.options.volume is not None: - params["volume"] = self.options.volume + if self.sample_rate is not None: + params["sample_rate"] = self.sample_rate + if self.speed is not None: + params["speed"] = self.speed + if self.volume is not None: + params["volume"] = self.volume result: Dict[str, Any] = {"vendor": "microsoft", "params": params} - if self.options.skip_patterns is not None: - result["skip_patterns"] = self.options.skip_patterns + if self.skip_patterns is not None: + result["skip_patterns"] = self.skip_patterns return result -class MiniMaxTTSOptions(BaseModel): +class MiniMaxTTS(_BaseTTSCompat): model_config = ConfigDict(extra="forbid") key: Optional[str] = Field(default=None, description="MiniMax API key") @@ -449,57 +404,48 @@ class MiniMaxTTSOptions(BaseModel): skip_patterns: Optional[List[int]] = Field(default=None) @model_validator(mode="after") - def _validate_params(self) -> "MiniMaxTTSOptions": + def _validate_params(self) -> "MiniMaxTTS": if self.voice_id is not None and self.timber_weights is not None: raise ValueError("MiniMaxTTS requires exactly one of voice_id or timber_weights") if self.voice_id is None and self.timber_weights is None: raise ValueError("MiniMaxTTS requires exactly one of voice_id or timber_weights") return self - -class MiniMaxTTS(_BaseTTSCompat): - def __init__(self, **kwargs: Any): - self.options = MiniMaxTTSOptions(**kwargs) - - @property - def sample_rate(self) -> Optional[int]: - return self.options.sample_rate - def to_config(self) -> Dict[str, Any]: - params: Dict[str, Any] = dict(self.options.additional_params or {}) - if self.options.key is not None: - params["key"] = self.options.key - params["model"] = self.options.model + params: Dict[str, Any] = dict(self.additional_params or {}) + if self.key is not None: + params["key"] = self.key + params["model"] = self.model voice_setting: Dict[str, Any] = {} - if self.options.voice_id is not None: - voice_setting["voice_id"] = self.options.voice_id - if self.options.speed is not None: - voice_setting["speed"] = self.options.speed - if self.options.vol is not None: - voice_setting["vol"] = self.options.vol - if self.options.pitch is not None: - voice_setting["pitch"] = self.options.pitch - if self.options.emotion is not None: - voice_setting["emotion"] = self.options.emotion - if self.options.latex_read is not None: - voice_setting["latex_read"] = self.options.latex_read - if self.options.english_normalization is not None: - voice_setting["english_normalization"] = self.options.english_normalization + if self.voice_id is not None: + voice_setting["voice_id"] = self.voice_id + if self.speed is not None: + voice_setting["speed"] = self.speed + if self.vol is not None: + voice_setting["vol"] = self.vol + if self.pitch is not None: + voice_setting["pitch"] = self.pitch + if self.emotion is not None: + voice_setting["emotion"] = self.emotion + if self.latex_read is not None: + voice_setting["latex_read"] = self.latex_read + if self.english_normalization is not None: + voice_setting["english_normalization"] = self.english_normalization if voice_setting: params["voice_setting"] = voice_setting - if self.options.timber_weights is not None: - params["timber_weights"] = self.options.timber_weights - if self.options.sample_rate is not None: - params["audio_setting"] = {"sample_rate": self.options.sample_rate} - if self.options.pronunciation_dict is not None: - params["pronunciation_dict"] = self.options.pronunciation_dict - if self.options.language_boost is not None: - params["language_boost"] = self.options.language_boost + if self.timber_weights is not None: + params["timber_weights"] = self.timber_weights + if self.sample_rate is not None: + params["audio_setting"] = {"sample_rate": self.sample_rate} + if self.pronunciation_dict is not None: + params["pronunciation_dict"] = self.pronunciation_dict + if self.language_boost is not None: + params["language_boost"] = self.language_boost result: Dict[str, Any] = {"vendor": "minimax", "params": params} - if self.options.skip_patterns is not None: - result["skip_patterns"] = self.options.skip_patterns + if self.skip_patterns is not None: + result["skip_patterns"] = self.skip_patterns return result diff --git a/src/agora_agent/agentkit/vendors/tts.py b/src/agora_agent/agentkit/vendors/tts.py index 254399b..eda1740 100644 --- a/src/agora_agent/agentkit/vendors/tts.py +++ b/src/agora_agent/agentkit/vendors/tts.py @@ -1,11 +1,12 @@ from typing import Any, Dict, List, Optional -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import ConfigDict, Field, model_validator from .base import BaseTTS, CartesiaSampleRate, ElevenLabsSampleRate, GoogleTTSSampleRate, MicrosoftSampleRate from ..presets import MiniMaxPresetModels, OpenAITtsPresetModels -class ElevenLabsTTSOptions(BaseModel): + +class ElevenLabsTTS(BaseTTS): model_config = ConfigDict(extra="forbid") key: str = Field(..., description="ElevenLabs API key") @@ -20,42 +21,34 @@ class ElevenLabsTTSOptions(BaseModel): style: Optional[float] = Field(default=None, ge=0.0, le=1.0) use_speaker_boost: Optional[bool] = Field(default=None) -class ElevenLabsTTS(BaseTTS): - def __init__(self, **kwargs: Any): - self.options = ElevenLabsTTSOptions(**kwargs) - - @property - def sample_rate(self) -> Optional[int]: - return self.options.sample_rate - def to_config(self) -> Dict[str, Any]: params: Dict[str, Any] = { - "key": self.options.key, - "base_url": self.options.base_url, - "model_id": self.options.model_id, - "voice_id": self.options.voice_id, + "key": self.key, + "base_url": self.base_url, + "model_id": self.model_id, + "voice_id": self.voice_id, } - if self.options.sample_rate is not None: - params["sample_rate"] = self.options.sample_rate - if self.options.optimize_streaming_latency is not None: - params["optimize_streaming_latency"] = self.options.optimize_streaming_latency - if self.options.stability is not None: - params["stability"] = self.options.stability - if self.options.similarity_boost is not None: - params["similarity_boost"] = self.options.similarity_boost - if self.options.style is not None: - params["style"] = self.options.style - if self.options.use_speaker_boost is not None: - params["use_speaker_boost"] = self.options.use_speaker_boost + if self.sample_rate is not None: + params["sample_rate"] = self.sample_rate + if self.optimize_streaming_latency is not None: + params["optimize_streaming_latency"] = self.optimize_streaming_latency + if self.stability is not None: + params["stability"] = self.stability + if self.similarity_boost is not None: + params["similarity_boost"] = self.similarity_boost + if self.style is not None: + params["style"] = self.style + if self.use_speaker_boost is not None: + params["use_speaker_boost"] = self.use_speaker_boost result: Dict[str, Any] = {"vendor": "elevenlabs", "params": params} - if self.options.skip_patterns is not None: - result["skip_patterns"] = self.options.skip_patterns + if self.skip_patterns is not None: + result["skip_patterns"] = self.skip_patterns return result -class MicrosoftTTSOptions(BaseModel): +class MicrosoftTTS(BaseTTS): model_config = ConfigDict(extra="forbid") key: str = Field(..., description="Azure subscription key") @@ -67,36 +60,28 @@ class MicrosoftTTSOptions(BaseModel): additional_params: Optional[Dict[str, Any]] = Field(default=None, description="Additional Microsoft TTS params") skip_patterns: Optional[List[int]] = Field(default=None) -class MicrosoftTTS(BaseTTS): - def __init__(self, **kwargs: Any): - self.options = MicrosoftTTSOptions(**kwargs) - - @property - def sample_rate(self) -> Optional[int]: - return self.options.sample_rate - def to_config(self) -> Dict[str, Any]: - params: Dict[str, Any] = dict(self.options.additional_params or {}) + params: Dict[str, Any] = dict(self.additional_params or {}) params.update({ - "key": self.options.key, - "region": self.options.region, - "voice_name": self.options.voice_name, + "key": self.key, + "region": self.region, + "voice_name": self.voice_name, }) - if self.options.sample_rate is not None: - params["sample_rate"] = self.options.sample_rate - if self.options.speed is not None: - params["speed"] = self.options.speed - if self.options.volume is not None: - params["volume"] = self.options.volume + if self.sample_rate is not None: + params["sample_rate"] = self.sample_rate + if self.speed is not None: + params["speed"] = self.speed + if self.volume is not None: + params["volume"] = self.volume result: Dict[str, Any] = {"vendor": "microsoft", "params": params} - if self.options.skip_patterns is not None: - result["skip_patterns"] = self.options.skip_patterns + if self.skip_patterns is not None: + result["skip_patterns"] = self.skip_patterns return result -class OpenAITTSOptions(BaseModel): +class OpenAITTS(BaseTTS): model_config = ConfigDict(extra="forbid") api_key: Optional[str] = Field(default=None, description="OpenAI API key") @@ -108,7 +93,7 @@ class OpenAITTSOptions(BaseModel): skip_patterns: Optional[List[int]] = Field(default=None) @model_validator(mode="after") - def _validate_byok_params(self) -> "OpenAITTSOptions": + def _validate_byok_params(self) -> "OpenAITTS": if self.api_key is not None: missing = [ name @@ -127,37 +112,33 @@ def _validate_byok_params(self) -> "OpenAITTSOptions": raise ValueError("OpenAITTS base_url is only valid when api_key is set") return self -class OpenAITTS(BaseTTS): - def __init__(self, **kwargs: Any): - self.options = OpenAITTSOptions(**kwargs) - @property def sample_rate(self) -> Optional[int]: return 24000 def to_config(self) -> Dict[str, Any]: params: Dict[str, Any] = { - "voice": self.options.voice, + "voice": self.voice, } - if self.options.api_key is not None: - params["api_key"] = self.options.api_key - params["base_url"] = self.options.base_url - params["model"] = self.options.model - elif self.options.model is not None: - params["model"] = self.options.model - - if self.options.instructions is not None: - params["instructions"] = self.options.instructions - if self.options.speed is not None: - params["speed"] = self.options.speed + if self.api_key is not None: + params["api_key"] = self.api_key + params["base_url"] = self.base_url + params["model"] = self.model + elif self.model is not None: + params["model"] = self.model + + if self.instructions is not None: + params["instructions"] = self.instructions + if self.speed is not None: + params["speed"] = self.speed result: Dict[str, Any] = {"vendor": "openai", "params": params} - if self.options.skip_patterns is not None: - result["skip_patterns"] = self.options.skip_patterns + if self.skip_patterns is not None: + result["skip_patterns"] = self.skip_patterns return result -class CartesiaTTSOptions(BaseModel): +class CartesiaTTS(BaseTTS): model_config = ConfigDict(extra="forbid") api_key: str = Field(..., description="Cartesia API key") @@ -168,35 +149,27 @@ class CartesiaTTSOptions(BaseModel): sample_rate: Optional[CartesiaSampleRate] = Field(default=None, description="Sample rate in Hz") skip_patterns: Optional[List[int]] = Field(default=None) -class CartesiaTTS(BaseTTS): - def __init__(self, **kwargs: Any): - self.options = CartesiaTTSOptions(**kwargs) - - @property - def sample_rate(self) -> Optional[int]: - return self.options.sample_rate - def to_config(self) -> Dict[str, Any]: params: Dict[str, Any] = { - "api_key": self.options.api_key, - "model_id": self.options.model_id, - "voice": {"mode": "id", "id": self.options.voice_id}, + "api_key": self.api_key, + "model_id": self.model_id, + "voice": {"mode": "id", "id": self.voice_id}, } - if self.options.base_url is not None: - params["base_url"] = self.options.base_url - if self.options.sample_rate is not None: - params["output_format"] = {"container": "raw", "sample_rate": self.options.sample_rate} - if self.options.language is not None: - params["language"] = self.options.language + if self.base_url is not None: + params["base_url"] = self.base_url + if self.sample_rate is not None: + params["output_format"] = {"container": "raw", "sample_rate": self.sample_rate} + if self.language is not None: + params["language"] = self.language result: Dict[str, Any] = {"vendor": "cartesia", "params": params} - if self.options.skip_patterns is not None: - result["skip_patterns"] = self.options.skip_patterns + if self.skip_patterns is not None: + result["skip_patterns"] = self.skip_patterns return result -class GoogleTTSOptions(BaseModel): +class GoogleTTS(BaseTTS): model_config = ConfigDict(extra="forbid") key: str = Field(..., description="Google Cloud service account credentials JSON string") @@ -205,32 +178,28 @@ class GoogleTTSOptions(BaseModel): sample_rate_hertz: Optional[GoogleTTSSampleRate] = Field(default=None, description="Sample rate in Hz") skip_patterns: Optional[List[int]] = Field(default=None) -class GoogleTTS(BaseTTS): - def __init__(self, **kwargs: Any): - self.options = GoogleTTSOptions(**kwargs) - @property def sample_rate(self) -> Optional[int]: - return self.options.sample_rate_hertz + return self.sample_rate_hertz def to_config(self) -> Dict[str, Any]: params: Dict[str, Any] = { - "credentials": self.options.key, - "VoiceSelectionParams": {"name": self.options.voice_name}, + "credentials": self.key, + "VoiceSelectionParams": {"name": self.voice_name}, } - if self.options.language_code is not None: - params["VoiceSelectionParams"]["language_code"] = self.options.language_code - if self.options.sample_rate_hertz is not None: - params["AudioConfig"] = {"sample_rate_hertz": self.options.sample_rate_hertz} + if self.language_code is not None: + params["VoiceSelectionParams"]["language_code"] = self.language_code + if self.sample_rate_hertz is not None: + params["AudioConfig"] = {"sample_rate_hertz": self.sample_rate_hertz} result: Dict[str, Any] = {"vendor": "google", "params": params} - if self.options.skip_patterns is not None: - result["skip_patterns"] = self.options.skip_patterns + if self.skip_patterns is not None: + result["skip_patterns"] = self.skip_patterns return result -class AmazonTTSOptions(BaseModel): +class AmazonTTS(BaseTTS): model_config = ConfigDict(extra="forbid") access_key: str = Field(..., description="AWS access key") @@ -240,30 +209,26 @@ class AmazonTTSOptions(BaseModel): engine: str = Field(..., description="Amazon Polly engine type") skip_patterns: Optional[List[int]] = Field(default=None) -class AmazonTTS(BaseTTS): - def __init__(self, **kwargs: Any): - self.options = AmazonTTSOptions(**kwargs) - @property def sample_rate(self) -> Optional[int]: return None def to_config(self) -> Dict[str, Any]: params: Dict[str, Any] = { - "aws_access_key_id": self.options.access_key, - "aws_secret_access_key": self.options.secret_key, - "region_name": self.options.region, - "voice": self.options.voice_id, - "engine": self.options.engine, + "aws_access_key_id": self.access_key, + "aws_secret_access_key": self.secret_key, + "region_name": self.region, + "voice": self.voice_id, + "engine": self.engine, } result: Dict[str, Any] = {"vendor": "amazon", "params": params} - if self.options.skip_patterns is not None: - result["skip_patterns"] = self.options.skip_patterns + if self.skip_patterns is not None: + result["skip_patterns"] = self.skip_patterns return result -class DeepgramTTSOptions(BaseModel): +class DeepgramTTS(BaseTTS): model_config = ConfigDict(extra="forbid") api_key: str = Field(..., description="Deepgram API key") @@ -273,32 +238,24 @@ class DeepgramTTSOptions(BaseModel): additional_params: Optional[Dict[str, Any]] = Field(default=None, description="Additional Deepgram TTS parameters") skip_patterns: Optional[List[int]] = Field(default=None) -class DeepgramTTS(BaseTTS): - def __init__(self, **kwargs: Any): - self.options = DeepgramTTSOptions(**kwargs) - - @property - def sample_rate(self) -> Optional[int]: - return self.options.sample_rate - def to_config(self) -> Dict[str, Any]: - params: Dict[str, Any] = dict(self.options.additional_params or {}) + params: Dict[str, Any] = dict(self.additional_params or {}) params.update({ - "api_key": self.options.api_key, - "model": self.options.model, + "api_key": self.api_key, + "model": self.model, }) - if self.options.base_url is not None: - params["base_url"] = self.options.base_url - if self.options.sample_rate is not None: - params["sample_rate"] = self.options.sample_rate + if self.base_url is not None: + params["base_url"] = self.base_url + if self.sample_rate is not None: + params["sample_rate"] = self.sample_rate result: Dict[str, Any] = {"vendor": "deepgram", "params": params} - if self.options.skip_patterns is not None: - result["skip_patterns"] = self.options.skip_patterns + if self.skip_patterns is not None: + result["skip_patterns"] = self.skip_patterns return result -class HumeAITTSOptions(BaseModel): +class HumeAITTS(BaseTTS): model_config = ConfigDict(extra="forbid") key: str = Field(..., description="Hume AI API key") @@ -310,37 +267,33 @@ class HumeAITTSOptions(BaseModel): trailing_silence: Optional[float] = Field(default=None, description="Trailing silence in seconds") skip_patterns: Optional[List[int]] = Field(default=None) -class HumeAITTS(BaseTTS): - def __init__(self, **kwargs: Any): - self.options = HumeAITTSOptions(**kwargs) - @property def sample_rate(self) -> Optional[int]: return None def to_config(self) -> Dict[str, Any]: params: Dict[str, Any] = { - "key": self.options.key, - "voice_id": self.options.voice_id, - "provider": self.options.provider, + "key": self.key, + "voice_id": self.voice_id, + "provider": self.provider, } - if self.options.config_id is not None: - params["config_id"] = self.options.config_id - if self.options.base_url is not None: - params["base_url"] = self.options.base_url - if self.options.speed is not None: - params["speed"] = self.options.speed - if self.options.trailing_silence is not None: - params["trailing_silence"] = self.options.trailing_silence + if self.config_id is not None: + params["config_id"] = self.config_id + if self.base_url is not None: + params["base_url"] = self.base_url + if self.speed is not None: + params["speed"] = self.speed + if self.trailing_silence is not None: + params["trailing_silence"] = self.trailing_silence result: Dict[str, Any] = {"vendor": "humeai", "params": params} - if self.options.skip_patterns is not None: - result["skip_patterns"] = self.options.skip_patterns + if self.skip_patterns is not None: + result["skip_patterns"] = self.skip_patterns return result -class RimeTTSOptions(BaseModel): +class RimeTTS(BaseTTS): model_config = ConfigDict(extra="forbid") key: str = Field(..., description="Rime API key") @@ -349,30 +302,26 @@ class RimeTTSOptions(BaseModel): base_url: Optional[str] = Field(default=None, description="WebSocket URL") skip_patterns: Optional[List[int]] = Field(default=None) -class RimeTTS(BaseTTS): - def __init__(self, **kwargs: Any): - self.options = RimeTTSOptions(**kwargs) - @property def sample_rate(self) -> Optional[int]: return None def to_config(self) -> Dict[str, Any]: params: Dict[str, Any] = { - "api_key": self.options.key, - "speaker": self.options.speaker, - "modelId": self.options.model_id, + "api_key": self.key, + "speaker": self.speaker, + "modelId": self.model_id, } - if self.options.base_url is not None: - params["base_url"] = self.options.base_url + if self.base_url is not None: + params["base_url"] = self.base_url result: Dict[str, Any] = {"vendor": "rime", "params": params} - if self.options.skip_patterns is not None: - result["skip_patterns"] = self.options.skip_patterns + if self.skip_patterns is not None: + result["skip_patterns"] = self.skip_patterns return result -class FishAudioTTSOptions(BaseModel): +class FishAudioTTS(BaseTTS): model_config = ConfigDict(extra="forbid") key: str = Field(..., description="Fish Audio API key") @@ -380,28 +329,24 @@ class FishAudioTTSOptions(BaseModel): backend: str = Field(..., description="Backend") skip_patterns: Optional[List[int]] = Field(default=None) -class FishAudioTTS(BaseTTS): - def __init__(self, **kwargs: Any): - self.options = FishAudioTTSOptions(**kwargs) - @property def sample_rate(self) -> Optional[int]: return None def to_config(self) -> Dict[str, Any]: params: Dict[str, Any] = { - "api_key": self.options.key, - "reference_id": self.options.reference_id, - "backend": self.options.backend, + "api_key": self.key, + "reference_id": self.reference_id, + "backend": self.backend, } result: Dict[str, Any] = {"vendor": "fishaudio", "params": params} - if self.options.skip_patterns is not None: - result["skip_patterns"] = self.options.skip_patterns + if self.skip_patterns is not None: + result["skip_patterns"] = self.skip_patterns return result -class MiniMaxTTSOptions(BaseModel): +class MiniMaxTTS(BaseTTS): model_config = ConfigDict(extra="forbid") key: Optional[str] = Field(default=None, description="MiniMax API key") @@ -423,7 +368,7 @@ class MiniMaxTTSOptions(BaseModel): skip_patterns: Optional[List[int]] = Field(default=None) @model_validator(mode="after") - def _validate_byok_params(self) -> "MiniMaxTTSOptions": + def _validate_byok_params(self) -> "MiniMaxTTS": if self.voice_id is not None and self.timber_weights is not None: raise ValueError("MiniMaxTTS requires exactly one of voice_id or timber_weights") if self.key is not None: @@ -446,59 +391,51 @@ def _validate_byok_params(self) -> "MiniMaxTTSOptions": raise ValueError("MiniMaxTTS requires key unless using a supported Agora-managed model") return self -class MiniMaxTTS(BaseTTS): - def __init__(self, **kwargs: Any): - self.options = MiniMaxTTSOptions(**kwargs) - - @property - def sample_rate(self) -> Optional[int]: - return self.options.sample_rate - def to_config(self) -> Dict[str, Any]: - params: Dict[str, Any] = dict(self.options.additional_params or {}) - if self.options.key is not None: - params["key"] = self.options.key - params["group_id"] = self.options.group_id - params["model"] = self.options.model - if self.options.url is not None: - params["url"] = self.options.url + params: Dict[str, Any] = dict(self.additional_params or {}) + if self.key is not None: + params["key"] = self.key + params["group_id"] = self.group_id + params["model"] = self.model + if self.url is not None: + params["url"] = self.url voice_setting: Dict[str, Any] = {} - if self.options.voice_id is not None: - voice_setting["voice_id"] = self.options.voice_id - if self.options.speed is not None: - voice_setting["speed"] = self.options.speed - if self.options.vol is not None: - voice_setting["vol"] = self.options.vol - if self.options.pitch is not None: - voice_setting["pitch"] = self.options.pitch - if self.options.emotion is not None: - voice_setting["emotion"] = self.options.emotion - if self.options.latex_read is not None: - voice_setting["latex_read"] = self.options.latex_read - if self.options.english_normalization is not None: - voice_setting["english_normalization"] = self.options.english_normalization + if self.voice_id is not None: + voice_setting["voice_id"] = self.voice_id + if self.speed is not None: + voice_setting["speed"] = self.speed + if self.vol is not None: + voice_setting["vol"] = self.vol + if self.pitch is not None: + voice_setting["pitch"] = self.pitch + if self.emotion is not None: + voice_setting["emotion"] = self.emotion + if self.latex_read is not None: + voice_setting["latex_read"] = self.latex_read + if self.english_normalization is not None: + voice_setting["english_normalization"] = self.english_normalization if voice_setting: params["voice_setting"] = voice_setting - if self.options.timber_weights is not None: - params["timber_weights"] = self.options.timber_weights - if self.options.sample_rate is not None: - params["audio_setting"] = {"sample_rate": self.options.sample_rate} - if self.options.pronunciation_dict is not None: - params["pronunciation_dict"] = self.options.pronunciation_dict - if self.options.language_boost is not None: - params["language_boost"] = self.options.language_boost + if self.timber_weights is not None: + params["timber_weights"] = self.timber_weights + if self.sample_rate is not None: + params["audio_setting"] = {"sample_rate": self.sample_rate} + if self.pronunciation_dict is not None: + params["pronunciation_dict"] = self.pronunciation_dict + if self.language_boost is not None: + params["language_boost"] = self.language_boost result: Dict[str, Any] = {"vendor": "minimax", "params": params} - if self.options.key is None: + if self.key is None: # Preset path: model not in params; stored as top-level hint for preset # inference. Stripped by strip_inferred_preset_fields before the POST body. - result["_minimax_preset_model"] = self.options.model - if self.options.skip_patterns is not None: - result["skip_patterns"] = self.options.skip_patterns + result["_minimax_preset_model"] = self.model + if self.skip_patterns is not None: + result["skip_patterns"] = self.skip_patterns return result -class SarvamTTSOptions(BaseModel): +class SarvamTTS(BaseTTS): model_config = ConfigDict(extra="forbid") key: str = Field(..., description="Sarvam API subscription key") @@ -510,36 +447,28 @@ class SarvamTTSOptions(BaseModel): sample_rate: Optional[int] = Field(default=None, description="Audio sample rate in Hz") skip_patterns: Optional[List[int]] = Field(default=None) -class SarvamTTS(BaseTTS): - def __init__(self, **kwargs: Any): - self.options = SarvamTTSOptions(**kwargs) - - @property - def sample_rate(self) -> Optional[int]: - return None - def to_config(self) -> Dict[str, Any]: params: Dict[str, Any] = { - "api_subscription_key": self.options.key, - "speaker": self.options.speaker, - "target_language_code": self.options.target_language_code, + "api_subscription_key": self.key, + "speaker": self.speaker, + "target_language_code": self.target_language_code, } - if self.options.pitch is not None: - params["pitch"] = self.options.pitch - if self.options.pace is not None: - params["pace"] = self.options.pace - if self.options.loudness is not None: - params["loudness"] = self.options.loudness - if self.options.sample_rate is not None: - params["sample_rate"] = self.options.sample_rate + if self.pitch is not None: + params["pitch"] = self.pitch + if self.pace is not None: + params["pace"] = self.pace + if self.loudness is not None: + params["loudness"] = self.loudness + if self.sample_rate is not None: + params["sample_rate"] = self.sample_rate result: Dict[str, Any] = {"vendor": "sarvam", "params": params} - if self.options.skip_patterns is not None: - result["skip_patterns"] = self.options.skip_patterns + if self.skip_patterns is not None: + result["skip_patterns"] = self.skip_patterns return result -class MurfTTSOptions(BaseModel): +class MurfTTS(BaseTTS): model_config = ConfigDict(extra="forbid") key: str = Field(..., description="Murf API key") @@ -552,39 +481,31 @@ class MurfTTSOptions(BaseModel): sample_rate: Optional[int] = Field(default=None, description="Audio sample rate") skip_patterns: Optional[List[int]] = Field(default=None) -class MurfTTS(BaseTTS): - def __init__(self, **kwargs: Any): - self.options = MurfTTSOptions(**kwargs) - - @property - def sample_rate(self) -> Optional[int]: - return None - def to_config(self) -> Dict[str, Any]: - params: Dict[str, Any] = {"api_key": self.options.key} - - if self.options.base_url is not None: - params["base_url"] = self.options.base_url - if self.options.voice_id is not None: - params["voiceId"] = self.options.voice_id - if self.options.locale is not None: - params["locale"] = self.options.locale - if self.options.rate is not None: - params["rate"] = self.options.rate - if self.options.pitch is not None: - params["pitch"] = self.options.pitch - if self.options.model is not None: - params["model"] = self.options.model - if self.options.sample_rate is not None: - params["sample_rate"] = self.options.sample_rate + params: Dict[str, Any] = {"api_key": self.key} + + if self.base_url is not None: + params["base_url"] = self.base_url + if self.voice_id is not None: + params["voiceId"] = self.voice_id + if self.locale is not None: + params["locale"] = self.locale + if self.rate is not None: + params["rate"] = self.rate + if self.pitch is not None: + params["pitch"] = self.pitch + if self.model is not None: + params["model"] = self.model + if self.sample_rate is not None: + params["sample_rate"] = self.sample_rate result: Dict[str, Any] = {"vendor": "murf", "params": params} - if self.options.skip_patterns is not None: - result["skip_patterns"] = self.options.skip_patterns + if self.skip_patterns is not None: + result["skip_patterns"] = self.skip_patterns return result -class GenericTTSOptions(BaseModel): +class GenericTTS(BaseTTS): model_config = ConfigDict(extra="forbid") url: str = Field(..., description="Callback address of the generic TTS service") @@ -599,42 +520,33 @@ class GenericTTSOptions(BaseModel): additional_params: Optional[Dict[str, Any]] = Field(default=None, description="Additional generic TTS params") skip_patterns: Optional[List[int]] = Field(default=None) - -class GenericTTS(BaseTTS): - def __init__(self, **kwargs: Any): - self.options = GenericTTSOptions(**kwargs) - - @property - def sample_rate(self) -> Optional[int]: - return self.options.sample_rate - def to_config(self) -> Dict[str, Any]: - params: Dict[str, Any] = dict(self.options.additional_params or {}) - if self.options.api_key is not None: - params["api_key"] = self.options.api_key - params["model"] = self.options.model - params["voice"] = self.options.voice - if self.options.speed is not None: - params["speed"] = self.options.speed - if self.options.sample_rate is not None: - params["sample_rate"] = self.options.sample_rate - if self.options.response_format is not None: - params["response_format"] = self.options.response_format - if self.options.instruction is not None: - params["instruction"] = self.options.instruction + params: Dict[str, Any] = dict(self.additional_params or {}) + if self.api_key is not None: + params["api_key"] = self.api_key + params["model"] = self.model + params["voice"] = self.voice + if self.speed is not None: + params["speed"] = self.speed + if self.sample_rate is not None: + params["sample_rate"] = self.sample_rate + if self.response_format is not None: + params["response_format"] = self.response_format + if self.instruction is not None: + params["instruction"] = self.instruction result: Dict[str, Any] = { "vendor": "generic", - "url": self.options.url, - "headers": self.options.headers, + "url": self.url, + "headers": self.headers, "params": params, } - if self.options.skip_patterns is not None: - result["skip_patterns"] = self.options.skip_patterns + if self.skip_patterns is not None: + result["skip_patterns"] = self.skip_patterns return result -class XaiTTSOptions(BaseModel): +class XaiTTS(BaseTTS): model_config = ConfigDict(extra="forbid") api_key: str = Field(..., description="xAI API key") @@ -644,25 +556,16 @@ class XaiTTSOptions(BaseModel): additional_params: Optional[Dict[str, Any]] = Field(default=None, description="Additional xAI TTS params") skip_patterns: Optional[List[int]] = Field(default=None) - -class XaiTTS(BaseTTS): - def __init__(self, **kwargs: Any): - self.options = XaiTTSOptions(**kwargs) - - @property - def sample_rate(self) -> Optional[int]: - return self.options.sample_rate - def to_config(self) -> Dict[str, Any]: - params: Dict[str, Any] = dict(self.options.additional_params or {}) - params["api_key"] = self.options.api_key - params["language"] = self.options.language - if self.options.voice_id is not None: - params["voice_id"] = self.options.voice_id - if self.options.sample_rate is not None: - params["sample_rate"] = self.options.sample_rate + params: Dict[str, Any] = dict(self.additional_params or {}) + params["api_key"] = self.api_key + params["language"] = self.language + if self.voice_id is not None: + params["voice_id"] = self.voice_id + if self.sample_rate is not None: + params["sample_rate"] = self.sample_rate result: Dict[str, Any] = {"vendor": "xai", "params": params} - if self.options.skip_patterns is not None: - result["skip_patterns"] = self.options.skip_patterns + if self.skip_patterns is not None: + result["skip_patterns"] = self.skip_patterns return result From 13339ed8c74704cf5bd07dac9400813fe9c89e72 Mon Sep 17 00:00:00 2001 From: plutoless Date: Thu, 2 Jul 2026 06:29:59 -0700 Subject: [PATCH 07/12] fix: preserve resolved_sample_rate behavior for Sarvam/Murf/CosyVoice --- src/agora_agent/agentkit/vendors/cn.py | 11 +++++++++++ src/agora_agent/agentkit/vendors/tts.py | 8 ++++++++ 2 files changed, 19 insertions(+) diff --git a/src/agora_agent/agentkit/vendors/cn.py b/src/agora_agent/agentkit/vendors/cn.py index 80ff398..96af7e6 100644 --- a/src/agora_agent/agentkit/vendors/cn.py +++ b/src/agora_agent/agentkit/vendors/cn.py @@ -298,6 +298,17 @@ class CosyVoiceTTS(_BaseTTSCompat): additional_params: Optional[Dict[str, Any]] = Field(default=None, description="CosyVoice TTS params from REST doc") skip_patterns: Optional[List[int]] = Field(default=None) + @property + def resolved_sample_rate(self) -> Optional[int]: + if self.sample_rate is not None: + return self.sample_rate + audio_setting = (self.additional_params or {}).get("audio_setting") + if isinstance(audio_setting, dict): + sample_rate = audio_setting.get("sample_rate") + if isinstance(sample_rate, int): + return sample_rate + return None + def to_config(self) -> Dict[str, Any]: params: Dict[str, Any] = dict(self.additional_params or {}) if self.api_key is not None: diff --git a/src/agora_agent/agentkit/vendors/tts.py b/src/agora_agent/agentkit/vendors/tts.py index eda1740..054b6cd 100644 --- a/src/agora_agent/agentkit/vendors/tts.py +++ b/src/agora_agent/agentkit/vendors/tts.py @@ -447,6 +447,10 @@ class SarvamTTS(BaseTTS): sample_rate: Optional[int] = Field(default=None, description="Audio sample rate in Hz") skip_patterns: Optional[List[int]] = Field(default=None) + @property + def resolved_sample_rate(self) -> Optional[int]: + return None + def to_config(self) -> Dict[str, Any]: params: Dict[str, Any] = { "api_subscription_key": self.key, @@ -481,6 +485,10 @@ class MurfTTS(BaseTTS): sample_rate: Optional[int] = Field(default=None, description="Audio sample rate") skip_patterns: Optional[List[int]] = Field(default=None) + @property + def resolved_sample_rate(self) -> Optional[int]: + return None + def to_config(self) -> Dict[str, Any]: params: Dict[str, Any] = {"api_key": self.key} From 289e780c78b0ce353cbd4862ae0c7251cb6c436d Mon Sep 17 00:00:00 2001 From: plutoless Date: Thu, 2 Jul 2026 06:43:43 -0700 Subject: [PATCH 08/12] refactor: collapse MLLM vendor configs into pydantic models Co-Authored-By: Claude Sonnet 4.6 --- src/agora_agent/agentkit/vendors/base.py | 2 +- src/agora_agent/agentkit/vendors/mllm.py | 272 +++++++++++------------ 2 files changed, 128 insertions(+), 146 deletions(-) diff --git a/src/agora_agent/agentkit/vendors/base.py b/src/agora_agent/agentkit/vendors/base.py index b072ef6..2ec1b7b 100644 --- a/src/agora_agent/agentkit/vendors/base.py +++ b/src/agora_agent/agentkit/vendors/base.py @@ -62,7 +62,7 @@ def to_config(self) -> Dict[str, Any]: """Serialize the STT configuration to a dict for the REST API.""" -class BaseMLLM(ABC): +class BaseMLLM(BaseModel, ABC): """Abstract base class for all MLLM (multimodal LLM) vendor implementations. When an MLLM is configured via :meth:`~agora_agent.agentkit.Agent.with_mllm`, diff --git a/src/agora_agent/agentkit/vendors/mllm.py b/src/agora_agent/agentkit/vendors/mllm.py index 4a0846a..71d94fe 100644 --- a/src/agora_agent/agentkit/vendors/mllm.py +++ b/src/agora_agent/agentkit/vendors/mllm.py @@ -1,7 +1,6 @@ -import warnings from typing import Any, Dict, List, Optional -from pydantic import BaseModel, ConfigDict, Field +from pydantic import ConfigDict, Field from ...types.mllm_turn_detection import MllmTurnDetection from .base import BaseMLLM @@ -9,7 +8,7 @@ MllmTurnDetectionConfig = MllmTurnDetection -class OpenAIRealtimeOptions(BaseModel): +class OpenAIRealtime(BaseMLLM): model_config = ConfigDict(extra="forbid") api_key: str = Field(..., description="OpenAI API key") @@ -26,49 +25,45 @@ class OpenAIRealtimeOptions(BaseModel): turn_detection: Optional[MllmTurnDetectionConfig] = Field(default=None, description="MLLM turn detection configuration") failure_message: Optional[str] = Field(default=None, description="Message played on failure") -class OpenAIRealtime(BaseMLLM): - def __init__(self, **kwargs: Any): - self.options = OpenAIRealtimeOptions(**kwargs) - def to_config(self) -> Dict[str, Any]: config: Dict[str, Any] = { "vendor": "openai", - "api_key": self.options.api_key, + "api_key": self.api_key, } - if self.options.url is not None: - config["url"] = self.options.url + if self.url is not None: + config["url"] = self.url if ( - self.options.model is not None - or self.options.params is not None - or self.options.voice is not None - or self.options.instructions is not None - or self.options.input_audio_transcription is not None + self.model is not None + or self.params is not None + or self.voice is not None + or self.instructions is not None + or self.input_audio_transcription is not None ): - params: Dict[str, Any] = {} - if self.options.model is not None: - params["model"] = self.options.model - if self.options.params is not None: - params.update(self.options.params) - if self.options.voice is not None: - params["voice"] = self.options.voice - if self.options.instructions is not None: - params["instructions"] = self.options.instructions - if self.options.input_audio_transcription is not None: - params["input_audio_transcription"] = self.options.input_audio_transcription - config["params"] = params - if self.options.greeting_message is not None: - config["greeting_message"] = self.options.greeting_message - if self.options.input_modalities is not None: - config["input_modalities"] = self.options.input_modalities - if self.options.output_modalities is not None: - config["output_modalities"] = self.options.output_modalities - if self.options.messages is not None: - config["messages"] = self.options.messages - if self.options.failure_message is not None: - config["failure_message"] = self.options.failure_message - if self.options.turn_detection is not None: - config["turn_detection"] = self.options.turn_detection + inner_params: Dict[str, Any] = {} + if self.model is not None: + inner_params["model"] = self.model + if self.params is not None: + inner_params.update(self.params) + if self.voice is not None: + inner_params["voice"] = self.voice + if self.instructions is not None: + inner_params["instructions"] = self.instructions + if self.input_audio_transcription is not None: + inner_params["input_audio_transcription"] = self.input_audio_transcription + config["params"] = inner_params + if self.greeting_message is not None: + config["greeting_message"] = self.greeting_message + if self.input_modalities is not None: + config["input_modalities"] = self.input_modalities + if self.output_modalities is not None: + config["output_modalities"] = self.output_modalities + if self.messages is not None: + config["messages"] = self.messages + if self.failure_message is not None: + config["failure_message"] = self.failure_message + if self.turn_detection is not None: + config["turn_detection"] = self.turn_detection return config @@ -77,7 +72,9 @@ def to_config(self) -> Dict[str, Any]: # is deprecated and reserved naming for future XaiSTT / XaiTTS cascading vendors. -class XaiGrokOptions(BaseModel): +class XaiGrok(BaseMLLM): + """xAI Grok MLLM vendor (`mllm.vendor`: ``xai``).""" + model_config = ConfigDict(extra="forbid") api_key: str = Field(..., description="xAI API key") @@ -93,46 +90,39 @@ class XaiGrokOptions(BaseModel): turn_detection: Optional[MllmTurnDetectionConfig] = Field(default=None, description="MLLM turn detection configuration") failure_message: Optional[str] = Field(default=None, description="Message played on failure") - -class XaiGrok(BaseMLLM): - """xAI Grok MLLM vendor (`mllm.vendor`: ``xai``).""" - - def __init__(self, **kwargs: Any): - self.options = XaiGrokOptions(**kwargs) - def to_config(self) -> Dict[str, Any]: - params: Dict[str, Any] = dict(self.options.params or {}) - if self.options.voice is not None: - params["voice"] = self.options.voice - if self.options.language is not None: - params["language"] = self.options.language - if self.options.sample_rate is not None: - params["sample_rate"] = self.options.sample_rate + inner_params: Dict[str, Any] = dict(self.params or {}) + if self.voice is not None: + inner_params["voice"] = self.voice + if self.language is not None: + inner_params["language"] = self.language + if self.sample_rate is not None: + inner_params["sample_rate"] = self.sample_rate config: Dict[str, Any] = { "vendor": "xai", - "api_key": self.options.api_key, - "url": self.options.url, - "params": params, + "api_key": self.api_key, + "url": self.url, + "params": inner_params, } - if self.options.greeting_message is not None: - config["greeting_message"] = self.options.greeting_message - if self.options.input_modalities is not None: - config["input_modalities"] = self.options.input_modalities - if self.options.output_modalities is not None: - config["output_modalities"] = self.options.output_modalities - if self.options.messages is not None: - config["messages"] = self.options.messages - if self.options.failure_message is not None: - config["failure_message"] = self.options.failure_message - if self.options.turn_detection is not None: - config["turn_detection"] = self.options.turn_detection + if self.greeting_message is not None: + config["greeting_message"] = self.greeting_message + if self.input_modalities is not None: + config["input_modalities"] = self.input_modalities + if self.output_modalities is not None: + config["output_modalities"] = self.output_modalities + if self.messages is not None: + config["messages"] = self.messages + if self.failure_message is not None: + config["failure_message"] = self.failure_message + if self.turn_detection is not None: + config["turn_detection"] = self.turn_detection return config -class VertexAIOptions(BaseModel): +class VertexAI(BaseMLLM): model_config = ConfigDict(extra="forbid") model: str = Field(..., description="Model name") @@ -155,55 +145,51 @@ class VertexAIOptions(BaseModel): turn_detection: Optional[MllmTurnDetectionConfig] = Field(default=None, description="MLLM turn detection configuration") failure_message: Optional[str] = Field(default=None, description="Message played on failure") -class VertexAI(BaseMLLM): - def __init__(self, **kwargs: Any): - self.options = VertexAIOptions(**kwargs) - def to_config(self) -> Dict[str, Any]: # additional_params spread first so that explicit fields always win, # matching the TypeScript SDK. - params: Dict[str, Any] = dict(self.options.additional_params or {}) - params["model"] = self.options.model - params["project_id"] = self.options.project_id - params["location"] = self.options.location - params["adc_credentials_string"] = self.options.adc_credentials_string - if self.options.instructions is not None: - params["instructions"] = self.options.instructions - if self.options.voice is not None: - params["voice"] = self.options.voice - if self.options.affective_dialog is not None: - params["affective_dialog"] = self.options.affective_dialog - if self.options.proactive_audio is not None: - params["proactive_audio"] = self.options.proactive_audio - if self.options.transcribe_agent is not None: - params["transcribe_agent"] = self.options.transcribe_agent - if self.options.transcribe_user is not None: - params["transcribe_user"] = self.options.transcribe_user - if self.options.http_options is not None: - params["http_options"] = self.options.http_options + inner_params: Dict[str, Any] = dict(self.additional_params or {}) + inner_params["model"] = self.model + inner_params["project_id"] = self.project_id + inner_params["location"] = self.location + inner_params["adc_credentials_string"] = self.adc_credentials_string + if self.instructions is not None: + inner_params["instructions"] = self.instructions + if self.voice is not None: + inner_params["voice"] = self.voice + if self.affective_dialog is not None: + inner_params["affective_dialog"] = self.affective_dialog + if self.proactive_audio is not None: + inner_params["proactive_audio"] = self.proactive_audio + if self.transcribe_agent is not None: + inner_params["transcribe_agent"] = self.transcribe_agent + if self.transcribe_user is not None: + inner_params["transcribe_user"] = self.transcribe_user + if self.http_options is not None: + inner_params["http_options"] = self.http_options config: Dict[str, Any] = { "vendor": "vertexai", - "url": self.options.url if self.options.url is not None else "", - "params": params, + "url": self.url if self.url is not None else "", + "params": inner_params, } - if self.options.greeting_message is not None: - config["greeting_message"] = self.options.greeting_message - if self.options.input_modalities is not None: - config["input_modalities"] = self.options.input_modalities - if self.options.output_modalities is not None: - config["output_modalities"] = self.options.output_modalities - if self.options.messages is not None: - config["messages"] = self.options.messages - if self.options.failure_message is not None: - config["failure_message"] = self.options.failure_message - if self.options.turn_detection is not None: - config["turn_detection"] = self.options.turn_detection + if self.greeting_message is not None: + config["greeting_message"] = self.greeting_message + if self.input_modalities is not None: + config["input_modalities"] = self.input_modalities + if self.output_modalities is not None: + config["output_modalities"] = self.output_modalities + if self.messages is not None: + config["messages"] = self.messages + if self.failure_message is not None: + config["failure_message"] = self.failure_message + if self.turn_detection is not None: + config["turn_detection"] = self.turn_detection return config -class GeminiLiveOptions(BaseModel): +class GeminiLive(BaseMLLM): model_config = ConfigDict(extra="forbid") api_key: str = Field(..., description="Google API key") @@ -224,47 +210,43 @@ class GeminiLiveOptions(BaseModel): turn_detection: Optional[MllmTurnDetectionConfig] = Field(default=None, description="MLLM turn detection configuration") failure_message: Optional[str] = Field(default=None, description="Message played on failure") -class GeminiLive(BaseMLLM): - def __init__(self, **kwargs: Any): - self.options = GeminiLiveOptions(**kwargs) - def to_config(self) -> Dict[str, Any]: - params: Dict[str, Any] = {} - if self.options.additional_params is not None: - params.update(self.options.additional_params) - params["model"] = self.options.model - if self.options.instructions is not None: - params["instructions"] = self.options.instructions - if self.options.voice is not None: - params["voice"] = self.options.voice - if self.options.affective_dialog is not None: - params["affective_dialog"] = self.options.affective_dialog - if self.options.proactive_audio is not None: - params["proactive_audio"] = self.options.proactive_audio - if self.options.transcribe_agent is not None: - params["transcribe_agent"] = self.options.transcribe_agent - if self.options.transcribe_user is not None: - params["transcribe_user"] = self.options.transcribe_user - if self.options.http_options is not None: - params["http_options"] = self.options.http_options + inner_params: Dict[str, Any] = {} + if self.additional_params is not None: + inner_params.update(self.additional_params) + inner_params["model"] = self.model + if self.instructions is not None: + inner_params["instructions"] = self.instructions + if self.voice is not None: + inner_params["voice"] = self.voice + if self.affective_dialog is not None: + inner_params["affective_dialog"] = self.affective_dialog + if self.proactive_audio is not None: + inner_params["proactive_audio"] = self.proactive_audio + if self.transcribe_agent is not None: + inner_params["transcribe_agent"] = self.transcribe_agent + if self.transcribe_user is not None: + inner_params["transcribe_user"] = self.transcribe_user + if self.http_options is not None: + inner_params["http_options"] = self.http_options config: Dict[str, Any] = { "vendor": "gemini", - "api_key": self.options.api_key, - "url": self.options.url if self.options.url is not None else "", - "params": params, + "api_key": self.api_key, + "url": self.url if self.url is not None else "", + "params": inner_params, } - if self.options.greeting_message is not None: - config["greeting_message"] = self.options.greeting_message - if self.options.input_modalities is not None: - config["input_modalities"] = self.options.input_modalities - if self.options.output_modalities is not None: - config["output_modalities"] = self.options.output_modalities - if self.options.messages is not None: - config["messages"] = self.options.messages - if self.options.failure_message is not None: - config["failure_message"] = self.options.failure_message - if self.options.turn_detection is not None: - config["turn_detection"] = self.options.turn_detection + if self.greeting_message is not None: + config["greeting_message"] = self.greeting_message + if self.input_modalities is not None: + config["input_modalities"] = self.input_modalities + if self.output_modalities is not None: + config["output_modalities"] = self.output_modalities + if self.messages is not None: + config["messages"] = self.messages + if self.failure_message is not None: + config["failure_message"] = self.failure_message + if self.turn_detection is not None: + config["turn_detection"] = self.turn_detection return config From 0be1014e333ddba6db82de376198330ceca827b2 Mon Sep 17 00:00:00 2001 From: plutoless Date: Thu, 2 Jul 2026 06:51:30 -0700 Subject: [PATCH 09/12] refactor: collapse avatar vendor configs; HeyGen copy + SenseTime populate_by_name Co-Authored-By: Claude Opus 4.8 --- src/agora_agent/agentkit/vendors/avatar.py | 165 ++++++++++----------- src/agora_agent/agentkit/vendors/base.py | 2 +- src/agora_agent/agentkit/vendors/cn.py | 76 +++++----- tests/custom/test_sensetime_avatar.py | 8 + 4 files changed, 122 insertions(+), 129 deletions(-) diff --git a/src/agora_agent/agentkit/vendors/avatar.py b/src/agora_agent/agentkit/vendors/avatar.py index 7bbd8dc..88b3f32 100644 --- a/src/agora_agent/agentkit/vendors/avatar.py +++ b/src/agora_agent/agentkit/vendors/avatar.py @@ -1,7 +1,7 @@ import warnings from typing import Any, Dict, Optional -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import ConfigDict, Field, field_validator from .base import BaseAvatar @@ -10,7 +10,7 @@ AKOOL_SAMPLE_RATE = 16000 -class LiveAvatarAvatarOptions(BaseModel): +class LiveAvatarAvatar(BaseAvatar): model_config = ConfigDict(extra="forbid") api_key: str = Field(..., description="LiveAvatar API key") @@ -31,51 +31,61 @@ def validate_quality(cls, v: str) -> str: raise ValueError(f"Invalid quality '{v}'. Must be one of: {', '.join(valid)}") return v - -class LiveAvatarAvatar(BaseAvatar): - def __init__(self, **kwargs: Any): - self.options = LiveAvatarAvatarOptions(**kwargs) - @property def required_sample_rate(self) -> int: return LIVEAVATAR_SAMPLE_RATE def to_config(self) -> Dict[str, Any]: params: Dict[str, Any] = { - "api_key": self.options.api_key, - "quality": self.options.quality, - "agora_uid": self.options.agora_uid, + "api_key": self.api_key, + "quality": self.quality, + "agora_uid": self.agora_uid, } - if self.options.agora_token is not None: - params["agora_token"] = self.options.agora_token - if self.options.avatar_id is not None: - params["avatar_id"] = self.options.avatar_id - if self.options.disable_idle_timeout is not None: - params["disable_idle_timeout"] = self.options.disable_idle_timeout - if self.options.activity_idle_timeout is not None: - params["activity_idle_timeout"] = self.options.activity_idle_timeout - if self.options.additional_params is not None: - params = {**self.options.additional_params, **params} - - enable = self.options.enable if self.options.enable is not None else True + if self.agora_token is not None: + params["agora_token"] = self.agora_token + if self.avatar_id is not None: + params["avatar_id"] = self.avatar_id + if self.disable_idle_timeout is not None: + params["disable_idle_timeout"] = self.disable_idle_timeout + if self.activity_idle_timeout is not None: + params["activity_idle_timeout"] = self.activity_idle_timeout + if self.additional_params is not None: + params = {**self.additional_params, **params} + + enable = self.enable if self.enable is not None else True return {"enable": enable, "vendor": "liveavatar", "params": params} -class HeyGenAvatarOptions(LiveAvatarAvatarOptions): - """Deprecated: use :class:`LiveAvatarAvatarOptions` instead.""" - - class HeyGenAvatar(BaseAvatar): """Deprecated: HeyGen has been renamed to LiveAvatar. Use LiveAvatarAvatar instead.""" - def __init__(self, **kwargs: Any): + model_config = ConfigDict(extra="forbid") + + api_key: str = Field(..., description="LiveAvatar API key") + quality: str = Field(..., description="Avatar quality: low, medium, or high") + agora_uid: str = Field(..., description="Agora UID for the avatar stream") + agora_token: Optional[str] = Field(default=None, description="RTC token for avatar authentication") + avatar_id: Optional[str] = Field(default=None, description="Avatar ID") + enable: Optional[bool] = Field(default=None, description="Enable avatar (default: true)") + disable_idle_timeout: Optional[bool] = Field(default=None, description="Whether to disable idle timeout") + activity_idle_timeout: Optional[int] = Field(default=None, description="Idle timeout in seconds") + additional_params: Optional[Dict[str, Any]] = Field(default=None, description="Additional vendor-specific parameters") + + @field_validator("quality") + @classmethod + def validate_quality(cls, v: str) -> str: + valid = ("low", "medium", "high") + if v not in valid: + raise ValueError(f"Invalid quality '{v}'. Must be one of: {', '.join(valid)}") + return v + + def model_post_init(self, __context: Any) -> None: warnings.warn( "HeyGenAvatar is deprecated; use LiveAvatarAvatar instead.", DeprecationWarning, stacklevel=2, ) - self.options = HeyGenAvatarOptions(**kwargs) @property def required_sample_rate(self) -> int: @@ -83,27 +93,27 @@ def required_sample_rate(self) -> int: def to_config(self) -> Dict[str, Any]: params: Dict[str, Any] = { - "api_key": self.options.api_key, - "quality": self.options.quality, - "agora_uid": self.options.agora_uid, + "api_key": self.api_key, + "quality": self.quality, + "agora_uid": self.agora_uid, } - if self.options.agora_token is not None: - params["agora_token"] = self.options.agora_token - if self.options.avatar_id is not None: - params["avatar_id"] = self.options.avatar_id - if self.options.disable_idle_timeout is not None: - params["disable_idle_timeout"] = self.options.disable_idle_timeout - if self.options.activity_idle_timeout is not None: - params["activity_idle_timeout"] = self.options.activity_idle_timeout - if self.options.additional_params is not None: - params = {**self.options.additional_params, **params} - - enable = self.options.enable if self.options.enable is not None else True + if self.agora_token is not None: + params["agora_token"] = self.agora_token + if self.avatar_id is not None: + params["avatar_id"] = self.avatar_id + if self.disable_idle_timeout is not None: + params["disable_idle_timeout"] = self.disable_idle_timeout + if self.activity_idle_timeout is not None: + params["activity_idle_timeout"] = self.activity_idle_timeout + if self.additional_params is not None: + params = {**self.additional_params, **params} + + enable = self.enable if self.enable is not None else True return {"enable": enable, "vendor": "heygen", "params": params} -class AkoolAvatarOptions(BaseModel): +class AkoolAvatar(BaseAvatar): model_config = ConfigDict(extra="forbid") api_key: str = Field(..., description="Akool API key") @@ -111,30 +121,25 @@ class AkoolAvatarOptions(BaseModel): enable: Optional[bool] = Field(default=None, description="Enable avatar (default: true)") additional_params: Optional[Dict[str, Any]] = Field(default=None, description="Additional vendor-specific parameters") - -class AkoolAvatar(BaseAvatar): - def __init__(self, **kwargs: Any): - self.options = AkoolAvatarOptions(**kwargs) - @property def required_sample_rate(self) -> int: return AKOOL_SAMPLE_RATE def to_config(self) -> Dict[str, Any]: params: Dict[str, Any] = { - "api_key": self.options.api_key, + "api_key": self.api_key, } - if self.options.avatar_id is not None: - params["avatar_id"] = self.options.avatar_id - if self.options.additional_params is not None: - params = {**self.options.additional_params, **params} + if self.avatar_id is not None: + params["avatar_id"] = self.avatar_id + if self.additional_params is not None: + params = {**self.additional_params, **params} - enable = self.options.enable if self.options.enable is not None else True + enable = self.enable if self.enable is not None else True return {"enable": enable, "vendor": "akool", "params": params} -class GenericAvatarOptions(BaseModel): +class GenericAvatar(BaseAvatar): model_config = ConfigDict(extra="forbid") api_key: str = Field(..., description="Generic avatar provider API key") @@ -147,37 +152,32 @@ class GenericAvatarOptions(BaseModel): enable: Optional[bool] = Field(default=None, description="Enable avatar (default: true)") additional_params: Optional[Dict[str, Any]] = Field(default=None, description="Additional vendor-specific parameters") - -class GenericAvatar(BaseAvatar): - def __init__(self, **kwargs: Any): - self.options = GenericAvatarOptions(**kwargs) - @property def required_sample_rate(self) -> int: return 0 def to_config(self) -> Dict[str, Any]: params: Dict[str, Any] = { - "api_key": self.options.api_key, - "api_base_url": self.options.api_base_url, - "avatar_id": self.options.avatar_id, - "agora_uid": self.options.agora_uid, + "api_key": self.api_key, + "api_base_url": self.api_base_url, + "avatar_id": self.avatar_id, + "agora_uid": self.agora_uid, } - if self.options.agora_appid is not None: - params["agora_appid"] = self.options.agora_appid - if self.options.agora_token is not None: - params["agora_token"] = self.options.agora_token - if self.options.agora_channel is not None: - params["agora_channel"] = self.options.agora_channel - if self.options.additional_params is not None: - params = {**self.options.additional_params, **params} + if self.agora_appid is not None: + params["agora_appid"] = self.agora_appid + if self.agora_token is not None: + params["agora_token"] = self.agora_token + if self.agora_channel is not None: + params["agora_channel"] = self.agora_channel + if self.additional_params is not None: + params = {**self.additional_params, **params} - enable = self.options.enable if self.options.enable is not None else True + enable = self.enable if self.enable is not None else True return {"enable": enable, "vendor": "generic", "params": params} -class AnamAvatarOptions(BaseModel): +class AnamAvatar(BaseAvatar): model_config = ConfigDict(extra="forbid") api_key: str = Field(..., description="Anam API key") @@ -185,23 +185,18 @@ class AnamAvatarOptions(BaseModel): enable: Optional[bool] = Field(default=None, description="Enable avatar (default: true)") additional_params: Optional[Dict[str, Any]] = Field(default=None, description="Additional vendor-specific parameters") - -class AnamAvatar(BaseAvatar): - def __init__(self, **kwargs: Any): - self.options = AnamAvatarOptions(**kwargs) - @property def required_sample_rate(self) -> int: return 0 def to_config(self) -> Dict[str, Any]: params: Dict[str, Any] = { - "api_key": self.options.api_key, - "avatar_id": self.options.avatar_id, + "api_key": self.api_key, + "avatar_id": self.avatar_id, } - if self.options.additional_params is not None: - params = {**self.options.additional_params, **params} + if self.additional_params is not None: + params = {**self.additional_params, **params} - enable = self.options.enable if self.options.enable is not None else True + enable = self.enable if self.enable is not None else True return {"enable": enable, "vendor": "anam", "params": params} diff --git a/src/agora_agent/agentkit/vendors/base.py b/src/agora_agent/agentkit/vendors/base.py index 2ec1b7b..680f668 100644 --- a/src/agora_agent/agentkit/vendors/base.py +++ b/src/agora_agent/agentkit/vendors/base.py @@ -76,7 +76,7 @@ def to_config(self) -> Dict[str, Any]: """Serialize the MLLM configuration to a dict for the REST API.""" -class BaseAvatar(ABC): +class BaseAvatar(BaseModel, ABC): """Abstract base class for all avatar vendor implementations. Avatars render a visual representation of the agent and impose a specific TTS diff --git a/src/agora_agent/agentkit/vendors/cn.py b/src/agora_agent/agentkit/vendors/cn.py index 96af7e6..ee888c0 100644 --- a/src/agora_agent/agentkit/vendors/cn.py +++ b/src/agora_agent/agentkit/vendors/cn.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import ConfigDict, Field, model_validator from .avatar import BaseAvatar from .base import BaseLLM @@ -788,8 +788,8 @@ def to_config(self) -> Dict[str, Any]: return config -class SenseTimeAvatarOptions(BaseModel): - model_config = ConfigDict(extra="forbid") +class SenseTimeAvatar(BaseAvatar): + model_config = ConfigDict(extra="forbid", populate_by_name=True) agora_token: Optional[str] = Field(default=None, description="RTC token for avatar publisher; generated by AgentSession when omitted") agora_uid: str = Field(..., description="Avatar RTC publisher uid") @@ -799,34 +799,29 @@ class SenseTimeAvatarOptions(BaseModel): enable: Optional[bool] = Field(default=None) additional_params: Optional[Dict[str, Any]] = Field(default=None) - -class SenseTimeAvatar(BaseAvatar): - def __init__(self, **kwargs: Any): - self.options = SenseTimeAvatarOptions(**kwargs) - @property def required_sample_rate(self) -> int: return 0 def to_config(self) -> Dict[str, Any]: params: Dict[str, Any] = { - "agora_uid": self.options.agora_uid, - "app_key": self.options.app_key, + "agora_uid": self.agora_uid, + "app_key": self.app_key, } - if self.options.agora_token is not None: - params["agora_token"] = self.options.agora_token - if self.options.app_id is not None: - params["appId"] = self.options.app_id - if self.options.scene_list is not None: - params["sceneList"] = self.options.scene_list - if self.options.additional_params is not None: - params = {**self.options.additional_params, **params} - - enable = self.options.enable if self.options.enable is not None else True + if self.agora_token is not None: + params["agora_token"] = self.agora_token + if self.app_id is not None: + params["appId"] = self.app_id + if self.scene_list is not None: + params["sceneList"] = self.scene_list + if self.additional_params is not None: + params = {**self.additional_params, **params} + + enable = self.enable if self.enable is not None else True return {"enable": enable, "vendor": "sensetime", "params": params} -class SpatiusAvatarOptions(BaseModel): +class SpatiusAvatar(BaseAvatar): model_config = ConfigDict(extra="forbid") spatius_api_key: str = Field(..., description="Spatius API key") @@ -840,32 +835,27 @@ class SpatiusAvatarOptions(BaseModel): enable: Optional[bool] = Field(default=None) additional_params: Optional[Dict[str, Any]] = Field(default=None) - -class SpatiusAvatar(BaseAvatar): - def __init__(self, **kwargs: Any): - self.options = SpatiusAvatarOptions(**kwargs) - @property def required_sample_rate(self) -> int: - return self.options.sample_rate or 0 + return self.sample_rate or 0 def to_config(self) -> Dict[str, Any]: params: Dict[str, Any] = { - "spatius_api_key": self.options.spatius_api_key, - "spatius_app_id": self.options.spatius_app_id, - "spatius_avatar_id": self.options.spatius_avatar_id, - "agora_uid": self.options.agora_uid, + "spatius_api_key": self.spatius_api_key, + "spatius_app_id": self.spatius_app_id, + "spatius_avatar_id": self.spatius_avatar_id, + "agora_uid": self.agora_uid, } - if self.options.agora_token is not None: - params["agora_token"] = self.options.agora_token - if self.options.region is not None: - params["region"] = self.options.region - if self.options.sample_rate is not None: - params["sample_rate"] = self.options.sample_rate - if self.options.session_expire_minutes is not None: - params["session_expire_minutes"] = self.options.session_expire_minutes - if self.options.additional_params is not None: - params = {**self.options.additional_params, **params} - - enable = self.options.enable if self.options.enable is not None else True + if self.agora_token is not None: + params["agora_token"] = self.agora_token + if self.region is not None: + params["region"] = self.region + if self.sample_rate is not None: + params["sample_rate"] = self.sample_rate + if self.session_expire_minutes is not None: + params["session_expire_minutes"] = self.session_expire_minutes + if self.additional_params is not None: + params = {**self.additional_params, **params} + + enable = self.enable if self.enable is not None else True return {"enable": enable, "vendor": "spatius", "params": params} diff --git a/tests/custom/test_sensetime_avatar.py b/tests/custom/test_sensetime_avatar.py index bc582fb..4d69e73 100644 --- a/tests/custom/test_sensetime_avatar.py +++ b/tests/custom/test_sensetime_avatar.py @@ -40,6 +40,14 @@ def _scene_list() -> List[Dict[str, Any]]: return [{"digital_role": {"face_feature_id": "role-1"}}] +def test_sensetime_avatar_accepts_snake_case() -> None: + from agora_agent.agentkit.vendors.cn import SenseTimeAvatar + camel = SenseTimeAvatar(agora_uid="2", appId="app", app_key="key").to_config() + snake = SenseTimeAvatar(agora_uid="2", app_id="app", app_key="key").to_config() + assert camel == snake + assert snake["params"]["appId"] == "app" + + def test_sensetime_avatar_to_config_shape() -> None: config = SenseTimeAvatar( agora_token="avatar-token", From 38619a782887655d4394bb22e8b9ca499026f714 Mon Sep 17 00:00:00 2001 From: plutoless Date: Thu, 2 Jul 2026 06:57:59 -0700 Subject: [PATCH 10/12] fix: correct HeyGen deprecation warning stacklevel for model_post_init --- src/agora_agent/agentkit/vendors/avatar.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/agora_agent/agentkit/vendors/avatar.py b/src/agora_agent/agentkit/vendors/avatar.py index 88b3f32..23a709e 100644 --- a/src/agora_agent/agentkit/vendors/avatar.py +++ b/src/agora_agent/agentkit/vendors/avatar.py @@ -81,10 +81,12 @@ def validate_quality(cls, v: str) -> str: return v def model_post_init(self, __context: Any) -> None: + # stacklevel=3: warn() <- model_post_init <- pydantic __init__ <- user code, + # so the warning points at the user's construction site, not pydantic internals. warnings.warn( "HeyGenAvatar is deprecated; use LiveAvatarAvatar instead.", DeprecationWarning, - stacklevel=2, + stacklevel=3, ) @property From d7304eabb2b66222c518560184fbed04fd8a3329 Mon Sep 17 00:00:00 2001 From: plutoless Date: Thu, 2 Jul 2026 07:03:27 -0700 Subject: [PATCH 11/12] test: align vendor construction with typed pydantic constructors (mypy) Co-Authored-By: Claude Opus 4.8 --- tests/custom/test_agentkit_vendors.py | 2 +- tests/custom/test_sensetime_avatar.py | 14 +++++++------- tests/custom/test_vendor_collapse_golden.py | 9 +++++---- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/tests/custom/test_agentkit_vendors.py b/tests/custom/test_agentkit_vendors.py index 888c87a..e74f70a 100644 --- a/tests/custom/test_agentkit_vendors.py +++ b/tests/custom/test_agentkit_vendors.py @@ -89,7 +89,7 @@ def test_anam_avatar_serializes_avatar_id() -> None: def test_anam_avatar_requires_avatar_id() -> None: with pytest.raises(ValidationError): - AnamAvatar(api_key="anam-key") + AnamAvatar(api_key="anam-key") # type: ignore[call-arg] def test_vertex_ai_explicit_fields_override_additional_params(): diff --git a/tests/custom/test_sensetime_avatar.py b/tests/custom/test_sensetime_avatar.py index 4d69e73..91975b2 100644 --- a/tests/custom/test_sensetime_avatar.py +++ b/tests/custom/test_sensetime_avatar.py @@ -42,8 +42,8 @@ def _scene_list() -> List[Dict[str, Any]]: def test_sensetime_avatar_accepts_snake_case() -> None: from agora_agent.agentkit.vendors.cn import SenseTimeAvatar - camel = SenseTimeAvatar(agora_uid="2", appId="app", app_key="key").to_config() snake = SenseTimeAvatar(agora_uid="2", app_id="app", app_key="key").to_config() + camel = SenseTimeAvatar(**{"agora_uid": "2", "appId": "app", "app_key": "key"}).to_config() assert camel == snake assert snake["params"]["appId"] == "app" @@ -52,9 +52,9 @@ def test_sensetime_avatar_to_config_shape() -> None: config = SenseTimeAvatar( agora_token="avatar-token", agora_uid="2", - appId="sensetime-app-id", + app_id="sensetime-app-id", app_key="sensetime-app-key", - sceneList=_scene_list(), + scene_list=_scene_list(), ).to_config() assert config == { @@ -75,7 +75,7 @@ def test_sensetime_avatar_to_config_omits_token_when_not_provided() -> None: config = SenseTimeAvatar( agora_uid="2", app_key="sensetime-app-key", - sceneList=_scene_list(), + scene_list=_scene_list(), ).to_config() assert "agora_token" not in config["params"] @@ -142,7 +142,7 @@ def test_sensetime_avatar_session_validation_and_token_passthrough() -> None: agora_token="avatar-token", agora_uid="2", app_key="sensetime-app-key", - sceneList=_scene_list(), + scene_list=_scene_list(), ) ) session = _session(agent) @@ -164,7 +164,7 @@ def test_sensetime_avatar_enrichment_generates_token() -> None: SenseTimeAvatar( agora_uid="2", app_key="sensetime-app-key", - sceneList=_scene_list(), + scene_list=_scene_list(), ) ) session = _session(agent) @@ -187,7 +187,7 @@ def test_sensetime_avatar_user_token_is_not_overwritten() -> None: agora_uid="2", agora_token="user-token", app_key="sensetime-app-key", - sceneList=_scene_list(), + scene_list=_scene_list(), ) ) session = _session(agent) diff --git a/tests/custom/test_vendor_collapse_golden.py b/tests/custom/test_vendor_collapse_golden.py index 254648a..9122cf6 100644 --- a/tests/custom/test_vendor_collapse_golden.py +++ b/tests/custom/test_vendor_collapse_golden.py @@ -85,9 +85,10 @@ def test_aliyun_llm_pins_vendor_golden() -> None: def test_sensetime_avatar_camelcase_golden() -> None: - # SenseTimeAvatarOptions uses alias "appId" for the app_id field; - # pydantic v2 requires the alias keyword in the constructor. - cfg = SenseTimeAvatar(agora_uid="2", appId="app", app_key="key").to_config() + # Verifies that the camelCase alias "appId" still works as constructor input + # (populate_by_name=True means both the field name and alias are accepted at runtime). + # Dict-unpack is used so mypy does not call-arg-check the alias keyword. + cfg = SenseTimeAvatar(**{"agora_uid": "2", "appId": "app", "app_key": "key"}).to_config() assert cfg["vendor"] == "sensetime" assert cfg["params"]["appId"] == "app" @@ -104,5 +105,5 @@ def test_fengming_rejects_kwargs() -> None: import pytest from pydantic import ValidationError with pytest.raises(ValidationError): - FengmingSTT(unexpected="x") + FengmingSTT(unexpected="x") # type: ignore[call-arg] assert FengmingSTT().to_config() == {"vendor": "fengming"} From be39ea40554f39a93949dea32341245a050211db Mon Sep 17 00:00:00 2001 From: plutoless Date: Thu, 2 Jul 2026 09:12:03 -0700 Subject: [PATCH 12/12] feat: export global vendor types at top-level; add agora_agent.cn module Global vendor classes now resolve to real types via the top-level TYPE_CHECKING block, so `from agora_agent import DeepgramSTT` gives IDE autocomplete (was Any). CN vendors get a dedicated `agora_agent.cn` module (natural names), which also resolves the MicrosoftSTT/MicrosoftTTS/MiniMaxTTS collisions. SpatiusAvatar (CN) is now cn-only; XaiGrok/GenericAvatar demoted to TYPE_CHECKING-only. Co-Authored-By: Claude Opus 4.8 --- .fernignore | 1 + docs/guides/avatars.md | 3 +- src/agora_agent/__init__.py | 57 +++++++++++++++++++++++++----- src/agora_agent/cn.py | 60 ++++++++++++++++++++++++++++++++ tests/custom/test_llm_vendors.py | 14 ++++---- 5 files changed, 119 insertions(+), 16 deletions(-) create mode 100644 src/agora_agent/cn.py diff --git a/.fernignore b/.fernignore index f3790e8..3d06feb 100644 --- a/.fernignore +++ b/.fernignore @@ -1,6 +1,7 @@ # Specify files that shouldn't be modified by Fern src/agora_agent/pool_client.py src/agora_agent/__init__.py +src/agora_agent/cn.py src/agora_agent/core/domain.py changelog.md diff --git a/docs/guides/avatars.md b/docs/guides/avatars.md index af50888..abd82a4 100644 --- a/docs/guides/avatars.md +++ b/docs/guides/avatars.md @@ -132,7 +132,8 @@ agent = ( `SpatiusAvatar` is available for `Area.CN` sessions. Provide `spatius_api_key`, `spatius_app_id`, `spatius_avatar_id`, and `agora_uid` when constructing the avatar. `agora_token` is optional and is generated at session start when omitted, like SenseTime and Generic avatars. ```python -from agora_agent import Agora, Area, CNAgent, GenericTTS, SpatiusAvatar, TencentSTT +from agora_agent import Agora, Area, CNAgent, GenericTTS +from agora_agent.cn import SpatiusAvatar, TencentSTT client = Agora( area=Area.CN, diff --git a/src/agora_agent/__init__.py b/src/agora_agent/__init__.py index 0fa00e2..b132d4b 100644 --- a/src/agora_agent/__init__.py +++ b/src/agora_agent/__init__.py @@ -20,12 +20,57 @@ AgentSessionOptions, CNAgent, GlobalAgent, - GenericAvatar, - SpatiusAvatar, RegionalAgent, - XaiGrok, generate_rtc_token, GenerateTokenOptions, + # Global (non-CN) vendor classes — re-exported from agentkit for static typing so + # `from agora_agent import DeepgramSTT` autocompletes. Runtime resolution and + # `__all__` membership come from the `__getattr__` fallback + the `__all__` union + # below, so these need no `_dynamic_imports` / `_ROOT_ALL` entries. + # CN vendors are intentionally not here — import them from `agora_agent.cn`. + AkoolAvatar, + AmazonBedrock, + AmazonSTT, + AmazonTTS, + AnamAvatar, + Anthropic, + AresSTT, + AssemblyAISTT, + AzureOpenAI, + CartesiaTTS, + CustomLLM, + DeepgramSTT, + DeepgramTTS, + Dify, + ElevenLabsTTS, + FishAudioTTS, + Gemini, + GeminiLive, + GenericAvatar, + GenericTTS, + GoogleSTT, + GoogleTTS, + Groq, + HeyGenAvatar, + HumeAITTS, + LiveAvatarAvatar, + MicrosoftSTT, + MicrosoftTTS, + MiniMaxTTS, + MurfTTS, + OpenAI, + OpenAIRealtime, + OpenAISTT, + OpenAITTS, + RimeTTS, + SarvamSTT, + SarvamTTS, + SpeechmaticsSTT, + VertexAI, + VertexAILLM, + XaiGrok, + XaiSTT, + XaiTTS, ) from .agentkit.agent_session import AsyncAgentSession @@ -39,9 +84,6 @@ "AsyncAgentSession": ".agentkit.agent_session", "AsyncAgora": ".pool_client", "AsyncAgentClient": ".pool_client", - "GenericAvatar": ".agentkit", - "SpatiusAvatar": ".agentkit", - "XaiGrok": ".agentkit", "GenerateTokenOptions": ".agentkit", "__version__": ".version", "agentkit": ".agentkit", @@ -63,9 +105,6 @@ "AsyncAgentSession", "AsyncAgora", "AsyncAgentClient", - "GenericAvatar", - "SpatiusAvatar", - "XaiGrok", "GenerateTokenOptions", "Pool", "__version__", diff --git a/src/agora_agent/cn.py b/src/agora_agent/cn.py new file mode 100644 index 0000000..9df1a0f --- /dev/null +++ b/src/agora_agent/cn.py @@ -0,0 +1,60 @@ +# isort: skip_file +"""China (CN) vendor classes for the Agora Conversational AI SDK. + +Import CN vendors explicitly from this module so they stay separate from the global +top-level namespace:: + + from agora_agent.cn import AliyunLLM, MiniMaxTTS, TencentSTT + +``MicrosoftSTT`` / ``MicrosoftTTS`` / ``MiniMaxTTS`` here are the CN variants; the +global variants of those names are available from the top-level package +(e.g. ``from agora_agent import MiniMaxTTS``). +""" + +from .agentkit.vendors.cn import ( + AliyunLLM, + BytedanceDuplexTTS, + BytedanceLLM, + BytedanceTTS, + CosyVoiceTTS, + DeepSeekLLM, + FengmingSTT, + MicrosoftSTT, + MicrosoftTTS, + MiniMaxTTS, + SenseTimeAvatar, + SpatiusAvatar, + StepFunTTS, + TencentLLM, + TencentSTT, + TencentTTS, + XfyunBigModelSTT, + XfyunDialectSTT, + XfyunSTT, +) + +__all__ = [ + # STT + "FengmingSTT", + "MicrosoftSTT", + "TencentSTT", + "XfyunBigModelSTT", + "XfyunDialectSTT", + "XfyunSTT", + # TTS + "BytedanceDuplexTTS", + "BytedanceTTS", + "CosyVoiceTTS", + "MicrosoftTTS", + "MiniMaxTTS", + "StepFunTTS", + "TencentTTS", + # LLM + "AliyunLLM", + "BytedanceLLM", + "DeepSeekLLM", + "TencentLLM", + # Avatar + "SenseTimeAvatar", + "SpatiusAvatar", +] diff --git a/tests/custom/test_llm_vendors.py b/tests/custom/test_llm_vendors.py index b9cb90a..c2a5037 100644 --- a/tests/custom/test_llm_vendors.py +++ b/tests/custom/test_llm_vendors.py @@ -128,11 +128,13 @@ def test_dify_serializes_conversation_fields() -> None: def test_llm_vendors_reject_missing_required_models() -> None: + # These intentionally omit the required `model` to assert it raises; the top-level + # vendor types are now real, so mypy needs the call-arg ignores. with pytest.raises(Exception, match="model"): - OpenAI(api_key="openai-key", base_url="https://api.openai.com/v1/chat/completions") + OpenAI(api_key="openai-key", base_url="https://api.openai.com/v1/chat/completions") # type: ignore[call-arg] with pytest.raises(Exception, match="model"): - Anthropic( + Anthropic( # type: ignore[call-arg] api_key="anthropic-key", url="https://api.anthropic.com/v1/messages", headers={"anthropic-version": "2023-06-01"}, @@ -140,16 +142,16 @@ def test_llm_vendors_reject_missing_required_models() -> None: ) with pytest.raises(Exception, match="model"): - Gemini(api_key="google-key") + Gemini(api_key="google-key") # type: ignore[call-arg] with pytest.raises(Exception, match="model"): - Groq(api_key="groq-key", base_url="https://api.groq.com/openai/v1/chat/completions") + Groq(api_key="groq-key", base_url="https://api.groq.com/openai/v1/chat/completions") # type: ignore[call-arg] with pytest.raises(Exception, match="model"): - VertexAILLM(api_key="vertex-token", project_id="project", location="us-central1") + VertexAILLM(api_key="vertex-token", project_id="project", location="us-central1") # type: ignore[call-arg] with pytest.raises(Exception, match="model"): - AmazonBedrock(access_key="aws-access", secret_key="aws-secret", region="us-east-1") + AmazonBedrock(access_key="aws-access", secret_key="aws-secret", region="us-east-1") # type: ignore[call-arg] def test_openai_managed_mode_is_restricted_to_supported_models() -> None: