33import re
44import shutil
55import subprocess
6+ from typing import Literal
67
78# from dsp.modules.adapter import TurboAdapter, DavinciAdapter, LlamaAdapter
89import backoff
@@ -114,7 +115,7 @@ def send_hftgi_request_v00(arg, **kwargs):
114115
115116
116117class HFClientVLLM (HFModel ):
117- def __init__ (self , model , port , url = "http://localhost" , ** kwargs ):
118+ def __init__ (self , model , port , model_type : Literal [ 'chat' , 'text' ] = 'text' , url = "http://localhost" , ** kwargs ):
118119 super ().__init__ (model = model , is_client = True )
119120
120121 if isinstance (url , list ):
@@ -126,27 +127,24 @@ def __init__(self, model, port, url="http://localhost", **kwargs):
126127 else :
127128 raise ValueError (f"The url provided to `HFClientVLLM` is neither a string nor a list of strings. It is of type { type (url )} ." )
128129
130+ self .model_type = model_type
129131 self .headers = {"Content-Type" : "application/json" }
130132 self .kwargs |= kwargs
131133 # kwargs needs to have model, port and url for the lm.copy() to work properly
132134 self .kwargs .update ({
133- 'model' : model ,
134135 'port' : port ,
135- 'url' : url
136+ 'url' : url ,
136137 })
137138
138139
139140 def _generate (self , prompt , ** kwargs ):
140141 kwargs = {** self .kwargs , ** kwargs }
141142
142- # get model_type
143- model_type = kwargs .get ("model_type" ,None )
144-
145143 # Round robin the urls.
146144 url = self .urls .pop (0 )
147145 self .urls .append (url )
148146
149- if model_type == "chat" :
147+ if self . model_type == "chat" :
150148 system_prompt = kwargs .get ("system_prompt" ,None )
151149 messages = [{"role" : "user" , "content" : prompt }]
152150 if system_prompt :
0 commit comments