Skip to content

Commit e19b055

Browse files
committed
Relocate checks on get_model arguments for quantized models
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
1 parent be62d32 commit e19b055

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

scripts/inference.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,13 @@
337337

338338
fused_weights = not args.unfuse_weights
339339
if args.quantization == "gptq":
340+
if fused_weights and is_aiu_backend:
341+
raise ValueError("GPTQ checkpoints on AIU must always run with --unfuse_weights")
342+
if default_dtype is not None:
343+
raise ValueError(
344+
"GPTQ default_dtype must be None to preserve the checkpoint data types."
345+
)
346+
340347
if "aiu" in args.device_type:
341348
linear_type = "gptq_aiu"
342349
elif args.device_type == "cpu":
@@ -370,13 +377,14 @@
370377
"group_size": group_size,
371378
"desc_act": desc_act,
372379
}
373-
# [ATTENTION] for GPTQ on AIU, we must always instantiate an unfused
374-
# model, the adapter will take care of converting key/values from
375-
# ckpt into the appropriate form for the model
376-
if fused_weights and is_aiu_backend:
377-
raise ValueError("GPTQ checkpoints on AIU must always run with --unfuse_weights")
378-
default_dtype = None # GPTQ dtype always comes from ckpt, can't be enforced
379380
elif args.quantization == "int8":
381+
if fused_weights and is_aiu_backend:
382+
raise ValueError("INT8 checkpoints on AIU must always run with --unfuse_weights")
383+
if default_dtype is not None:
384+
raise ValueError(
385+
"INT8 default_dtype must be None to preserve the checkpoint data types."
386+
)
387+
380388
def select_int8_module(
381389
module_name: str | None = None,
382390
smoothquant: bool = True,
@@ -414,12 +422,6 @@ def select_int8_module(
414422
"weight_per_channel": args.int8_weight_per_channel,
415423
"activ_quant_type": args.int8_activ_quant_type,
416424
}
417-
if fused_weights and is_aiu_backend:
418-
raise ValueError("INT8 checkpoints on AIU must always run with --unfuse_weights")
419-
if default_dtype is not None:
420-
raise ValueError(
421-
"INT8 default_dtype must be None to preserve the checkpoint data types."
422-
)
423425
else:
424426
linear_config = {"linear_type": "torch_linear"}
425427

0 commit comments

Comments
 (0)