Skip to content

Commit f34792c

Browse files
Merge pull request #99 from muhammadyaseen/feat/azure-openai-support
2 parents 2024c9d + dca04cd commit f34792c

File tree

3 files changed

+39
-8
lines changed

3 files changed

+39
-8
lines changed

.env.example

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@ TRAINEE_MODEL=gpt-4o-mini
1414
TRAINEE_BASE_URL=
1515
TRAINEE_API_KEY=
1616

17+
# azure_openai_api
18+
# SYNTHESIZER_BACKEND=azure_openai_api
19+
# The following is the same as your "Deployment name" in Azure
20+
# SYNTHESIZER_MODEL=<your-deployment-name>
21+
# SYNTHESIZER_BASE_URL=https://<your-resource-name>.openai.azure.com/openai/deployments/<your-deployment-name>/chat/completions
22+
# SYNTHESIZER_API_KEY=
23+
# SYNTHESIZER_API_VERSION=<api-version>
24+
1725
# # ollama_api
1826
# SYNTHESIZER_BACKEND=ollama_api
1927
# SYNTHESIZER_MODEL=gemma3

graphgen/models/llm/api/openai_client.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Any, Dict, List, Optional
33

44
import openai
5-
from openai import APIConnectionError, APITimeoutError, AsyncOpenAI, RateLimitError
5+
from openai import APIConnectionError, APITimeoutError, AsyncOpenAI, AsyncAzureOpenAI, RateLimitError
66
from tenacity import (
77
retry,
88
retry_if_exception_type,
@@ -35,17 +35,20 @@ def __init__(
3535
model: str = "gpt-4o-mini",
3636
api_key: Optional[str] = None,
3737
base_url: Optional[str] = None,
38+
api_version: Optional[str] = None,
3839
json_mode: bool = False,
3940
seed: Optional[int] = None,
4041
topk_per_token: int = 5, # number of topk tokens to generate for each token
4142
request_limit: bool = False,
4243
rpm: Optional[RPM] = None,
4344
tpm: Optional[TPM] = None,
45+
backend: str = "openai_api",
4446
**kwargs: Any,
4547
):
4648
super().__init__(**kwargs)
4749
self.model = model
4850
self.api_key = api_key
51+
self.api_version = api_version # required for Azure OpenAI
4952
self.base_url = base_url
5053
self.json_mode = json_mode
5154
self.seed = seed
@@ -56,13 +59,32 @@ def __init__(
5659
self.rpm = rpm or RPM()
5760
self.tpm = tpm or TPM()
5861

62+
assert (
63+
backend in ("openai_api", "azure_openai_api")
64+
), f"Unsupported backend '{backend}'. Use 'openai_api' or 'azure_openai_api'."
65+
self.backend = backend
66+
5967
self.__post_init__()
6068

6169
def __post_init__(self):
62-
assert self.api_key is not None, "Please provide api key to access openai api."
63-
self.client = AsyncOpenAI(
64-
api_key=self.api_key or "dummy", base_url=self.base_url
65-
)
70+
71+
api_name = self.backend.replace("_", " ")
72+
assert self.api_key is not None, f"Please provide api key to access {api_name}."
73+
if self.backend == "openai_api":
74+
self.client = AsyncOpenAI(
75+
api_key=self.api_key or "dummy", base_url=self.base_url
76+
)
77+
elif self.backend == "azure_openai_api":
78+
assert self.api_version is not None, f"Please provide api_version for {api_name}."
79+
assert self.base_url is not None, f"Please provide base_url for {api_name}."
80+
self.client = AsyncAzureOpenAI(
81+
api_key=self.api_key,
82+
azure_endpoint=self.base_url,
83+
api_version=self.api_version,
84+
azure_deployment=self.model,
85+
)
86+
else:
87+
raise ValueError(f"Unsupported backend {self.backend}. Use 'openai_api' or 'azure_openai_api'.")
6688

6789
def _pre_generate(self, text: str, history: List[str]) -> Dict:
6890
kwargs = {

graphgen/operators/init/init_llm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,11 @@ def create_llm_wrapper(backend: str, config: Dict[str, Any]) -> BaseLLMWrapper:
2727
from graphgen.models.llm.api.http_client import HTTPClient
2828

2929
return HTTPClient(**config)
30-
if backend == "openai_api":
30+
if backend in ("openai_api", "azure_openai_api"):
3131
from graphgen.models.llm.api.openai_client import OpenAIClient
32-
33-
return OpenAIClient(**config)
32+
# pass in concrete backend to the OpenAIClient so that internally we can distinguish
33+
# between OpenAI and Azure OpenAI
34+
return OpenAIClient(**config, backend=backend)
3435
if backend == "ollama_api":
3536
from graphgen.models.llm.api.ollama_client import OllamaClient
3637

0 commit comments

Comments
 (0)