Skip to content
30 changes: 29 additions & 1 deletion examples/llm_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
snapshot_download = None

import modelopt.torch.quantization as mtq
from modelopt.torch.utils.image_processor import MllamaImageProcessor
from modelopt.torch.utils.image_processor import MllamaImageProcessor, Qwen3OmniImageProcessor

SPECULATIVE_MODEL_LIST = ["Eagle", "Medusa"]

Expand Down Expand Up @@ -174,6 +174,21 @@ def build_quant_cfg(
quant_cfg["quant_cfg"]["*image*"] = {"enable": False}
quant_cfg["quant_cfg"]["*vision*"] = {"enable": False}

# Qwen3 specific quantizer disabling patterns (thinker.model.layers only)
if "qkv_disabled" in qformat:
quant_cfg = copy.deepcopy(quant_cfg) # Don't modify global config
for proj in ["q_proj", "k_proj", "v_proj"]:
quant_cfg["quant_cfg"][f"*thinker.model.layers.*.self_attn.{proj}*"] = {
"enable": False
}
if "qkvo_disabled" in qformat:
if "qkv_disabled" not in qformat: # Avoid double deepcopy
quant_cfg = copy.deepcopy(quant_cfg)
for proj in ["o_proj"]:
quant_cfg["quant_cfg"][f"*thinker.model.layers.*.self_attn.{proj}*"] = {
"enable": False
}

return quant_cfg


Expand Down Expand Up @@ -240,6 +255,19 @@ def get_processor(
)

return MllamaImageProcessor(processor, device)
elif model_type == "qwen3omni":
processor = AutoProcessor.from_pretrained(
ckpt_path,
padding_side="left",
**model_kwargs,
)
if processor.tokenizer.pad_token is None:
processor.tokenizer.pad_token = processor.tokenizer.eos_token
assert processor.tokenizer.pad_token is not None, (
f"Pad token for {ckpt_path} cannot be set!"
)

return Qwen3OmniImageProcessor(processor, device)


def get_dtype(dtype):
Expand Down
155 changes: 146 additions & 9 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# limitations under the License.

import argparse
import contextlib
import os
import random
import time
import warnings
Expand Down Expand Up @@ -57,12 +59,26 @@
create_forward_loop,
get_dataset_dataloader,
get_max_batch_size,
get_qwen3omni_text_dataloader,
get_supported_datasets,
)
from modelopt.torch.utils.image_processor import MllamaImageProcessor
from modelopt.torch.utils.image_processor import (
BaseImageProcessor,
MllamaImageProcessor,
Qwen3OmniImageProcessor,
Qwen3OmniTextProcessor,
)
from modelopt.torch.utils.memory_monitor import launch_memory_monitor
from modelopt.torch.utils.speech_dataset_utils import get_speech_dataset_dataloader
from modelopt.torch.utils.vlm_dataset_utils import get_vlm_dataset_dataloader
from modelopt.torch.utils.video_dataset_utils import (
Qwen3OmniVideoProcessor,
get_supported_video_datasets,
get_video_dataset_dataloader,
)
from modelopt.torch.utils.vlm_dataset_utils import (
get_supported_vlm_datasets,
get_vlm_dataset_dataloader,
)

RAND_SEED = 1234

Expand All @@ -80,6 +96,9 @@
"w4a8_nvfp4_fp8": mtq.W4A8_NVFP4_FP8_CFG,
"w4a8_mxfp4_fp8": mtq.W4A8_MXFP4_FP8_CFG,
"nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG,
"qwen3_nvfp4_qkv_disabled": mtq.NVFP4_DEFAULT_CFG,
"qwen3_nvfp4_qkvo_disabled": mtq.NVFP4_DEFAULT_CFG,
"qwen3_first_and_last_n_disabled": mtq.NVFP4_DEFAULT_CFG,
}

KV_QUANT_CFG_CHOICES = {
Expand Down Expand Up @@ -280,11 +299,35 @@ def main(args):
use_seq_device_map=args.use_seq_device_map,
attn_implementation=args.attn_implementation,
)

quant_cfg = QUANT_CFG_CHOICES[args.qformat]
# Qwen3 specific quantizer disabling patterns (thinker.model.layers only)
if "qkv_disabled" in args.qformat:
# Disable q_proj, k_proj, v_proj quantizers
for proj in ["q_proj", "k_proj", "v_proj"]:
quant_cfg["quant_cfg"][f"*thinker.model.layers.*.self_attn.{proj}*"] = {
"enable": False
}
if "qkvo_disabled" in args.qformat:
# Disable q_proj, k_proj, v_proj, o_proj quantizers
for proj in ["o_proj"]:
quant_cfg["quant_cfg"][f"*thinker.model.layers.*.self_attn.{proj}*"] = {
"enable": False
}
if "first_and_last_n_disabled" in args.qformat:
# Disable both first N and last N layers
total_layers = 48
n_layers_to_disable = 4
for i in range(n_layers_to_disable):
quant_cfg["quant_cfg"][f"*thinker.model.layers.{i}.*"] = {"enable": False}
for i in range(total_layers - n_layers_to_disable, total_layers):
quant_cfg["quant_cfg"][f"*thinker.model.layers.{i}.*"] = {"enable": False}
else:
assert args.qformat in QUANT_CFG_CHOICES, (
f"Quantization format is not supported for low memory mode. Supported formats: {QUANT_CFG_CHOICES.keys()}"
)
quant_cfg = QUANT_CFG_CHOICES[args.qformat]

