Skip to content

Commit 24ad343

Browse files
Add model name to response body
1 parent 416ac39 commit 24ad343

File tree

8 files changed

+219
-127
lines changed

8 files changed

+219
-127
lines changed

nemoguardrails/rails/llm/llmrails.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,11 @@ def _init_llms(self):
490490
if not self.llm:
491491
self.llm = llm_model
492492
self.runtime.register_action_param("llm", self.llm)
493+
self._configure_main_llm_streaming(
494+
self.llm,
495+
model_name=llm_config.model,
496+
provider_name=llm_config.engine,
497+
)
493498
else:
494499
model_name = f"{llm_config.type}_llm"
495500
if not hasattr(self, model_name):

nemoguardrails/server/api.py

Lines changed: 76 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,24 @@
2424
import uuid
2525
import warnings
2626
from contextlib import asynccontextmanager
27-
from typing import Any, AsyncIterator, Callable, List, Optional, Union
27+
from typing import Any, AsyncIterator, Callable, List, Optional
2828

2929
from fastapi import FastAPI, Request
3030
from fastapi.middleware.cors import CORSMiddleware
3131
from openai.types.chat.chat_completion import Choice
3232
from openai.types.chat.chat_completion_message import ChatCompletionMessage
33-
from openai.types.model import Model
3433
from pydantic import BaseModel, Field, root_validator, validator
3534
from starlette.responses import StreamingResponse
3635
from starlette.staticfiles import StaticFiles
3736

3837
from nemoguardrails import LLMRails, RailsConfig, utils
3938
from nemoguardrails.rails.llm.options import GenerationOptions, GenerationResponse
4039
from nemoguardrails.server.datastore.datastore import DataStore
41-
from nemoguardrails.server.schemas.openai import ModelsResponse, ResponseBody
40+
from nemoguardrails.server.schemas.openai import (
41+
GuardrailsModel,
42+
ModelsResponse,
43+
ResponseBody,
44+
)
4245
from nemoguardrails.streaming import StreamingHandler
4346

4447
logging.basicConfig(level=logging.INFO)
@@ -90,9 +93,9 @@ async def lifespan(app: GuardrailsApp):
9093

9194
# If there is a `config.yml` in the root `app.rails_config_path`, then
9295
# that means we are in single config mode.
93-
if os.path.exists(
94-
os.path.join(app.rails_config_path, "config.yml")
95-
) or os.path.exists(os.path.join(app.rails_config_path, "config.yaml")):
96+
if os.path.exists(os.path.join(app.rails_config_path, "config.yml")) or os.path.exists(
97+
os.path.join(app.rails_config_path, "config.yaml")
98+
):
9699
app.single_config_mode = True
97100
app.single_config_id = os.path.basename(app.rails_config_path)
98101
else:
@@ -232,8 +235,8 @@ class RequestBody(BaseModel):
232235
)
233236
# Standard OpenAI completion parameters
234237
model: Optional[str] = Field(
235-
default=None,
236-
description="The model to use for chat completion. Maps to config_id for backward compatibility.",
238+
default="main",
239+
description="The model to use for chat completion. Maps to the main model in the config.",
237240
)
238241
max_tokens: Optional[int] = Field(
239242
default=None,
@@ -278,15 +281,11 @@ def ensure_config_id(cls, data: Any) -> Any:
278281
if data.get("model") is not None and data.get("config_id") is None:
279282
data["config_id"] = data["model"]
280283
if data.get("config_id") is not None and data.get("config_ids") is not None:
281-
raise ValueError(
282-
"Only one of config_id or config_ids should be specified"
283-
)
284+
raise ValueError("Only one of config_id or config_ids should be specified")
284285
if data.get("config_id") is None and data.get("config_ids") is not None:
285286
data["config_id"] = None
286287
if data.get("config_id") is None and data.get("config_ids") is None:
287-
warnings.warn(
288-
"No config_id or config_ids provided, using default config_id"
289-
)
288+
warnings.warn("No config_id or config_ids provided, using default config_id")
290289
return data
291290

292291
@validator("config_ids", pre=True, always=True)
@@ -309,6 +308,7 @@ async def get_models():
309308
# Use the same logic as get_rails_configs to find available configurations
310309
if app.single_config_mode:
311310
config_ids = [app.single_config_id] if app.single_config_id else []
311+
312312
else:
313313
config_ids = [
314314
f
@@ -323,16 +323,43 @@ async def get_models():
323323
)
324324
]
325325

