2424import uuid
2525import warnings
2626from contextlib import asynccontextmanager
27- from typing import Any , AsyncIterator , Callable , List , Optional , Union
27+ from typing import Any , AsyncIterator , Callable , List , Optional
2828
2929from fastapi import FastAPI , Request
3030from fastapi .middleware .cors import CORSMiddleware
3131from openai .types .chat .chat_completion import Choice
3232from openai .types .chat .chat_completion_message import ChatCompletionMessage
33- from openai .types .model import Model
3433from pydantic import BaseModel , Field , root_validator , validator
3534from starlette .responses import StreamingResponse
3635from starlette .staticfiles import StaticFiles
3736
3837from nemoguardrails import LLMRails , RailsConfig , utils
3938from nemoguardrails .rails .llm .options import GenerationOptions , GenerationResponse
4039from nemoguardrails .server .datastore .datastore import DataStore
41- from nemoguardrails .server .schemas .openai import ModelsResponse , ResponseBody
40+ from nemoguardrails .server .schemas .openai import (
41+ GuardrailsModel ,
42+ ModelsResponse ,
43+ ResponseBody ,
44+ )
4245from nemoguardrails .streaming import StreamingHandler
4346
4447logging .basicConfig (level = logging .INFO )
@@ -232,8 +235,8 @@ class RequestBody(BaseModel):
232235 )
233236 # Standard OpenAI completion parameters
234237 model : Optional [str ] = Field (
235- default = None ,
236- description = "The model to use for chat completion. Maps to config_id for backward compatibility ." ,
238+ default = "main" ,
239+ description = "The model to use for chat completion. Maps to the main model in the config ." ,
237240 )
238241 max_tokens : Optional [int ] = Field (
239242 default = None ,
@@ -309,6 +312,7 @@ async def get_models():
309312 # Use the same logic as get_rails_configs to find available configurations
310313 if app .single_config_mode :
311314 config_ids = [app .single_config_id ] if app .single_config_id else []
315+
312316 else :
313317 config_ids = [
314318 f
@@ -323,16 +327,43 @@ async def get_models():
323327 )
324328 ]
325329
326- # Convert configurations to OpenAI model format
327330 models = []
328331 for config_id in config_ids :
329- model = Model (
330- id = config_id ,
331- object = "model" ,
332- created = int (time .time ()), # Use current time as created timestamp
333- owned_by = "nemo-guardrails" ,
334- )
335- models .append (model )
332+ try :
333+ # Load the RailsConfig to extract model information
334+ if app .single_config_mode :
335+ config_path = app .rails_config_path
336+ else :
337+ config_path = os .path .join (app .rails_config_path , config_id )
338+
339+ rails_config = RailsConfig .from_path (config_path )
340+ # Extract all models from this config
341+ config_models = rails_config .models
342+
343+ if len (config_models ) == 0 :
344+ guardrails_model = GuardrailsModel (
345+ id = config_id ,
346+ object = "model" ,
347+ created = int (time .time ()),
348+ owned_by = "nemo-guardrails" ,
349+ guardrails_config_id = config_id ,
350+ )
351+ models .append (guardrails_model )
352+ else :
353+ for model in config_models :
354+ # Only include models with a model name
355+ if model .model :
356+ guardrails_model = GuardrailsModel (
357+ id = model .model ,
358+ object = "model" ,
359+ created = int (time .time ()),
360+ owned_by = "nemo-guardrails" ,
361+ guardrails_config_id = config_id ,
362+ )
363+ models .append (guardrails_model )
364+ except Exception as ex :
365+ log .warning (f"Could not load model info for config { config_id } : { ex } " )
366+ continue
336367
337368 return ModelsResponse (data = models )
338369
@@ -377,6 +408,14 @@ def _generate_cache_key(config_ids: List[str]) -> str:
377408 return "-" .join ((config_ids )) # remove sorted
378409
379410
411+ def _get_main_model_name (rails_config : RailsConfig ) -> Optional [str ]:
412+ """Extracts the main model name from a RailsConfig."""
413+ main_models = [m for m in rails_config .models if m .type == "main" ]
414+ if main_models and main_models [0 ].model :
415+ return main_models [0 ].model
416+ return None
417+
418+
380419def _get_rails (config_ids : List [str ]) -> LLMRails :
381420 """Returns the rails instance for the given config id."""
382421
@@ -518,6 +557,7 @@ async def chat_completion(body: RequestBody, request: Request):
518557 # Use Request config_ids if set, otherwise use the FastAPI default config.
519558 # If neither is available we can't generate any completions as we have no config_id
520559 config_ids = body .config_ids
560+
521561 if not config_ids :
522562 if app .default_config_id :
523563 config_ids = [app .default_config_id ]
@@ -528,6 +568,7 @@ async def chat_completion(body: RequestBody, request: Request):
528568
529569 try :
530570 llm_rails = _get_rails (config_ids )
571+
531572 except ValueError as ex :
532573 log .exception (ex )
533574 return ResponseBody (
@@ -550,6 +591,10 @@ async def chat_completion(body: RequestBody, request: Request):
550591 )
551592
552593 try :
594+ main_model_name = _get_main_model_name (llm_rails .config )
595+ if main_model_name is None :
596+ main_model_name = config_ids [0 ] if config_ids else "unknown"
597+
553598 messages = body .messages or []
554599 if body .context :
555600 messages .insert (0 , {"role" : "context" , "content" : body .context })
@@ -560,14 +605,13 @@ async def chat_completion(body: RequestBody, request: Request):
560605 if body .thread_id :
561606 if datastore is None :
562607 raise RuntimeError ("No DataStore has been configured." )
563-
564608 # We make sure the `thread_id` meets the minimum complexity requirement.
565609 if len (body .thread_id ) < 16 :
566610 return ResponseBody (
567611 id = f"chatcmpl-{ uuid .uuid4 ()} " ,
568612 object = "chat.completion" ,
569613 created = int (time .time ()),
570- model = config_ids [ 0 ] if config_ids else "unknown" ,
614+ model = main_model_name ,
571615 choices = [
572616 Choice (
573617 index = 0 ,
@@ -608,7 +652,6 @@ async def chat_completion(body: RequestBody, request: Request):
608652 generation_options .llm_params ["presence_penalty" ] = body .presence_penalty
609653 if body .frequency_penalty is not None :
610654 generation_options .llm_params ["frequency_penalty" ] = body .frequency_penalty
611-
612655 if (
613656 body .stream
614657 and llm_rails .config .streaming_supported
@@ -629,7 +672,7 @@ async def chat_completion(body: RequestBody, request: Request):
629672
630673 return StreamingResponse (
631674 _format_streaming_response (
632- streaming_handler , model_name = config_ids [ 0 ] if config_ids else None
675+ streaming_handler , model_name = main_model_name
633676 ),
634677 media_type = "text/event-stream" ,
635678 )
@@ -654,12 +697,12 @@ async def chat_completion(body: RequestBody, request: Request):
654697 if body .thread_id and datastore is not None and datastore_key is not None :
655698 await datastore .set (datastore_key , json .dumps (messages + [bot_message ]))
656699
657- # Build the response with OpenAI-compatible format plus NeMo-Guardrails extensions
700+ # Build the response with OpenAI-compatible format
658701 response_kwargs = {
659702 "id" : f"chatcmpl-{ uuid .uuid4 ()} " ,
660703 "object" : "chat.completion" ,
661704 "created" : int (time .time ()),
662- "model" : config_ids [ 0 ] if config_ids else "unknown" ,
705+ "model" : main_model_name ,
663706 "choices" : [
664707 Choice (
665708 index = 0 ,
@@ -688,7 +731,7 @@ async def chat_completion(body: RequestBody, request: Request):
688731 id = f"chatcmpl-{ uuid .uuid4 ()} " ,
689732 object = "chat.completion" ,
690733 created = int (time .time ()),
691- model = "unknown" ,
734+ model = config_ids [ 0 ] if config_ids else "unknown" ,
692735 choices = [
693736 Choice (
694737 index = 0 ,
0 commit comments