Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions examples/llm_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,12 @@ def get_model(
device_map=device_map,
**model_kwargs,
)
elif hf_config.quantization_config.get("format", None) == "pack-quantized":
model = AutoModelForCausalLM.from_pretrained(
ckpt_path,
device_map="auto",
trust_remote_code=trust_remote_code,
)
else:
architecture = hf_config.architectures[0]

Expand All @@ -346,9 +352,9 @@ def get_model(
from_config = auto_model_module._from_config

with init_empty_weights():
# When computing the device_map, assuming half precision by default,
# When computing the device_map, assuming bfloat16 precision by default,
# unless specified by the hf_config.
torch_dtype = getattr(hf_config, "torch_dtype", torch.float16)
torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is suspicious to me.

Won't this impact other models?

model_kwargs2 = model_kwargs.copy()
if auto_model_module != AutoModelForCausalLM:
model_kwargs2.pop("trust_remote_code", None)
Expand Down
12 changes: 10 additions & 2 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,12 @@ def main(args):
][0:1]

# Generate preview before quantization
if is_nemotron_vl_model and tokenizer is not None:
if model_type == "deepseek":
print(
"Deepseek model may hit OOM during preview generation. Skipping preview generation."
)
generated_ids_before_ptq = None
elif is_nemotron_vl_model and tokenizer is not None:
generated_ids_before_ptq = run_nemotron_vl_preview(
full_model,
tokenizer,
Expand All @@ -523,6 +528,7 @@ def main(args):
else:
# Standard generation for non-Nemotron VL models
generated_ids_before_ptq = full_model.generate(input_ids, max_new_tokens=100)

if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only":
print("Applying nvfp4 quantization (MoE only) for gpt-oss")

Expand All @@ -542,7 +548,9 @@ def main(args):
# Run some samples
torch.cuda.empty_cache()
generated_ids_after_ptq = None
if model_type != "llama4" and not is_nemotron_vl_model:
if generated_ids_before_ptq is None:
pass
elif model_type != "llama4" and not is_nemotron_vl_model:
# Our fake quantizer may not be fully compatible with torch.compile.
generated_ids_after_ptq = full_model.generate(input_ids, max_new_tokens=100)
elif is_nemotron_vl_model and tokenizer is not None:
Expand Down
4 changes: 2 additions & 2 deletions examples/llm_ptq/scripts/huggingface_example.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ esac
IFS=","
for qformat in $QFORMAT; do
case $qformat in
fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;;
fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | nvfp4_mlp_only) ;;
*)
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8]" >&2
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, nvfp4_mlp_only]" >&2
exit 1
;;
esac
Expand Down
28 changes: 28 additions & 0 deletions modelopt/torch/quantization/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from typing import TYPE_CHECKING

import torch
from torch import Tensor
from torch.nn.functional import linear

try:
from torch.distributed.tensor import Shard
Expand Down Expand Up @@ -501,6 +503,22 @@ def top_k(self, value):
self.router.moe_top_k = value


class _QuantCompressedLinear(QuantModule):
def _setup(self):
self.input_quantizer = TensorQuantizer()
self.weight_quantizer = TensorQuantizer()

def forward(self, input: Tensor) -> Tensor:
from compressed_tensors.quantization import QuantizationStatus

if self.quantization_status == QuantizationStatus.COMPRESSED:
weight_data = self.compressor.decompress_module(self)
else:
weight_data = self.weight

return linear(self.input_quantizer(input), self.weight_quantizer(weight_data), self.bias)


try:
from transformers.models.llama4.modeling_llama4 import Llama4TextExperts, Llama4TextMoe

Expand Down Expand Up @@ -576,6 +594,16 @@ def top_k(self, value):
except ImportError:
pass

try:
from compressed_tensors.linear.compressed_linear import CompressedLinear
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add compressed-tensor as an optional dependency?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kevalmorabia97 @realAsma what do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a user is quantizing a model with CompressedLinear, wouldn't they already have compressed-tensors pre-installed? What benefit do we have by having it added as an optional dependency?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compressed-tensors's main dependencies are torch and transformers so should be pretty lightweight to add as a dependency so fine if you want to add. But if its not commonly used by customers, perhaps we can skip it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this to a seperate file modelopt/torch/quantization/plugins/compressed_tensor.py?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a user is quantizing a model with CompressedLinear, wouldn't they already have compressed-tensors pre-installed?

This is a good point. +1
Are we planning to have any unit tests for compressed tensor integration?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not right now

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this to a seperate file modelopt/torch/quantization/plugins/compressed_tensor.py?

How strong do you feel about it? Right now I feel this still fall under hf plugins as it's part of the HF's invocation.


if CompressedLinear not in QuantModuleRegistry:
QuantModuleRegistry.register({CompressedLinear: "hf.CompressedLinear"})(
_QuantCompressedLinear
)
except ImportError:
pass


class _QuantGptOssExperts(_QuantFunctionalMixin):
"""Quantized wrapper for `transformers.GptOssExperts`.
Expand Down
Loading