1+ from typing import Iterator , List , Union
2+
13import structlog
2- from llama_cpp import Llama
4+ from llama_cpp import (
5+ CreateChatCompletionResponse ,
6+ CreateChatCompletionStreamResponse ,
7+ CreateCompletionResponse ,
8+ CreateCompletionStreamResponse ,
9+ Llama ,
10+ )
311
412logger = structlog .get_logger ("codegate" )
513
@@ -35,7 +43,9 @@ def _close_models(self):
3543 model ._sampler .close ()
3644 model .close ()
3745
38- async def __get_model (self , model_path , embedding = False , n_ctx = 512 , n_gpu_layers = 0 ) -> Llama :
46+ async def __get_model (
47+ self , model_path : str , embedding : bool = False , n_ctx : int = 512 , n_gpu_layers : int = 0
48+ ) -> Llama :
3949 """
4050 Returns Llama model object from __models if present. Otherwise, the model
4151 is loaded and added to __models and returned.
@@ -55,7 +65,9 @@ async def __get_model(self, model_path, embedding=False, n_ctx=512, n_gpu_layers
5565
5666 return self .__models [model_path ]
5767
58- async def complete (self , model_path , n_ctx = 512 , n_gpu_layers = 0 , ** completion_request ):
68+ async def complete (
69+ self , model_path : str , n_ctx : int = 512 , n_gpu_layers : int = 0 , ** completion_request
70+ ) -> Union [CreateCompletionResponse , Iterator [CreateCompletionStreamResponse ]]:
5971 """
6072 Generates a chat completion using the specified model and request parameters.
6173 """
@@ -64,7 +76,9 @@ async def complete(self, model_path, n_ctx=512, n_gpu_layers=0, **completion_req
6476 )
6577 return model .create_completion (** completion_request )
6678
67- async def chat (self , model_path , n_ctx = 512 , n_gpu_layers = 0 , ** chat_completion_request ):
79+ async def chat (
80+ self , model_path : str , n_ctx : int = 512 , n_gpu_layers : int = 0 , ** chat_completion_request
81+ ) -> Union [CreateChatCompletionResponse , Iterator [CreateChatCompletionStreamResponse ]]:
6882 """
6983 Generates a chat completion using the specified model and request parameters.
7084 """
@@ -73,18 +87,20 @@ async def chat(self, model_path, n_ctx=512, n_gpu_layers=0, **chat_completion_re
7387 )
7488 return model .create_chat_completion (** chat_completion_request )
7589
76- async def embed (self , model_path , content ) :
90+ async def embed (self , model_path : str , content : List [ str ], n_gpu_layers = 0 ) -> List [ List [ float ]] :
7791 """
7892 Generates an embedding for the given content using the specified model.
7993 """
8094 logger .debug (
8195 "Generating embedding" ,
8296 model = model_path .split ("/" )[- 1 ],
83- content = content ,
97+ content = content [ 0 ][ 0 : min ( 100 , len ( content [ 0 ]))] ,
8498 content_length = len (content [0 ]) if content else 0 ,
8599 )
86100
87- model = await self .__get_model (model_path = model_path , embedding = True )
101+ model = await self .__get_model (
102+ model_path = model_path , embedding = True , n_gpu_layers = n_gpu_layers
103+ )
88104 embedding = model .embed (content )
89105
90106 logger .debug (
0 commit comments