diff --git a/chatlas/_provider_openai.py b/chatlas/_provider_openai.py index 255aa6bd..d7cefec9 100644 --- a/chatlas/_provider_openai.py +++ b/chatlas/_provider_openai.py @@ -49,6 +49,9 @@ def ChatOpenAI( model: "Optional[ResponsesModel | str]" = None, api_key: Optional[str] = None, base_url: str = "https://api.openai.com/v1", + service_tier: Optional[ + Literal["auto", "default", "flex", "scale", "priority"] + ] = None, kwargs: Optional["ChatClientArgs"] = None, ) -> Chat["SubmitInputArgs", Response]: """ @@ -93,6 +96,13 @@ def ChatOpenAI( variable. base_url The base URL to the endpoint; the default uses OpenAI. + service_tier + Request a specific service tier. Options: + - `"auto"` (default): uses the service tier configured in Project settings. + - `"default"`: standard pricing and performance. + - `"flex"`: slower and cheaper. + - `"scale"`: batch-like pricing for high-volume use. + - `"priority"`: faster and more expensive. kwargs Additional arguments to pass to the `openai.OpenAI()` client constructor. @@ -146,6 +156,10 @@ def ChatOpenAI( if model is None: model = log_model_default("gpt-4.1") + kwargs_chat: "SubmitInputArgs" = {} + if service_tier is not None: + kwargs_chat["service_tier"] = service_tier + return Chat( provider=OpenAIProvider( api_key=api_key, @@ -154,6 +168,7 @@ def ChatOpenAI( kwargs=kwargs, ), system_prompt=system_prompt, + kwargs_chat=kwargs_chat, ) @@ -260,6 +275,16 @@ def stream_text(self, chunk): def stream_merge_chunks(self, completion, chunk): if chunk.type == "response.completed": return chunk.response + elif chunk.type == "response.failed": + error = chunk.response.error + if error is None: + msg = "Request failed with an unknown error." + else: + msg = f"Request failed ({error.code}): {error.message}" + raise RuntimeError(msg) + elif chunk.type == "error": + raise RuntimeError(f"Request errored: {chunk.message}") + # Since this value won't actually be used, we can lie about the type return cast(Response, None) diff --git a/chatlas/_provider_openai_azure.py b/chatlas/_provider_openai_azure.py index 7a1b78bc..0f2bc39d 100644 --- a/chatlas/_provider_openai_azure.py +++ b/chatlas/_provider_openai_azure.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Literal, Optional from openai import AsyncAzureOpenAI, AzureOpenAI from openai.types.chat import ChatCompletion @@ -21,6 +21,9 @@ def ChatAzureOpenAI( api_version: str, api_key: Optional[str] = None, system_prompt: Optional[str] = None, + service_tier: Optional[ + Literal["auto", "default", "flex", "scale", "priority"] + ] = None, kwargs: Optional["ChatAzureClientArgs"] = None, ) -> Chat["SubmitInputArgs", ChatCompletion]: """ @@ -62,6 +65,13 @@ def ChatAzureOpenAI( variable. system_prompt A system prompt to set the behavior of the assistant. + service_tier + Request a specific service tier. Options: + - `"auto"` (default): uses the service tier configured in Project settings. + - `"default"`: standard pricing and performance. + - `"flex"`: slower and cheaper. + - `"scale"`: batch-like pricing for high-volume use. + - `"priority"`: faster and more expensive. kwargs Additional arguments to pass to the `openai.AzureOpenAI()` client constructor. @@ -71,6 +81,10 @@ def ChatAzureOpenAI( A Chat object. """ + kwargs_chat: "SubmitInputArgs" = {} + if service_tier is not None: + kwargs_chat["service_tier"] = service_tier + return Chat( provider=OpenAIAzureProvider( endpoint=endpoint, @@ -80,6 +94,7 @@ def ChatAzureOpenAI( kwargs=kwargs, ), system_prompt=system_prompt, + kwargs_chat=kwargs_chat, ) diff --git a/tests/test_provider_openai.py b/tests/test_provider_openai.py index a36e14dc..93908e93 100644 --- a/tests/test_provider_openai.py +++ b/tests/test_provider_openai.py @@ -105,3 +105,8 @@ def test_openai_custom_http_client(): def test_openai_list_models(): assert_list_models(ChatOpenAI) + + +def test_openai_service_tier(): + chat = ChatOpenAI(service_tier="flex") + assert chat.kwargs_chat.get("service_tier") == "flex" \ No newline at end of file