if args.kv_cache_qformat != "none":
quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant(
quant_cfg, getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"]
Expand All @@ -305,10 +348,19 @@ def main(args):
model_is_already_quantized = is_quantized(model)

model_type = get_model_type(model)
if model_type == "qwen3omni" and os.environ.get("DISABLE_TALKER", "0") == "1":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we probably need to find a better way for configurations like this

print("Disabling talker for Qwen3Omni model")
model.disable_talker()

device = model.device
if hasattr(model, "model"):
device = model.model.device
# For multi-GPU models with device_map="auto", model.device may return 'meta' or 'cpu'
# since parameters are distributed. Force cuda:0 for input tensors.
if device is None or str(device) in ("meta", "cpu"):
device = "cuda"
print(f"Overriding device to {device}")

processor = None
tokenizer = None

Expand All @@ -317,7 +369,7 @@ def main(args):
# Detect if this is a Nemotron VL model using architecture-based detection
is_nemotron_vl_model = is_nemotron_vl(full_model)

if model_type == "mllama":
if model_type in ["mllama", "qwen3omni"]:
processor = get_processor(
args.pyt_ckpt_path,
model_type,
Expand Down Expand Up @@ -453,6 +505,56 @@ def main(args):
batch_size=args.batch_size,
num_samples=args.calib_size[0],
)
elif model_type == "qwen3omni":
assert len(args.calib_size) == 1, (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for this part, I think we may want to host it in a model specific python file/module. E.g. llm_ptq/models/qwen3omni.py.

@shengliangxu WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not need to do it for now, I'll come up with a full design doc and then we can convert the whole repo afterwards. Even if we separate things out now, we may still refactor these anyway.

"qwen3omni only supports one dataset for calibration, can extend this in the future"
)
assert processor is not None, "The processor must be set for qwen3omni model."
dataset_name = args.dataset[0] if args.dataset else "cnn_dailymail"
# Check if using video dataset (e.g., finevideo)
if dataset_name in get_supported_video_datasets():
video_processor = Qwen3OmniVideoProcessor(
processor.tokenizer if hasattr(processor, "tokenizer") else processor,
device=device,
dtype=model.dtype,
use_audio_in_video=True,
)
calib_dataloader = get_video_dataset_dataloader(
dataset_name=dataset_name,
processor=video_processor,
batch_size=args.batch_size,
num_samples=args.calib_size[0],
)
elif dataset_name in get_supported_vlm_datasets():
assert processor is not None and isinstance(processor, Qwen3OmniImageProcessor), (
"The Qwen3OmniImageProcessor must be set."
)
# Set the dtype for proper tensor conversion in collate_function
processor.dtype = model.dtype
calib_dataloader = get_vlm_dataset_dataloader(
dataset_name=dataset_name,
processor=processor,
batch_size=args.batch_size,
num_samples=args.calib_size[0],
)
else:
# Text-only datasets (e.g., cnn_dailymail)
# Use Qwen3OmniTextProcessor to apply proper conversation template
# See: https://huggingface.co/Qwen/Qwen3-Omni-30B-A3B-Thinking
text_processor = Qwen3OmniTextProcessor(
processor=processor.tokenizer, # Pass the underlying HF processor
device=device,
dtype=model.dtype,
)
calib_dataloader = get_qwen3omni_text_dataloader(
dataset_name=dataset_name,
processor=text_processor,
batch_size=args.batch_size,
num_samples=args.calib_size[0],
max_sample_length=args.calib_seq,
device=device,
)
print(f"Selected dataset for calibration: {dataset_name}")
elif model_type == "whisper":
assert processor is not None and isinstance(processor, WhisperProcessor), (
"The AutoProcessor must be set."
Expand Down Expand Up @@ -506,9 +608,10 @@ def main(args):

if not model_is_already_quantized or calibration_only:
# Only run single sample for preview
input_ids = next(iter(calib_dataloader))[
"input_features" if model_type == "whisper" else "input_ids"
][0:1]
calib_batch = next(iter(calib_dataloader))
input_ids = calib_batch["input_features" if model_type == "whisper" else "input_ids"][
0:1
]

# Generate preview before quantization
if is_nemotron_vl_model and tokenizer is not None:
Expand All @@ -520,6 +623,17 @@ def main(args):
"before quantization",
allow_fallback=True,
)
elif model_type == "qwen3omni":
# Qwen3Omni returns (text_ids, audio) tuple; text_ids has .sequences
# Pass full batch with all multimodal inputs
result = full_model.generate(**calib_batch, max_new_tokens=100)
if isinstance(result, tuple):
text_ids, _ = result
generated_ids_before_ptq = (
text_ids.sequences if hasattr(text_ids, "sequences") else text_ids
)
else:
generated_ids_before_ptq = result
else:
# Standard generation for non-Nemotron VL models
generated_ids_before_ptq = full_model.generate(input_ids, max_new_tokens=100)
Expand All @@ -537,12 +651,24 @@ def main(args):
language_model_lineage[-2].language_model = model

if args.verbose:
mtq.print_quant_summary(full_model)
with open("./quant_summary.txt", "w") as f, contextlib.redirect_stdout(f):
mtq.print_quant_summary(full_model)

# Run some samples
torch.cuda.empty_cache()
generated_ids_after_ptq = None
if model_type != "llama4" and not is_nemotron_vl_model:
if model_type == "qwen3omni":
# Qwen3Omni returns (text_ids, audio) tuple; text_ids has .sequences
# Pass full batch with all multimodal inputs
result = full_model.generate(**calib_batch, max_new_tokens=100)
if isinstance(result, tuple):
text_ids, _ = result
generated_ids_after_ptq = (
text_ids.sequences if hasattr(text_ids, "sequences") else text_ids
)
else:
generated_ids_after_ptq = result
elif model_type != "llama4" and not is_nemotron_vl_model:
# Our fake quantizer may not be fully compatible with torch.compile.
generated_ids_after_ptq = full_model.generate(input_ids, max_new_tokens=100)
elif is_nemotron_vl_model and tokenizer is not None:
Expand All @@ -560,7 +686,8 @@ def main(args):
)

