Skip to content

Commit 613483e

Browse files
Address PR feedback for Cerebras integration
- Update docs/models/cerebras.md: use pip/uv-add format, link to Cerebras docs - Fix httpx.AsyncClient typo in cerebras.md, groq.md, mistral.md - Add docs/api/models/cerebras.md and update mkdocs.yml - Remove Cerebras section from openai.md, move to main list in overview.md - Add str | to CerebrasModelName for arbitrary model names - Add CerebrasModelSettings with cerebras_disable_reasoning field - Add zai_model_profile, restore unsupported_model_settings and json_schema_transformer - Pass lowercase model name to profile functions - Add tests/providers/test_cerebras.py with full coverage - Remove type ignore in models/__init__.py
1 parent 01d0009 commit 613483e

File tree

11 files changed

+189
-55
lines changed

11 files changed

+189
-55
lines changed

docs/api/models/cerebras.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# `pydantic_ai.models.cerebras`
2+
3+
## Setup
4+
5+
For details on how to set up authentication with this model, see [model configuration for Cerebras](../../models/cerebras.md).
6+
7+
::: pydantic_ai.models.cerebras

docs/models/cerebras.md

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,14 @@
55
To use `CerebrasModel`, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `cerebras` optional group:
66

77
```bash
8-
pip install "pydantic-ai-slim[cerebras]"
9-
```
10-
11-
or
12-
13-
```bash
14-
uv add "pydantic-ai-slim[cerebras]"
8+
pip/uv-add "pydantic-ai-slim[cerebras]"
159
```
1610

1711
## Configuration
1812

