Skip to content

Commit 262b7b1

Browse files
Add model name to response body
1 parent 416ac39 commit 262b7b1

File tree

6 files changed

+211
-70
lines changed

6 files changed

+211
-70
lines changed

nemoguardrails/rails/llm/llmrails.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,11 @@ def _init_llms(self):
490490
if not self.llm:
491491
self.llm = llm_model
492492
self.runtime.register_action_param("llm", self.llm)
493+
self._configure_main_llm_streaming(
494+
self.llm,
495+
model_name=llm_config.model,
496+
provider_name=llm_config.engine,
497+
)
493498
else:
494499
model_name = f"{llm_config.type}_llm"
495500
if not hasattr(self, model_name):

nemoguardrails/server/api.py

Lines changed: 63 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,24 @@
2424
import uuid
2525
import warnings
2626
from contextlib import asynccontextmanager
27-
from typing import Any, AsyncIterator, Callable, List, Optional, Union
27+
from typing import Any, AsyncIterator, Callable, List, Optional
2828

2929
from fastapi import FastAPI, Request
3030
from fastapi.middleware.cors import CORSMiddleware
3131
from openai.types.chat.chat_completion import Choice
3232
from openai.types.chat.chat_completion_message import ChatCompletionMessage
33-
from openai.types.model import Model
3433
from pydantic import BaseModel, Field, root_validator, validator
3534
from starlette.responses import StreamingResponse
3635
from starlette.staticfiles import StaticFiles
3736

3837
from nemoguardrails import LLMRails, RailsConfig, utils
3938
from nemoguardrails.rails.llm.options import GenerationOptions, GenerationResponse
4039
from nemoguardrails.server.datastore.datastore import DataStore
41-
from nemoguardrails.server.schemas.openai import ModelsResponse, ResponseBody
40+
from nemoguardrails.server.schemas.openai import (
41+
GuardrailsModel,
42+
ModelsResponse,
43+
ResponseBody,
44+
)
4245
from nemoguardrails.streaming import StreamingHandler
4346

4447
logging.basicConfig(level=logging.INFO)
@@ -232,8 +235,8 @@ class RequestBody(BaseModel):
232235
)
233236
# Standard OpenAI completion parameters
234237
model: Optional[str] = Field(
235-
default=None,
236-
description="The model to use for chat completion. Maps to config_id for backward compatibility.",
238+
default="main",
239+
description="The model to use for chat completion. Maps to the main model in the config.",
237240
)
238241
max_tokens: Optional[int] = Field(
239242
default=None,
@@ -309,6 +312,7 @@ async def get_models():
309312
# Use the same logic as get_rails_configs to find available configurations
310313
if app.single_config_mode:
311314
config_ids = [app.single_config_id] if app.single_config_id else []
315+
312316
else:
313317
config_ids = [
314318
f
@@ -323,16 +327,43 @@ async def get_models():
323327
)
324328
]
325329

326-
# Convert configurations to OpenAI model format
327330
models = []
328331
for config_id in config_ids:
329-
model = Model(
330-
id=config_id,
331-
object="model",
332-
created=int(time.time()), # Use current time as created timestamp
333-
owned_by="nemo-guardrails",
334-
)
335-
models.append(model)
332+
try:
333+
# Load the RailsConfig to extract model information
334+
if app.single_config_mode:
335+
config_path = app.rails_config_path
336+
else:
337+
config_path = os.path.join(app.rails_config_path, config_id)
338+
339+
rails_config = RailsConfig.from_path(config_path)
340+
# Extract all models from this config
341+
config_models = rails_config.models
342+
343+
if len(config_models) == 0:
344+
guardrails_model = GuardrailsModel(
345+
id=config_id,
346+
object="model",
347+
created=int(time.time()),
348+
owned_by="nemo-guardrails",
349+
guardrails_config_id=config_id,
350+
)
351+
models.append(guardrails_model)
352+
else:
353+
for model in config_models:
354+
# Only include models with a model name
355+
if model.model:
356+
guardrails_model = GuardrailsModel(
357+
id=model.model,
358+
object="model",
359+
created=int(time.time()),
360+
owned_by="nemo-guardrails",
361+
guardrails_config_id=config_id,
362+
)
363+
models.append(guardrails_model)
364+
except Exception as ex:
365+
log.warning(f"Could not load model info for config {config_id}: {ex}")
366+
continue
336367

