|
337 | 337 |
|
338 | 338 | fused_weights = not args.unfuse_weights |
339 | 339 | if args.quantization == "gptq": |
| 340 | + if fused_weights and is_aiu_backend: |
| 341 | + raise ValueError("GPTQ checkpoints on AIU must always run with --unfuse_weights") |
| 342 | + if default_dtype is not None: |
| 343 | + raise ValueError( |
| 344 | + "GPTQ default_dtype must be None to preserve the checkpoint data types." |
| 345 | + ) |
| 346 | + |
340 | 347 | if "aiu" in args.device_type: |
341 | 348 | linear_type = "gptq_aiu" |
342 | 349 | elif args.device_type == "cpu": |
|
370 | 377 | "group_size": group_size, |
371 | 378 | "desc_act": desc_act, |
372 | 379 | } |
373 | | - # [ATTENTION] for GPTQ on AIU, we must always instantiate an unfused |
374 | | - # model, the adapter will take care of converting key/values from |
375 | | - # ckpt into the appropriate form for the model |
376 | | - if fused_weights and is_aiu_backend: |
377 | | - raise ValueError("GPTQ checkpoints on AIU must always run with --unfuse_weights") |
378 | | - default_dtype = None # GPTQ dtype always comes from ckpt, can't be enforced |
379 | 380 | elif args.quantization == "int8": |
| 381 | + if fused_weights and is_aiu_backend: |
| 382 | + raise ValueError("INT8 checkpoints on AIU must always run with --unfuse_weights") |
| 383 | + if default_dtype is not None: |
| 384 | + raise ValueError( |
| 385 | + "INT8 default_dtype must be None to preserve the checkpoint data types." |
| 386 | + ) |
| 387 | + |
380 | 388 | def select_int8_module( |
381 | 389 | module_name: str | None = None, |
382 | 390 | smoothquant: bool = True, |
@@ -414,12 +422,6 @@ def select_int8_module( |
414 | 422 | "weight_per_channel": args.int8_weight_per_channel, |
415 | 423 | "activ_quant_type": args.int8_activ_quant_type, |
416 | 424 | } |
417 | | - if fused_weights and is_aiu_backend: |
418 | | - raise ValueError("INT8 checkpoints on AIU must always run with --unfuse_weights") |
419 | | - if default_dtype is not None: |
420 | | - raise ValueError( |
421 | | - "INT8 default_dtype must be None to preserve the checkpoint data types." |
422 | | - ) |
423 | 425 | else: |
424 | 426 | linear_config = {"linear_type": "torch_linear"} |
425 | 427 |
|
|
0 commit comments