326-
# Convert configurations to OpenAI model format
327326
models = []
328327
for config_id in config_ids:
329-
model = Model(
330-
id=config_id,
331-
object="model",
332-
created=int(time.time()), # Use current time as created timestamp
333-
owned_by="nemo-guardrails",
334-
)
335-
models.append(model)
328+
try:
329+
# Load the RailsConfig to extract model information
330+
if app.single_config_mode:
331+
config_path = app.rails_config_path
332+
else:
333+
config_path = os.path.join(app.rails_config_path, config_id)
334+
335+
rails_config = RailsConfig.from_path(config_path)
336+
# Extract all models from this config
337+
config_models = rails_config.models
338+
339+
if len(config_models) == 0:
340+
guardrails_model = GuardrailsModel(
341+
id=config_id,
342+
object="model",
343+
created=int(time.time()),
344+
owned_by="nemo-guardrails",
345+
guardrails_config_id=config_id,
346+
)
347+
models.append(guardrails_model)
348+
else:
349+
for model in config_models:
350+
# Only include models with a model name
351+
if model.model:
352+
guardrails_model = GuardrailsModel(
353+
id=model.model,
354+
object="model",
355+
created=int(time.time()),
356+
owned_by="nemo-guardrails",
357+
guardrails_config_id=config_id,
358+
)
359+
models.append(guardrails_model)
360+
except Exception as ex:
361+
log.warning(f"Could not load model info for config {config_id}: {ex}")
362+
continue
336363

337364
return ModelsResponse(data=models)
338365

@@ -377,6 +404,14 @@ def _generate_cache_key(config_ids: List[str]) -> str:
377404
return "-".join((config_ids)) # remove sorted
378405

379406

407+
def _get_main_model_name(rails_config: RailsConfig) -> Optional[str]:
408+
"""Extracts the main model name from a RailsConfig."""
409+
main_models = [m for m in rails_config.models if m.type == "main"]
410+
if main_models and main_models[0].model:
411+
return main_models[0].model
412+
return None
413+
414+
380415
def _get_rails(config_ids: List[str]) -> LLMRails:
381416
"""Returns the rails instance for the given config id."""
382417

@@ -422,9 +457,7 @@ def _get_rails(config_ids: List[str]) -> LLMRails:
422457
llm_rails_instances[configs_cache_key] = llm_rails
423458

424459
# If we have a cache for the events, we restore it
425-
llm_rails.events_history_cache = llm_rails_events_history_cache.get(
426-
configs_cache_key, {}
427-
)
460+
llm_rails.events_history_cache = llm_rails_events_history_cache.get(configs_cache_key, {})
428461

429462
return llm_rails
430463