337368
return ModelsResponse(data=models)
338369

@@ -377,6 +408,14 @@ def _generate_cache_key(config_ids: List[str]) -> str:
377408
return "-".join((config_ids)) # remove sorted
378409

379410

411+
def _get_main_model_name(rails_config: RailsConfig) -> Optional[str]:
412+
"""Extracts the main model name from a RailsConfig."""
413+
main_models = [m for m in rails_config.models if m.type == "main"]
414+
if main_models and main_models[0].model:
415+
return main_models[0].model
416+
return None
417+
418+
380419
def _get_rails(config_ids: List[str]) -> LLMRails:
381420
"""Returns the rails instance for the given config id."""
382421

@@ -518,6 +557,7 @@ async def chat_completion(body: RequestBody, request: Request):
518557
# Use Request config_ids if set, otherwise use the FastAPI default config.
519558
# If neither is available we can't generate any completions as we have no config_id
520559
config_ids = body.config_ids
560+
521561
if not config_ids:
522562
if app.default_config_id:
523563
config_ids = [app.default_config_id]
@@ -528,6 +568,7 @@ async def chat_completion(body: RequestBody, request: Request):
528568

529569
try:
530570
llm_rails = _get_rails(config_ids)
571+
531572
except ValueError as ex:
532573
log.exception(ex)
533574
return ResponseBody(
@@ -550,6 +591,10 @@ async def chat_completion(body: RequestBody, request: Request):
550591
)
551592