19-
To use [Cerebras](https://cerebras.ai/) through their API, go to [cloud.cerebras.ai](https://cloud.cerebras.ai/?utm_source=3pi_pydantic-ai&utm_campaign=partner_doc) and follow your nose until you find the place to generate an API key.
13+
To use [Cerebras](https://cerebras.ai/) through their API, go to [cloud.cerebras.ai](https://cloud.cerebras.ai/?utm_source=3pi_pydantic-ai&utm_campaign=partner_doc) and generate an API key.
2014

21-
`CerebrasModelName` contains a list of available Cerebras models.
15+
For a list of available models, see the [Cerebras models documentation](https://inference-docs.cerebras.ai/models).
2216

2317
## Environment variable
2418

@@ -64,7 +58,7 @@ agent = Agent(model)
6458
...
6559
```
6660

67-
You can also customize the `CerebrasProvider` with a custom `httpx.AsyncHTTPClient`:
61+
You can also customize the `CerebrasProvider` with a custom `httpx.AsyncClient`:
6862

6963
```python
7064
from httpx import AsyncClient

docs/models/groq.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ agent = Agent(model)
5858
...
5959
```
6060

61-
You can also customize the `GroqProvider` with a custom `httpx.AsyncHTTPClient`:
61+
You can also customize the `GroqProvider` with a custom `httpx.AsyncClient`:
6262

6363
```python
6464
from httpx import AsyncClient

docs/models/mistral.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ agent = Agent(model)
5858
...
5959
```
6060

61-
You can also customize the provider with a custom `httpx.AsyncHTTPClient`:
61+
You can also customize the provider with a custom `httpx.AsyncClient`:
6262

6363
```python
6464
from httpx import AsyncClient

docs/models/openai.md

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -632,39 +632,6 @@ agent = Agent(model)
632632
...
633633
```
634634

635-
### Cerebras
636-
637-
To use [Cerebras](https://cerebras.ai/), you need to create an API key in the [Cerebras Console](https://cloud.cerebras.ai/).
638-
639-
You can set the `CEREBRAS_API_KEY` environment variable and use [`CerebrasProvider`][pydantic_ai.providers.cerebras.CerebrasProvider] by name:
640-
641-
```python
642-
from pydantic_ai import Agent
643-
644-
agent = Agent('cerebras:llama3.3-70b')
645-
result = agent.run_sync('What is the capital of France?')
646-
print(result.output)
647-
#> The capital of France is Paris.
648-
```
649-
650-
Or initialise the model and provider directly:
651-
652-
```python
653-
from pydantic_ai import Agent
654-
from pydantic_ai.models.openai import OpenAIChatModel
655-
from pydantic_ai.providers.cerebras import CerebrasProvider
656-
657-
model = OpenAIChatModel(
658-
'llama3.3-70b',
659-
provider=CerebrasProvider(api_key='your-cerebras-api-key'),
660-
)
661-
agent = Agent(model)
662-
663-
result = agent.run_sync('What is the capital of France?')
664-
print(result.output)
665-
#> The capital of France is Paris.
666-
```
667-
668635
### LiteLLM
669636

670637
To use [LiteLLM](https://www.litellm.ai/), set the configs as outlined in the [doc](https://docs.litellm.ai/docs/set_keys). In `LiteLLMProvider`, you can pass `api_base` and `api_key`. The value of these configs will depend on your setup. For example, if you are using OpenAI models, then you need to pass `https://api.openai.com/v1` as the `api_base` and your OpenAI API key as the `api_key`. If you are using a LiteLLM proxy server running on your local machine, then you need to pass `http://localhost:<port>` as the `api_base` and your LiteLLM API key (or a placeholder) as the `api_key`.

docs/models/overview.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ Pydantic AI is model-agnostic and has built-in support for multiple model provid
99
* [Mistral](mistral.md)
1010
* [Cohere](cohere.md)
1111
* [Bedrock](bedrock.md)
12+
* [Cerebras](cerebras.md)
1213
* [Hugging Face](huggingface.md)
1314
* [Outlines](outlines.md)
1415

@@ -27,7 +28,6 @@ In addition, many providers are compatible with the OpenAI API, and can be used
2728
- [Azure AI Foundry](openai.md#azure-ai-foundry)
2829
- [Heroku](openai.md#heroku-ai)
2930
- [GitHub Models](openai.md#github-models)
30-
- [Cerebras](openai.md#cerebras)
3131
- [LiteLLM](openai.md#litellm)
3232
- [Nebius AI Studio](openai.md#nebius-ai-studio)
3333
- [OVHcloud AI Endpoints](openai.md#ovhcloud-ai-endpoints)

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ nav:
147147
- api/models/openai.md
148148
- api/models/anthropic.md
149149
- api/models/bedrock.md
150+
- api/models/cerebras.md
150151
- api/models/cohere.md
151152
- api/models/google.md
152153
- api/models/groq.md

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -819,7 +819,7 @@ def infer_model( # noqa: C901
819819
if model_kind == 'cerebras':
820820
from .cerebras import CerebrasModel
821821

822-
return CerebrasModel(model_name, provider=provider) # type: ignore[arg-type]
822+
return CerebrasModel(model_name, provider=provider)
823823
elif model_kind == 'openai-chat':
824824
from .openai import OpenAIChatModel
825825

pydantic_ai_slim/pydantic_ai/models/cerebras.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
'you can use the `cerebras` optional group — `pip install "pydantic-ai-slim[cerebras]"'
1919
) from _import_error
2020

21-
__all__ = ('CerebrasModel', 'CerebrasModelName')
21+
__all__ = ('CerebrasModel', 'CerebrasModelName', 'CerebrasModelSettings')
2222

23-
CerebrasModelName = Literal[
23+
_KnownCerebrasModelName = Literal[
2424
'gpt-oss-120b',
2525
'llama-3.3-70b',
2626
'llama3.1-8b',
@@ -29,6 +29,28 @@
2929
'zai-glm-4.6',
3030
]
3131

32+
CerebrasModelName = str | _KnownCerebrasModelName
33+
"""Possible Cerebras model names.
34+
35+
Since Cerebras supports a variety of models and the list changes frequently, we explicitly list known models
36+
but allow any name in the type hints.
37+
38+
See <https://inference-docs.cerebras.ai/models/overview> for an up to date list of models.
39+
"""
40+
41+
42+
class CerebrasModelSettings(ModelSettings, total=False):
43+
"""Settings used for a Cerebras model request.
44+
45+
ALL FIELDS MUST BE `cerebras_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
46+
"""
47+
48+
cerebras_disable_reasoning: bool
49+
"""Disable reasoning for the model.
50+
51+
See [the Cerebras docs](https://inference-docs.cerebras.ai/resources/openai#passing-non-standard-parameters) for more details.
52+
"""
53+
3254

3355
@dataclass(init=False)
3456
class CerebrasModel(OpenAIChatModel):
@@ -45,7 +67,7 @@ def __init__(
4567
*,
4668
provider: Literal['cerebras'] | Provider[AsyncOpenAI] = 'cerebras',
4769
profile: ModelProfileSpec | None = None,
48-
settings: ModelSettings | None = None,
70+
settings: CerebrasModelSettings | None = None,
4971
):
5072
"""Initialize a Cerebras model.
5173

pydantic_ai_slim/pydantic_ai/providers/cerebras.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pydantic_ai.models import cached_async_http_client
1111
from pydantic_ai.profiles.harmony import harmony_model_profile
1212
from pydantic_ai.profiles.meta import meta_model_profile
13-
from pydantic_ai.profiles.openai import OpenAIModelProfile
13+
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile
1414
from pydantic_ai.profiles.qwen import qwen_model_profile
1515
from pydantic_ai.providers import Provider
1616

@@ -23,6 +23,15 @@
2323
) from _import_error
2424

2525

26+
def zai_model_profile(model_name: str) -> ModelProfile | None:
27+
"""The model profile for ZAI models on Cerebras."""
28+
return ModelProfile(
29+
supports_json_object_output=True,
30+
supports_json_schema_output=True,
31+
json_schema_transformer=OpenAIJsonSchemaTransformer,
32+
)
33+
34+
2635
class CerebrasProvider(Provider[AsyncOpenAI]):
2736
"""Provider for Cerebras API."""
2837

@@ -43,18 +52,30 @@ def model_profile(self, model_name: str) -> ModelProfile | None:
4352
'llama': meta_model_profile,
4453
'qwen': qwen_model_profile,
4554
'gpt-oss': harmony_model_profile,
55+
'zai': zai_model_profile,
4656
}
4757

4858
profile = None
4959
model_name_lower = model_name.lower()
5060
for prefix, profile_func in prefix_to_profile.items():
5161
if model_name_lower.startswith(prefix):
52-
profile = profile_func(model_name)
62+
profile = profile_func(model_name_lower)
5363
break
5464

55-
# Wrap in OpenAIModelProfile with web search disabled
56-
# Cerebras doesn't support web search
57-
return OpenAIModelProfile(openai_chat_supports_web_search=False).update(profile)
65+
# According to https://inference-docs.cerebras.ai/resources/openai#currently-unsupported-openai-features,
66+
# Cerebras doesn't support some model settings.
67+
# openai_chat_supports_web_search=False is default, so not required here
68+
unsupported_model_settings = (
69+
'frequency_penalty',
70+
'logit_bias',
71+
'presence_penalty',
72+
'parallel_tool_calls',
73+
'service_tier',
74+
)
75+
return OpenAIModelProfile(
76+
json_schema_transformer=OpenAIJsonSchemaTransformer,
77+
openai_unsupported_model_settings=unsupported_model_settings,
78+
).update(profile)
5879

5980
@overload
6081
def __init__(self) -> None: ...

0 commit comments

Comments
 (0)