Skip to content

Commit b73985c

Browse files
committed
Test this
1 parent 22ada54 commit b73985c

File tree

1 file changed

+20
-22
lines changed

1 file changed

+20
-22
lines changed

app.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from huggingface_hub import snapshot_download
1515
from starlette.concurrency import run_in_threadpool
1616
from supabase import create_client, Client
17-
from peft import PeftConfig, LoraConfig, get_peft_model, PeftModel, TaskType
17+
from peft import PeftConfig, PeftModel
1818

1919
MODEL_REPO_ID = os.getenv("MODEL_REPO_ID")
2020
HF_TOKEN = os.getenv("HF_TOKEN")
@@ -127,8 +127,8 @@ def load_model():
127127
print("Loading LoRA model…")
128128

129129
print("Loading adapter config…")
130-
adapter_cfg = PeftConfig.from_pretrained(MODEL_PATH)
131-
base_model_name_or_path = adapter_cfg.base_model_name_or_path
130+
adapter_config = PeftConfig.from_pretrained(MODEL_PATH)
131+
base_model_name_or_path = adapter_config.base_model_name_or_path
132132

133133
print("Loading tokenizer…")
134134
tokenizer = GPT2TokenizerFast.from_pretrained(base_model_name_or_path)
@@ -137,27 +137,15 @@ def load_model():
137137
print("Loading base model…")
138138
base_model = AutoModelForCausalLM.from_pretrained(
139139
base_model_name_or_path,
140-
torch_dtype=torch.float16
140+
torch_dtype=torch.float32
141141
)
142-
143-
lora_cfg = LoraConfig(
144-
task_type=TaskType.CAUSAL_LM,
145-
inference_mode=True,
146-
r=adapter_cfg.r,
147-
lora_alpha=adapter_cfg.lora_alpha,
148-
lora_dropout=adapter_cfg.lora_dropout,
149-
bias=adapter_cfg.bias,
150-
target_modules=adapter_cfg.target_modules,
151-
)
152-
153-
print("Attaching LoRA adapter & merging weights…")
154-
peft_wrapped = get_peft_model(base_model, lora_cfg)
155-
peft_wrapped = PeftModel.from_pretrained(
156-
peft_wrapped,
157-
MODEL_PATH,
158-
local_files_only=True,
142+
model = PeftModel.from_pretrained(
143+
base_model,
144+
MODEL_PATH,
145+
torch_dtype=torch.float32,
146+
local_files_only=True
159147
)
160-
model = peft_wrapped.merge_and_unload()
148+
model = model.merge_and_unload()
161149

162150
else:
163151
print("Loading vanilla model…")
@@ -169,6 +157,16 @@ def load_model():
169157
print("Loading base model…")
170158
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH)
171159

160+
print("Supported quantization engines:", torch.backends.quantized.supported_engines)
161+
torch.backends.quantized.engine = 'qnnpack'
162+
163+
print("Quantizing model…")
164+
model = torch.quantization.quantize_dynamic(
165+
model,
166+
{torch.nn.Linear},
167+
dtype=torch.qint8
168+
)
169+
172170
model.eval()
173171

174172
print("Warming up (1 token)…")

0 commit comments

Comments
 (0)