1414from huggingface_hub import snapshot_download
1515from starlette .concurrency import run_in_threadpool
1616from supabase import create_client , Client
17- from peft import PeftConfig , LoraConfig , get_peft_model , PeftModel , TaskType
17+ from peft import PeftConfig , PeftModel
1818
1919MODEL_REPO_ID = os .getenv ("MODEL_REPO_ID" )
2020HF_TOKEN = os .getenv ("HF_TOKEN" )
@@ -127,8 +127,8 @@ def load_model():
127127 print ("Loading LoRA model…" )
128128
129129 print ("Loading adapter config…" )
130- adapter_cfg = PeftConfig .from_pretrained (MODEL_PATH )
131- base_model_name_or_path = adapter_cfg .base_model_name_or_path
130+ adapter_config = PeftConfig .from_pretrained (MODEL_PATH )
131+ base_model_name_or_path = adapter_config .base_model_name_or_path
132132
133133 print ("Loading tokenizer…" )
134134 tokenizer = GPT2TokenizerFast .from_pretrained (base_model_name_or_path )
@@ -137,27 +137,15 @@ def load_model():
137137 print ("Loading base model…" )
138138 base_model = AutoModelForCausalLM .from_pretrained (
139139 base_model_name_or_path ,
140- torch_dtype = torch .float16
140+ torch_dtype = torch .float32
141141 )
142-
143- lora_cfg = LoraConfig (
144- task_type = TaskType .CAUSAL_LM ,
145- inference_mode = True ,
146- r = adapter_cfg .r ,
147- lora_alpha = adapter_cfg .lora_alpha ,
148- lora_dropout = adapter_cfg .lora_dropout ,
149- bias = adapter_cfg .bias ,
150- target_modules = adapter_cfg .target_modules ,
151- )
152-
153- print ("Attaching LoRA adapter & merging weights…" )
154- peft_wrapped = get_peft_model (base_model , lora_cfg )
155- peft_wrapped = PeftModel .from_pretrained (
156- peft_wrapped ,
157- MODEL_PATH ,
158- local_files_only = True ,
142+ model = PeftModel .from_pretrained (
143+ base_model ,
144+ MODEL_PATH ,
145+ torch_dtype = torch .float32 ,
146+ local_files_only = True
159147 )
160- model = peft_wrapped .merge_and_unload ()
148+ model = model .merge_and_unload ()
161149
162150 else :
163151 print ("Loading vanilla model…" )
@@ -169,6 +157,16 @@ def load_model():
169157 print ("Loading base model…" )
170158 model = AutoModelForCausalLM .from_pretrained (MODEL_PATH )
171159
160+ print ("Supported quantization engines:" , torch .backends .quantized .supported_engines )
161+ torch .backends .quantized .engine = 'qnnpack'
162+
163+ print ("Quantizing model…" )
164+ model = torch .quantization .quantize_dynamic (
165+ model ,
166+ {torch .nn .Linear },
167+ dtype = torch .qint8
168+ )
169+
172170 model .eval ()
173171
174172 print ("Warming up (1 token)…" )
0 commit comments