552593
try:
594+
main_model_name = _get_main_model_name(llm_rails.config)
595+
if main_model_name is None:
596+
main_model_name = config_ids[0] if config_ids else "unknown"
597+
553598
messages = body.messages or []
554599
if body.context:
555600
messages.insert(0, {"role": "context", "content": body.context})
@@ -560,14 +605,13 @@ async def chat_completion(body: RequestBody, request: Request):
560605
if body.thread_id:
561606
if datastore is None:
562607
raise RuntimeError("No DataStore has been configured.")
563-
564608
# We make sure the `thread_id` meets the minimum complexity requirement.
565609
if len(body.thread_id) < 16:
566610
return ResponseBody(
567611
id=f"chatcmpl-{uuid.uuid4()}",
568612
object="chat.completion",
569613
created=int(time.time()),
570-
model=config_ids[0] if config_ids else "unknown",
614+
model=main_model_name,
571615
choices=[
572616
Choice(
573617
index=0,
@@ -608,7 +652,6 @@ async def chat_completion(body: RequestBody, request: Request):
608652
generation_options.llm_params["presence_penalty"] = body.presence_penalty
609653
if body.frequency_penalty is not None:
610654
generation_options.llm_params["frequency_penalty"] = body.frequency_penalty
611-
612655
if (
613656
body.stream
614657
and llm_rails.config.streaming_supported
@@ -629,7 +672,7 @@ async def chat_completion(body: RequestBody, request: Request):
629672

630673
return StreamingResponse(
631674
_format_streaming_response(
632-
streaming_handler, model_name=config_ids[0] if config_ids else None
675+
streaming_handler, model_name=main_model_name
633676
),
634677
media_type="text/event-stream",
635678
)
@@ -654,12 +697,12 @@ async def chat_completion(body: RequestBody, request: Request):
654697
if body.thread_id and datastore is not None and datastore_key is not None:
655698
await datastore.set(datastore_key, json.dumps(messages + [bot_message]))
656699

657-
# Build the response with OpenAI-compatible format plus NeMo-Guardrails extensions
700+
# Build the response with OpenAI-compatible format
658701
response_kwargs = {
659702
"id": f"chatcmpl-{uuid.uuid4()}",
660703
"object": "chat.completion",
661704
"created": int(time.time()),
662-
"model": config_ids[0] if config_ids else "unknown",
705+
"model": main_model_name,
663706
"choices": [
664707
Choice(
665708
index=0,
@@ -688,7 +731,7 @@ async def chat_completion(body: RequestBody, request: Request):
688731
id=f"chatcmpl-{uuid.uuid4()}",
689732
object="chat.completion",
690733
created=int(time.time()),
691-
model="unknown",
734+
model=config_ids[0] if config_ids else "unknown",
692735
choices=[
693736
Choice(
694737
index=0,

nemoguardrails/server/schemas/openai.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
class ResponseBody(ChatCompletion):
2626
"""OpenAI API response body with NeMo-Guardrails extensions."""
2727

28+
guardrails_config_id: Optional[str] = Field(
29+
default=None,
30+
description="The guardrails configuration ID associated with this response.",
31+
)
2832
state: Optional[dict] = Field(
2933
default=None, description="State object for continuing the conversation."
3034
)
@@ -37,10 +41,19 @@ class ResponseBody(ChatCompletion):
3741
log: Optional[dict] = Field(default=None, description="Generation log data.")
3842

3943

44+
class GuardrailsModel(Model):
45+
"""OpenAI API model with NeMo-Guardrails extensions."""
46+
47+
guardrails_config_id: Optional[str] = Field(
48+
default=None,
49+
description="[NeMo Guardrails extension] The guardrails configuration ID associated with this model.",
50+
)
51+
52+
4053
class ModelsResponse(BaseModel):
41-
"""OpenAI API models list response."""
54+
"""OpenAI API models list response with NeMo-Guardrails extensions."""
4255

4356
object: str = Field(
4457
default="list", description="The object type, which is always 'list'."
4558
)
46-
data: List[Model] = Field(description="The list of models.")
59+
data: List[GuardrailsModel] = Field(description="The list of models.")

tests/test_api.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from nemoguardrails.server.api import RequestBody, _format_streaming_response
2525
from nemoguardrails.streaming import END_OF_STREAM, StreamingHandler
2626

27+
LIVE_TEST_MODE = os.environ.get("LIVE_TEST_MODE") or os.environ.get("TEST_LIVE_MODE")
28+
2729
client = TestClient(api.app)
2830

2931

@@ -59,12 +61,16 @@ def test_get_models():
5961
# Check each model has the required OpenAI format
6062
for model in result["data"]:
6163
assert "id" in model
64+
assert "guardrails_config_id" in model
6265
assert model["object"] == "model"
6366
assert "created" in model
6467
assert model["owned_by"] == "nemo-guardrails"
6568

6669

67-
@pytest.mark.skip(reason="Should only be run locally as it needs OpenAI key.")
70+
@pytest.mark.skipif(
71+
not LIVE_TEST_MODE,
72+
reason="This test requires LIVE_TEST_MODE or TEST_LIVE_MODE environment variable to be set for live testing",
73+
)
6874
def test_chat_completion():
6975
response = client.post(
7076
"/v1/chat/completions",
@@ -90,7 +96,10 @@ def test_chat_completion():
9096
assert res["choices"][0]["message"]["role"] == "assistant"
9197

9298

93-
@pytest.mark.skip(reason="Should only be run locally as it needs OpenAI key.")
99+
@pytest.mark.skipif(
100+
not LIVE_TEST_MODE,
101+
reason="This test requires LIVE_TEST_MODE or TEST_LIVE_MODE environment variable to be set for live testing",
102+
)
94103
def test_chat_completion_with_default_configs():
95104
api.set_default_config_id("general")
96105

0 commit comments

Comments
 (0)