@@ -508,26 +541,24 @@ async def chat_completion(body: RequestBody, request: Request):
508541
"""
509542
log.info("Got request for config %s", body.config_id)
510543
for logger in registered_loggers:
511-
asyncio.get_event_loop().create_task(
512-
logger({"endpoint": "/v1/chat/completions", "body": body.json()})
513-
)
544+
asyncio.get_event_loop().create_task(logger({"endpoint": "/v1/chat/completions", "body": body.json()}))
514545

515546
# Save the request headers in a context variable.
516547
api_request_headers.set(request.headers)
517548

518549
# Use Request config_ids if set, otherwise use the FastAPI default config.
519550
# If neither is available we can't generate any completions as we have no config_id
520551
config_ids = body.config_ids
552+
521553
if not config_ids:
522554
if app.default_config_id:
523555
config_ids = [app.default_config_id]
524556
else:
525-
raise GuardrailsConfigurationError(
526-
"No request config_ids provided and server has no default configuration"
527-
)
557+
raise GuardrailsConfigurationError("No request config_ids provided and server has no default configuration")
528558

529559
try:
530560
llm_rails = _get_rails(config_ids)
561+
531562
except ValueError as ex:
532563
log.exception(ex)
533564
return ResponseBody(
@@ -550,6 +581,10 @@ async def chat_completion(body: RequestBody, request: Request):
550581
)
551582

552583
try:
584+
main_model_name = _get_main_model_name(llm_rails.config)
585+
if main_model_name is None:
586+
main_model_name = config_ids[0] if config_ids else "unknown"
587+
553588
messages = body.messages or []
554589
if body.context:
555590
messages.insert(0, {"role": "context", "content": body.context})
@@ -560,14 +595,13 @@ async def chat_completion(body: RequestBody, request: Request):
560595
if body.thread_id:
561596
if datastore is None:
562597
raise RuntimeError("No DataStore has been configured.")
563-
564598
# We make sure the `thread_id` meets the minimum complexity requirement.
565599
if len(body.thread_id) < 16:
566600
return ResponseBody(
567601
id=f"chatcmpl-{uuid.uuid4()}",
568602
object="chat.completion",
569603
created=int(time.time()),
570-
model=config_ids[0] if config_ids else "unknown",
604+
model=main_model_name,
571605
choices=[
572606
Choice(
573607
index=0,
@@ -608,12 +642,7 @@ async def chat_completion(body: RequestBody, request: Request):
608642
generation_options.llm_params["presence_penalty"] = body.presence_penalty
609643
if body.frequency_penalty is not None:
610644
generation_options.llm_params["frequency_penalty"] = body.frequency_penalty
611-
612-
if (
613-
body.stream
614-
and llm_rails.config.streaming_supported
615-
and llm_rails.main_llm_supports_streaming
616-
):
645+
if body.stream and llm_rails.config.streaming_supported and llm_rails.main_llm_supports_streaming:
617646
# Create the streaming handler instance
618647
streaming_handler = StreamingHandler()
619648

@@ -628,15 +657,11 @@ async def chat_completion(body: RequestBody, request: Request):
628657
)
629658

630659
return StreamingResponse(
631-
_format_streaming_response(
632-
streaming_handler, model_name=config_ids[0] if config_ids else None
633-
),
660+
_format_streaming_response(streaming_handler, model_name=main_model_name),
634661
media_type="text/event-stream",
635662
)
636663
else:
637-
res = await llm_rails.generate_async(
638-
messages=messages, options=generation_options, state=body.state
639-
)
664+
res = await llm_rails.generate_async(messages=messages, options=generation_options, state=body.state)
640665

641666
if isinstance(res, GenerationResponse):
642667
bot_message_content = res.response[0]
@@ -654,12 +679,12 @@ async def chat_completion(body: RequestBody, request: Request):
654679
if body.thread_id and datastore is not None and datastore_key is not None:
655680
await datastore.set(datastore_key, json.dumps(messages + [bot_message]))
656681

657-
# Build the response with OpenAI-compatible format plus NeMo-Guardrails extensions
682+
# Build the response with OpenAI-compatible format
658683
response_kwargs = {
659684
"id": f"chatcmpl-{uuid.uuid4()}",
660685
"object": "chat.completion",
661686
"created": int(time.time()),
662-
"model": config_ids[0] if config_ids else "unknown",
687+
"model": main_model_name,
663688
"choices": [
664689
Choice(
665690
index=0,
@@ -688,7 +713,7 @@ async def chat_completion(body: RequestBody, request: Request):
688713
id=f"chatcmpl-{uuid.uuid4()}",
689714
object="chat.completion",
690715
created=int(time.time()),
691-
model="unknown",
716+
model=config_ids[0] if config_ids else "unknown",
692717
choices=[
693718
Choice(
694719
index=0,
@@ -750,9 +775,7 @@ def on_any_event(self, event):
750775
return None
751776

752777
elif event.event_type == "created" or event.event_type == "modified":
753-
log.info(
754-
f"Watchdog received {event.event_type} event for file {event.src_path}"
755-
)
778+
log.info(f"Watchdog received {event.event_type} event for file {event.src_path}")
756779

757780
# Compute the relative path
758781
src_path_str = str(event.src_path)
@@ -776,9 +799,7 @@ def on_any_event(self, event):
776799
# We save the events history cache, to restore it on the new instance
777800
llm_rails_events_history_cache[config_id] = val
778801

779-
log.info(
780-
f"Configuration {config_id} has changed. Clearing cache."
781-
)
802+
log.info(f"Configuration {config_id} has changed. Clearing cache.")
782803

783804
observer = Observer()
784805
event_handler = Handler()
@@ -793,9 +814,7 @@ def on_any_event(self, event):
793814

794815
except ImportError:
795816
# Since this is running in a separate thread, we just print the error.
796-
print(
797-
"The auto-reload feature requires `watchdog`. Please install using `pip install watchdog`."
798-
)
817+
print("The auto-reload feature requires `watchdog`. Please install using `pip install watchdog`.")
799818
# Force close everything.
800819
os._exit(-1)
801820

nemoguardrails/server/schemas/openai.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,22 +25,27 @@
2525
class ResponseBody(ChatCompletion):
2626
"""OpenAI API response body with NeMo-Guardrails extensions."""
2727

28-
state: Optional[dict] = Field(
29-
default=None, description="State object for continuing the conversation."
30-
)
31-
llm_output: Optional[dict] = Field(
32-
default=None, description="Additional LLM output data."
33-
)
34-
output_data: Optional[dict] = Field(
35-
default=None, description="Additional output data."
28+
guardrails_config_id: Optional[str] = Field(
29+
default=None,
30+
description="The guardrails configuration ID associated with this response.",
3631
)
32+
state: Optional[dict] = Field(default=None, description="State object for continuing the conversation.")
33+
llm_output: Optional[dict] = Field(default=None, description="Additional LLM output data.")
34+
output_data: Optional[dict] = Field(default=None, description="Additional output data.")
3735
log: Optional[dict] = Field(default=None, description="Generation log data.")
3836

3937

40-
class ModelsResponse(BaseModel):
41-
"""OpenAI API models list response."""
38+
class GuardrailsModel(Model):
39+
"""OpenAI API model with NeMo-Guardrails extensions."""
4240

43-
object: str = Field(
44-
default="list", description="The object type, which is always 'list'."
41+
guardrails_config_id: Optional[str] = Field(
42+
default=None,
43+
description="[NeMo Guardrails extension] The guardrails configuration ID associated with this model.",
4544
)
46-
data: List[Model] = Field(description="The list of models.")
45+
46+
47+
class ModelsResponse(BaseModel):
48+
"""OpenAI API models list response with NeMo-Guardrails extensions."""
49+
50+
object: str = Field(default="list", description="The object type, which is always 'list'.")
51+
data: List[GuardrailsModel] = Field(description="The list of models.")

nemoguardrails/streaming.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,7 @@ async def _process(
196196
{
197197
"text": chunk,
198198
"generation_info": (
199-
self.current_generation_info.copy()
200-
if self.current_generation_info
201-
else {}
199+
self.current_generation_info.copy() if self.current_generation_info else {}
202200
),
203201
}
204202
)

poetry.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)