2424import uuid
2525import warnings
2626from contextlib import asynccontextmanager
27- from typing import Any , AsyncIterator , Callable , List , Optional , Union
27+ from typing import Any , AsyncIterator , Callable , List , Optional
2828
2929from fastapi import FastAPI , Request
3030from fastapi .middleware .cors import CORSMiddleware
3131from openai .types .chat .chat_completion import Choice
3232from openai .types .chat .chat_completion_message import ChatCompletionMessage
33- from openai .types .model import Model
3433from pydantic import BaseModel , Field , root_validator , validator
3534from starlette .responses import StreamingResponse
3635from starlette .staticfiles import StaticFiles
3736
3837from nemoguardrails import LLMRails , RailsConfig , utils
3938from nemoguardrails .rails .llm .options import GenerationOptions , GenerationResponse
4039from nemoguardrails .server .datastore .datastore import DataStore
41- from nemoguardrails .server .schemas .openai import ModelsResponse , ResponseBody
40+ from nemoguardrails .server .schemas .openai import (
41+ GuardrailsModel ,
42+ ModelsResponse ,
43+ ResponseBody ,
44+ )
4245from nemoguardrails .streaming import StreamingHandler
4346
4447logging .basicConfig (level = logging .INFO )
@@ -90,9 +93,9 @@ async def lifespan(app: GuardrailsApp):
9093
9194 # If there is a `config.yml` in the root `app.rails_config_path`, then
9295 # that means we are in single config mode.
93- if os .path .exists (
94- os .path .join (app .rails_config_path , "config.yml " )
95- ) or os . path . exists ( os . path . join ( app . rails_config_path , "config.yaml" )) :
96+ if os .path .exists (os . path . join ( app . rails_config_path , "config.yml" )) or os . path . exists (
97+ os .path .join (app .rails_config_path , "config.yaml " )
98+ ):
9699 app .single_config_mode = True
97100 app .single_config_id = os .path .basename (app .rails_config_path )
98101 else :
@@ -232,8 +235,8 @@ class RequestBody(BaseModel):
232235 )
233236 # Standard OpenAI completion parameters
234237 model : Optional [str ] = Field (
235- default = None ,
236- description = "The model to use for chat completion. Maps to config_id for backward compatibility ." ,
238+ default = "main" ,
239+ description = "The model to use for chat completion. Maps to the main model in the config ." ,
237240 )
238241 max_tokens : Optional [int ] = Field (
239242 default = None ,
@@ -278,15 +281,11 @@ def ensure_config_id(cls, data: Any) -> Any:
278281 if data .get ("model" ) is not None and data .get ("config_id" ) is None :
279282 data ["config_id" ] = data ["model" ]
280283 if data .get ("config_id" ) is not None and data .get ("config_ids" ) is not None :
281- raise ValueError (
282- "Only one of config_id or config_ids should be specified"
283- )
284+ raise ValueError ("Only one of config_id or config_ids should be specified" )
284285 if data .get ("config_id" ) is None and data .get ("config_ids" ) is not None :
285286 data ["config_id" ] = None
286287 if data .get ("config_id" ) is None and data .get ("config_ids" ) is None :
287- warnings .warn (
288- "No config_id or config_ids provided, using default config_id"
289- )
288+ warnings .warn ("No config_id or config_ids provided, using default config_id" )
290289 return data
291290
292291 @validator ("config_ids" , pre = True , always = True )
@@ -309,6 +308,7 @@ async def get_models():
309308 # Use the same logic as get_rails_configs to find available configurations
310309 if app .single_config_mode :
311310 config_ids = [app .single_config_id ] if app .single_config_id else []
311+
312312 else :
313313 config_ids = [
314314 f
@@ -323,16 +323,43 @@ async def get_models():
323323 )
324324 ]
325325
326- # Convert configurations to OpenAI model format
327326 models = []
328327 for config_id in config_ids :
329- model = Model (
330- id = config_id ,
331- object = "model" ,
332- created = int (time .time ()), # Use current time as created timestamp
333- owned_by = "nemo-guardrails" ,
334- )
335- models .append (model )
328+ try :
329+ # Load the RailsConfig to extract model information
330+ if app .single_config_mode :
331+ config_path = app .rails_config_path
332+ else :
333+ config_path = os .path .join (app .rails_config_path , config_id )
334+
335+ rails_config = RailsConfig .from_path (config_path )
336+ # Extract all models from this config
337+ config_models = rails_config .models
338+
339+ if len (config_models ) == 0 :
340+ guardrails_model = GuardrailsModel (
341+ id = config_id ,
342+ object = "model" ,
343+ created = int (time .time ()),
344+ owned_by = "nemo-guardrails" ,
345+ guardrails_config_id = config_id ,
346+ )
347+ models .append (guardrails_model )
348+ else :
349+ for model in config_models :
350+ # Only include models with a model name
351+ if model .model :
352+ guardrails_model = GuardrailsModel (
353+ id = model .model ,
354+ object = "model" ,
355+ created = int (time .time ()),
356+ owned_by = "nemo-guardrails" ,
357+ guardrails_config_id = config_id ,
358+ )
359+ models .append (guardrails_model )
360+ except Exception as ex :
361+ log .warning (f"Could not load model info for config { config_id } : { ex } " )
362+ continue
336363
337364 return ModelsResponse (data = models )
338365
@@ -377,6 +404,14 @@ def _generate_cache_key(config_ids: List[str]) -> str:
377404 return "-" .join ((config_ids )) # remove sorted
378405
379406
407+ def _get_main_model_name (rails_config : RailsConfig ) -> Optional [str ]:
408+ """Extracts the main model name from a RailsConfig."""
409+ main_models = [m for m in rails_config .models if m .type == "main" ]
410+ if main_models and main_models [0 ].model :
411+ return main_models [0 ].model
412+ return None
413+
414+
380415def _get_rails (config_ids : List [str ]) -> LLMRails :
381416 """Returns the rails instance for the given config id."""
382417
@@ -422,9 +457,7 @@ def _get_rails(config_ids: List[str]) -> LLMRails:
422457 llm_rails_instances [configs_cache_key ] = llm_rails
423458
424459 # If we have a cache for the events, we restore it
425- llm_rails .events_history_cache = llm_rails_events_history_cache .get (
426- configs_cache_key , {}
427- )
460+ llm_rails .events_history_cache = llm_rails_events_history_cache .get (configs_cache_key , {})
428461
429462 return llm_rails
430463
@@ -508,26 +541,24 @@ async def chat_completion(body: RequestBody, request: Request):
508541 """
509542 log .info ("Got request for config %s" , body .config_id )
510543 for logger in registered_loggers :
511- asyncio .get_event_loop ().create_task (
512- logger ({"endpoint" : "/v1/chat/completions" , "body" : body .json ()})
513- )
544+ asyncio .get_event_loop ().create_task (logger ({"endpoint" : "/v1/chat/completions" , "body" : body .json ()}))
514545
515546 # Save the request headers in a context variable.
516547 api_request_headers .set (request .headers )
517548
518549 # Use Request config_ids if set, otherwise use the FastAPI default config.
519550 # If neither is available we can't generate any completions as we have no config_id
520551 config_ids = body .config_ids
552+
521553 if not config_ids :
522554 if app .default_config_id :
523555 config_ids = [app .default_config_id ]
524556 else :
525- raise GuardrailsConfigurationError (
526- "No request config_ids provided and server has no default configuration"
527- )
557+ raise GuardrailsConfigurationError ("No request config_ids provided and server has no default configuration" )
528558
529559 try :
530560 llm_rails = _get_rails (config_ids )
561+
531562 except ValueError as ex :
532563 log .exception (ex )
533564 return ResponseBody (
@@ -550,6 +581,10 @@ async def chat_completion(body: RequestBody, request: Request):
550581 )
551582
552583 try :
584+ main_model_name = _get_main_model_name (llm_rails .config )
585+ if main_model_name is None :
586+ main_model_name = config_ids [0 ] if config_ids else "unknown"
587+
553588 messages = body .messages or []
554589 if body .context :
555590 messages .insert (0 , {"role" : "context" , "content" : body .context })
@@ -560,14 +595,13 @@ async def chat_completion(body: RequestBody, request: Request):
560595 if body .thread_id :
561596 if datastore is None :
562597 raise RuntimeError ("No DataStore has been configured." )
563-
564598 # We make sure the `thread_id` meets the minimum complexity requirement.
565599 if len (body .thread_id ) < 16 :
566600 return ResponseBody (
567601 id = f"chatcmpl-{ uuid .uuid4 ()} " ,
568602 object = "chat.completion" ,
569603 created = int (time .time ()),
570- model = config_ids [ 0 ] if config_ids else "unknown" ,
604+ model = main_model_name ,
571605 choices = [
572606 Choice (
573607 index = 0 ,
@@ -608,12 +642,7 @@ async def chat_completion(body: RequestBody, request: Request):
608642 generation_options .llm_params ["presence_penalty" ] = body .presence_penalty
609643 if body .frequency_penalty is not None :
610644 generation_options .llm_params ["frequency_penalty" ] = body .frequency_penalty
611-
612- if (
613- body .stream
614- and llm_rails .config .streaming_supported
615- and llm_rails .main_llm_supports_streaming
616- ):
645+ if body .stream and llm_rails .config .streaming_supported and llm_rails .main_llm_supports_streaming :
617646 # Create the streaming handler instance
618647 streaming_handler = StreamingHandler ()
619648
@@ -628,15 +657,11 @@ async def chat_completion(body: RequestBody, request: Request):
628657 )
629658
630659 return StreamingResponse (
631- _format_streaming_response (
632- streaming_handler , model_name = config_ids [0 ] if config_ids else None
633- ),
660+ _format_streaming_response (streaming_handler , model_name = main_model_name ),
634661 media_type = "text/event-stream" ,
635662 )
636663 else :
637- res = await llm_rails .generate_async (
638- messages = messages , options = generation_options , state = body .state
639- )
664+ res = await llm_rails .generate_async (messages = messages , options = generation_options , state = body .state )
640665
641666 if isinstance (res , GenerationResponse ):
642667 bot_message_content = res .response [0 ]
@@ -654,12 +679,12 @@ async def chat_completion(body: RequestBody, request: Request):
654679 if body .thread_id and datastore is not None and datastore_key is not None :
655680 await datastore .set (datastore_key , json .dumps (messages + [bot_message ]))
656681
657- # Build the response with OpenAI-compatible format plus NeMo-Guardrails extensions
682+ # Build the response with OpenAI-compatible format
658683 response_kwargs = {
659684 "id" : f"chatcmpl-{ uuid .uuid4 ()} " ,
660685 "object" : "chat.completion" ,
661686 "created" : int (time .time ()),
662- "model" : config_ids [ 0 ] if config_ids else "unknown" ,
687+ "model" : main_model_name ,
663688 "choices" : [
664689 Choice (
665690 index = 0 ,
@@ -688,7 +713,7 @@ async def chat_completion(body: RequestBody, request: Request):
688713 id = f"chatcmpl-{ uuid .uuid4 ()} " ,
689714 object = "chat.completion" ,
690715 created = int (time .time ()),
691- model = "unknown" ,
716+ model = config_ids [ 0 ] if config_ids else "unknown" ,
692717 choices = [
693718 Choice (
694719 index = 0 ,
@@ -750,9 +775,7 @@ def on_any_event(self, event):
750775 return None
751776
752777 elif event .event_type == "created" or event .event_type == "modified" :
753- log .info (
754- f"Watchdog received { event .event_type } event for file { event .src_path } "
755- )
778+ log .info (f"Watchdog received { event .event_type } event for file { event .src_path } " )
756779
757780 # Compute the relative path
758781 src_path_str = str (event .src_path )
@@ -776,9 +799,7 @@ def on_any_event(self, event):
776799 # We save the events history cache, to restore it on the new instance
777800 llm_rails_events_history_cache [config_id ] = val
778801
779- log .info (
780- f"Configuration { config_id } has changed. Clearing cache."
781- )
802+ log .info (f"Configuration { config_id } has changed. Clearing cache." )
782803
783804 observer = Observer ()
784805 event_handler = Handler ()
@@ -793,9 +814,7 @@ def on_any_event(self, event):
793814
794815 except ImportError :
795816 # Since this is running in a separate thread, we just print the error.
796- print (
797- "The auto-reload feature requires `watchdog`. Please install using `pip install watchdog`."
798- )
817+ print ("The auto-reload feature requires `watchdog`. Please install using `pip install watchdog`." )
799818 # Force close everything.
800819 os ._exit (- 1 )
801820
0 commit comments