def input_decode(input_ids):
if processor is not None and isinstance(processor, MllamaImageProcessor):
# BaseImageProcessor covers MllamaImageProcessor and Qwen3OmniImageProcessor
if processor is not None and isinstance(processor, BaseImageProcessor):
return processor.tokenizer.batch_decode(input_ids)
elif processor is not None and isinstance(processor, WhisperProcessor):
return first_text
Expand All @@ -579,6 +706,12 @@ def output_decode(generated_ids, input_shape):
return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
elif processor is not None and isinstance(processor, MllamaImageProcessor):
return processor.tokenizer.batch_decode(generated_ids[:, input_shape:])
elif processor is not None and isinstance(processor, Qwen3OmniImageProcessor):
return processor.tokenizer.batch_decode(
generated_ids[:, input_shape:],
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
elif tokenizer is not None:
return tokenizer.batch_decode(generated_ids[:, input_shape:])
else:
Expand Down Expand Up @@ -612,6 +745,10 @@ def output_decode(generated_ids, input_shape):
assert model_type != "dbrx", f"Does not support export {model_type} without quantizaton"
print(f"qformat: {args.qformat}. No quantization applied, export {device} model")

if model_type == "qwen3omni":
print("Export of Qwen3Omni model is not supported yet")
return

with torch.inference_mode():
if model_type is None:
print(f"Unknown model type {type(model).__name__}. Continue exporting...")
Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/export/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"MPT": "mpt",
"Bloom": "bloom",
"ChatGLM": "chatglm",
"Qwen3OmniMoeForConditionalGeneration": "qwen3omni",
"QWen": "qwen",
"RecurrentGemma": "recurrentgemma",
"Gemma3": "gemma3",
Expand Down
18 changes: 18 additions & 0 deletions modelopt/torch/quantization/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,24 @@ def top_k(self, value):
except ImportError:
pass

# Uncomment to forward tokens to all MoE experts for full calibration.
# try:
# from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import (
# Qwen3OmniMoeTalkerTextSparseMoeBlock,
# Qwen3OmniMoeThinkerTextSparseMoeBlock,
# )

# if Qwen3OmniMoeTalkerTextSparseMoeBlock not in QuantModuleRegistry:
# QuantModuleRegistry.register(
# {Qwen3OmniMoeTalkerTextSparseMoeBlock: "hf.Qwen3OmniMoeTalkerTextSparseMoeBlock"}
# )(_QuantSparseMoe)
# if Qwen3OmniMoeThinkerTextSparseMoeBlock not in QuantModuleRegistry:
# QuantModuleRegistry.register(
# {Qwen3OmniMoeThinkerTextSparseMoeBlock: "hf.Qwen3OmniMoeThinkerTextSparseMoeBlock"}
# )(_QuantSparseMoe)
# except ImportError:
# pass


class _QuantGptOssExperts(_QuantFunctionalMixin):
"""Quantized wrapper for `transformers.GptOssExperts`.
Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@
from .perf import *
from .regex import *
from .tensor import *
from .video_dataset_utils import *
from .vlm_dataset_utils import *
Loading