22from typing import Any , Dict , List , Optional
33
44import openai
5- from openai import APIConnectionError , APITimeoutError , AsyncOpenAI , RateLimitError
5+ from openai import APIConnectionError , APITimeoutError , AsyncOpenAI , AsyncAzureOpenAI , RateLimitError
66from tenacity import (
77 retry ,
88 retry_if_exception_type ,
@@ -35,17 +35,20 @@ def __init__(
3535 model : str = "gpt-4o-mini" ,
3636 api_key : Optional [str ] = None ,
3737 base_url : Optional [str ] = None ,
38+ api_version : Optional [str ] = None ,
3839 json_mode : bool = False ,
3940 seed : Optional [int ] = None ,
4041 topk_per_token : int = 5 , # number of topk tokens to generate for each token
4142 request_limit : bool = False ,
4243 rpm : Optional [RPM ] = None ,
4344 tpm : Optional [TPM ] = None ,
45+ backend : str = "openai_api" ,
4446 ** kwargs : Any ,
4547 ):
4648 super ().__init__ (** kwargs )
4749 self .model = model
4850 self .api_key = api_key
51+ self .api_version = api_version # required for Azure OpenAI
4952 self .base_url = base_url
5053 self .json_mode = json_mode
5154 self .seed = seed
@@ -56,13 +59,32 @@ def __init__(
5659 self .rpm = rpm or RPM ()
5760 self .tpm = tpm or TPM ()
5861
62+ assert (
63+ backend in ("openai_api" , "azure_openai_api" )
64+ ), f"Unsupported backend '{ backend } '. Use 'openai_api' or 'azure_openai_api'."
65+ self .backend = backend
66+
5967 self .__post_init__ ()
6068
6169 def __post_init__ (self ):
62- assert self .api_key is not None , "Please provide api key to access openai api."
63- self .client = AsyncOpenAI (
64- api_key = self .api_key or "dummy" , base_url = self .base_url
65- )
70+
71+ api_name = self .backend .replace ("_" , " " )
72+ assert self .api_key is not None , f"Please provide api key to access { api_name } ."
73+ if self .backend == "openai_api" :
74+ self .client = AsyncOpenAI (
75+ api_key = self .api_key or "dummy" , base_url = self .base_url
76+ )
77+ elif self .backend == "azure_openai_api" :
78+ assert self .api_version is not None , f"Please provide api_version for { api_name } ."
79+ assert self .base_url is not None , f"Please provide base_url for { api_name } ."
80+ self .client = AsyncAzureOpenAI (
81+ api_key = self .api_key ,
82+ azure_endpoint = self .base_url ,
83+ api_version = self .api_version ,
84+ azure_deployment = self .model ,
85+ )
86+ else :
87+ raise ValueError (f"Unsupported backend { self .backend } . Use 'openai_api' or 'azure_openai_api'." )
6688
6789 def _pre_generate (self , text : str , history : List [str ]) -> Dict :
6890 kwargs = {
0 commit comments