diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index ce3fb0853..f31d11b8f 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -327,6 +327,12 @@ def get_model( device_map=device_map, **model_kwargs, ) + elif hf_config.quantization_config.get("format", None) == "pack-quantized": + model = AutoModelForCausalLM.from_pretrained( + ckpt_path, + device_map="auto", + trust_remote_code=trust_remote_code, + ) else: architecture = hf_config.architectures[0] @@ -346,9 +352,9 @@ def get_model( from_config = auto_model_module._from_config with init_empty_weights(): - # When computing the device_map, assuming half precision by default, + # When computing the device_map, assuming bfloat16 precision by default, # unless specified by the hf_config. - torch_dtype = getattr(hf_config, "torch_dtype", torch.float16) + torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16) model_kwargs2 = model_kwargs.copy() if auto_model_module != AutoModelForCausalLM: model_kwargs2.pop("trust_remote_code", None) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 57f0b5a89..d34f9fdbb 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -511,7 +511,12 @@ def main(args): ][0:1] # Generate preview before quantization - if is_nemotron_vl_model and tokenizer is not None: + if model_type == "deepseek": + print( + "Deepseek model may hit OOM during preview generation. Skipping preview generation." + ) + generated_ids_before_ptq = None + elif is_nemotron_vl_model and tokenizer is not None: generated_ids_before_ptq = run_nemotron_vl_preview( full_model, tokenizer, @@ -523,6 +528,7 @@ def main(args): else: # Standard generation for non-Nemotron VL models generated_ids_before_ptq = full_model.generate(input_ids, max_new_tokens=100) + if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only": print("Applying nvfp4 quantization (MoE only) for gpt-oss") @@ -542,7 +548,9 @@ def main(args): # Run some samples torch.cuda.empty_cache() generated_ids_after_ptq = None - if model_type != "llama4" and not is_nemotron_vl_model: + if generated_ids_before_ptq is None: + pass + elif model_type != "llama4" and not is_nemotron_vl_model: # Our fake quantizer may not be fully compatible with torch.compile. generated_ids_after_ptq = full_model.generate(input_ids, max_new_tokens=100) elif is_nemotron_vl_model and tokenizer is not None: diff --git a/examples/llm_ptq/scripts/huggingface_example.sh b/examples/llm_ptq/scripts/huggingface_example.sh index 043b690e5..3ea85de9e 100755 --- a/examples/llm_ptq/scripts/huggingface_example.sh +++ b/examples/llm_ptq/scripts/huggingface_example.sh @@ -53,9 +53,9 @@ esac IFS="," for qformat in $QFORMAT; do case $qformat in - fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;; + fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | nvfp4_mlp_only) ;; *) - echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8]" >&2 + echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, nvfp4_mlp_only]" >&2 exit 1 ;; esac diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 31ac2bbbd..458c72bce 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -22,6 +22,8 @@ from typing import TYPE_CHECKING import torch +from torch import Tensor +from torch.nn.functional import linear try: from torch.distributed.tensor import Shard @@ -501,6 +503,22 @@ def top_k(self, value): self.router.moe_top_k = value +class _QuantCompressedLinear(QuantModule): + def _setup(self): + self.input_quantizer = TensorQuantizer() + self.weight_quantizer = TensorQuantizer() + + def forward(self, input: Tensor) -> Tensor: + from compressed_tensors.quantization import QuantizationStatus + + if self.quantization_status == QuantizationStatus.COMPRESSED: + weight_data = self.compressor.decompress_module(self) + else: + weight_data = self.weight + + return linear(self.input_quantizer(input), self.weight_quantizer(weight_data), self.bias) + + try: from transformers.models.llama4.modeling_llama4 import Llama4TextExperts, Llama4TextMoe @@ -576,6 +594,16 @@ def top_k(self, value): except ImportError: pass +try: + from compressed_tensors.linear.compressed_linear import CompressedLinear + + if CompressedLinear not in QuantModuleRegistry: + QuantModuleRegistry.register({CompressedLinear: "hf.CompressedLinear"})( + _QuantCompressedLinear + ) +except ImportError: + pass + class _QuantGptOssExperts(_QuantFunctionalMixin): """Quantized wrapper for `transformers.GptOssExperts`.