-
Notifications
You must be signed in to change notification settings - Fork 227
Add support for Qwen3-Omni-30B-A3B-Thinking #677
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
3a3af92
7857728
ae13469
5746ea0
9ec99a0
156f7ee
f4ca285
c5f2fce
e4b374a
0c4b38f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,6 +14,8 @@ | |
| # limitations under the License. | ||
|
|
||
| import argparse | ||
| import contextlib | ||
| import os | ||
| import random | ||
| import time | ||
| import warnings | ||
|
|
@@ -57,12 +59,26 @@ | |
| create_forward_loop, | ||
| get_dataset_dataloader, | ||
| get_max_batch_size, | ||
| get_qwen3omni_text_dataloader, | ||
| get_supported_datasets, | ||
| ) | ||
| from modelopt.torch.utils.image_processor import MllamaImageProcessor | ||
| from modelopt.torch.utils.image_processor import ( | ||
| BaseImageProcessor, | ||
| MllamaImageProcessor, | ||
| Qwen3OmniImageProcessor, | ||
| Qwen3OmniTextProcessor, | ||
| ) | ||
| from modelopt.torch.utils.memory_monitor import launch_memory_monitor | ||
| from modelopt.torch.utils.speech_dataset_utils import get_speech_dataset_dataloader | ||
| from modelopt.torch.utils.vlm_dataset_utils import get_vlm_dataset_dataloader | ||
| from modelopt.torch.utils.video_dataset_utils import ( | ||
| Qwen3OmniVideoProcessor, | ||
| get_supported_video_datasets, | ||
| get_video_dataset_dataloader, | ||
| ) | ||
| from modelopt.torch.utils.vlm_dataset_utils import ( | ||
| get_supported_vlm_datasets, | ||
| get_vlm_dataset_dataloader, | ||
| ) | ||
|
|
||
| RAND_SEED = 1234 | ||
|
|
||
|
|
@@ -80,6 +96,9 @@ | |
| "w4a8_nvfp4_fp8": mtq.W4A8_NVFP4_FP8_CFG, | ||
| "w4a8_mxfp4_fp8": mtq.W4A8_MXFP4_FP8_CFG, | ||
| "nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG, | ||
| "qwen3_nvfp4_qkv_disabled": mtq.NVFP4_DEFAULT_CFG, | ||
| "qwen3_nvfp4_qkvo_disabled": mtq.NVFP4_DEFAULT_CFG, | ||
| "qwen3_first_and_last_n_disabled": mtq.NVFP4_DEFAULT_CFG, | ||
| } | ||
|
|
||
| KV_QUANT_CFG_CHOICES = { | ||
|
|
@@ -280,11 +299,35 @@ def main(args): | |
| use_seq_device_map=args.use_seq_device_map, | ||
| attn_implementation=args.attn_implementation, | ||
| ) | ||
|
|
||
| quant_cfg = QUANT_CFG_CHOICES[args.qformat] | ||
| # Qwen3 specific quantizer disabling patterns (thinker.model.layers only) | ||
| if "qkv_disabled" in args.qformat: | ||
| # Disable q_proj, k_proj, v_proj quantizers | ||
| for proj in ["q_proj", "k_proj", "v_proj"]: | ||
| quant_cfg["quant_cfg"][f"*thinker.model.layers.*.self_attn.{proj}*"] = { | ||
| "enable": False | ||
| } | ||
| if "qkvo_disabled" in args.qformat: | ||
| # Disable q_proj, k_proj, v_proj, o_proj quantizers | ||
| for proj in ["o_proj"]: | ||
| quant_cfg["quant_cfg"][f"*thinker.model.layers.*.self_attn.{proj}*"] = { | ||
| "enable": False | ||
| } | ||
| if "first_and_last_n_disabled" in args.qformat: | ||
| # Disable both first N and last N layers | ||
| total_layers = 48 | ||
| n_layers_to_disable = 4 | ||
| for i in range(n_layers_to_disable): | ||
| quant_cfg["quant_cfg"][f"*thinker.model.layers.{i}.*"] = {"enable": False} | ||
| for i in range(total_layers - n_layers_to_disable, total_layers): | ||
| quant_cfg["quant_cfg"][f"*thinker.model.layers.{i}.*"] = {"enable": False} | ||
| else: | ||
| assert args.qformat in QUANT_CFG_CHOICES, ( | ||
| f"Quantization format is not supported for low memory mode. Supported formats: {QUANT_CFG_CHOICES.keys()}" | ||
| ) | ||
| quant_cfg = QUANT_CFG_CHOICES[args.qformat] | ||
|
|
||
| if args.kv_cache_qformat != "none": | ||
| quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant( | ||
| quant_cfg, getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"] | ||
|
|
@@ -305,10 +348,19 @@ def main(args): | |
| model_is_already_quantized = is_quantized(model) | ||
|
|
||
| model_type = get_model_type(model) | ||
| if model_type == "qwen3omni" and os.environ.get("DISABLE_TALKER", "0") == "1": | ||
| print("Disabling talker for Qwen3Omni model") | ||
| model.disable_talker() | ||
|
|
||
| device = model.device | ||
| if hasattr(model, "model"): | ||
| device = model.model.device | ||
| # For multi-GPU models with device_map="auto", model.device may return 'meta' or 'cpu' | ||
| # since parameters are distributed. Force cuda:0 for input tensors. | ||
| if device is None or str(device) in ("meta", "cpu"): | ||
| device = "cuda" | ||
| print(f"Overriding device to {device}") | ||
|
|
||
| processor = None | ||
| tokenizer = None | ||
|
|
||
|
|
@@ -317,7 +369,7 @@ def main(args): | |
| # Detect if this is a Nemotron VL model using architecture-based detection | ||
| is_nemotron_vl_model = is_nemotron_vl(full_model) | ||
|
|
||
| if model_type == "mllama": | ||
| if model_type in ["mllama", "qwen3omni"]: | ||
| processor = get_processor( | ||
| args.pyt_ckpt_path, | ||
| model_type, | ||
|
|
@@ -453,6 +505,56 @@ def main(args): | |
| batch_size=args.batch_size, | ||
| num_samples=args.calib_size[0], | ||
| ) | ||
| elif model_type == "qwen3omni": | ||
| assert len(args.calib_size) == 1, ( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for this part, I think we may want to host it in a model specific python file/module. E.g. llm_ptq/models/qwen3omni.py. @shengliangxu WDYT?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We do not need to do it for now, I'll come up with a full design doc and then we can convert the whole repo afterwards. Even if we separate things out now, we may still refactor these anyway. |
||
| "qwen3omni only supports one dataset for calibration, can extend this in the future" | ||
| ) | ||
| assert processor is not None, "The processor must be set for qwen3omni model." | ||
| dataset_name = args.dataset[0] if args.dataset else "cnn_dailymail" | ||
| # Check if using video dataset (e.g., finevideo) | ||
| if dataset_name in get_supported_video_datasets(): | ||
| video_processor = Qwen3OmniVideoProcessor( | ||
| processor.tokenizer if hasattr(processor, "tokenizer") else processor, | ||
| device=device, | ||
| dtype=model.dtype, | ||
| use_audio_in_video=True, | ||
| ) | ||
| calib_dataloader = get_video_dataset_dataloader( | ||
| dataset_name=dataset_name, | ||
| processor=video_processor, | ||
| batch_size=args.batch_size, | ||
| num_samples=args.calib_size[0], | ||
| ) | ||
| elif dataset_name in get_supported_vlm_datasets(): | ||
| assert processor is not None and isinstance(processor, Qwen3OmniImageProcessor), ( | ||
| "The Qwen3OmniImageProcessor must be set." | ||
| ) | ||
| # Set the dtype for proper tensor conversion in collate_function | ||
| processor.dtype = model.dtype | ||
| calib_dataloader = get_vlm_dataset_dataloader( | ||
| dataset_name=dataset_name, | ||
| processor=processor, | ||
| batch_size=args.batch_size, | ||
| num_samples=args.calib_size[0], | ||
| ) | ||
| else: | ||
| # Text-only datasets (e.g., cnn_dailymail) | ||
| # Use Qwen3OmniTextProcessor to apply proper conversation template | ||
| # See: https://huggingface.co/Qwen/Qwen3-Omni-30B-A3B-Thinking | ||
| text_processor = Qwen3OmniTextProcessor( | ||
| processor=processor.tokenizer, # Pass the underlying HF processor | ||
| device=device, | ||
| dtype=model.dtype, | ||
| ) | ||
| calib_dataloader = get_qwen3omni_text_dataloader( | ||
| dataset_name=dataset_name, | ||
| processor=text_processor, | ||
| batch_size=args.batch_size, | ||
| num_samples=args.calib_size[0], | ||
| max_sample_length=args.calib_seq, | ||
| device=device, | ||
| ) | ||
| print(f"Selected dataset for calibration: {dataset_name}") | ||
| elif model_type == "whisper": | ||
| assert processor is not None and isinstance(processor, WhisperProcessor), ( | ||
| "The AutoProcessor must be set." | ||
|
|
@@ -506,9 +608,10 @@ def main(args): | |
|
|
||
| if not model_is_already_quantized or calibration_only: | ||
| # Only run single sample for preview | ||
| input_ids = next(iter(calib_dataloader))[ | ||
| "input_features" if model_type == "whisper" else "input_ids" | ||
| ][0:1] | ||
| calib_batch = next(iter(calib_dataloader)) | ||
| input_ids = calib_batch["input_features" if model_type == "whisper" else "input_ids"][ | ||
| 0:1 | ||
| ] | ||
|
|
||
| # Generate preview before quantization | ||
| if is_nemotron_vl_model and tokenizer is not None: | ||
|
|
@@ -520,6 +623,17 @@ def main(args): | |
| "before quantization", | ||
| allow_fallback=True, | ||
| ) | ||
| elif model_type == "qwen3omni": | ||
| # Qwen3Omni returns (text_ids, audio) tuple; text_ids has .sequences | ||
| # Pass full batch with all multimodal inputs | ||
| result = full_model.generate(**calib_batch, max_new_tokens=100) | ||
| if isinstance(result, tuple): | ||
| text_ids, _ = result | ||
| generated_ids_before_ptq = ( | ||
| text_ids.sequences if hasattr(text_ids, "sequences") else text_ids | ||
| ) | ||
| else: | ||
| generated_ids_before_ptq = result | ||
| else: | ||
| # Standard generation for non-Nemotron VL models | ||
| generated_ids_before_ptq = full_model.generate(input_ids, max_new_tokens=100) | ||
|
|
@@ -537,12 +651,24 @@ def main(args): | |
| language_model_lineage[-2].language_model = model | ||
|
|
||
| if args.verbose: | ||
| mtq.print_quant_summary(full_model) | ||
| with open("./quant_summary.txt", "w") as f, contextlib.redirect_stdout(f): | ||
| mtq.print_quant_summary(full_model) | ||
|
|
||
| # Run some samples | ||
| torch.cuda.empty_cache() | ||
| generated_ids_after_ptq = None | ||
| if model_type != "llama4" and not is_nemotron_vl_model: | ||
| if model_type == "qwen3omni": | ||
| # Qwen3Omni returns (text_ids, audio) tuple; text_ids has .sequences | ||
| # Pass full batch with all multimodal inputs | ||
| result = full_model.generate(**calib_batch, max_new_tokens=100) | ||
| if isinstance(result, tuple): | ||
| text_ids, _ = result | ||
| generated_ids_after_ptq = ( | ||
| text_ids.sequences if hasattr(text_ids, "sequences") else text_ids | ||
| ) | ||
| else: | ||
| generated_ids_after_ptq = result | ||
| elif model_type != "llama4" and not is_nemotron_vl_model: | ||
| # Our fake quantizer may not be fully compatible with torch.compile. | ||
| generated_ids_after_ptq = full_model.generate(input_ids, max_new_tokens=100) | ||
| elif is_nemotron_vl_model and tokenizer is not None: | ||
|
|
@@ -560,7 +686,8 @@ def main(args): | |
| ) | ||
|
|
||
| def input_decode(input_ids): | ||
| if processor is not None and isinstance(processor, MllamaImageProcessor): | ||
| # BaseImageProcessor covers MllamaImageProcessor and Qwen3OmniImageProcessor | ||
| if processor is not None and isinstance(processor, BaseImageProcessor): | ||
| return processor.tokenizer.batch_decode(input_ids) | ||
| elif processor is not None and isinstance(processor, WhisperProcessor): | ||
| return first_text | ||
|
|
@@ -579,6 +706,12 @@ def output_decode(generated_ids, input_shape): | |
| return tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | ||
| elif processor is not None and isinstance(processor, MllamaImageProcessor): | ||
| return processor.tokenizer.batch_decode(generated_ids[:, input_shape:]) | ||
| elif processor is not None and isinstance(processor, Qwen3OmniImageProcessor): | ||
| return processor.tokenizer.batch_decode( | ||
| generated_ids[:, input_shape:], | ||
| skip_special_tokens=True, | ||
| clean_up_tokenization_spaces=False, | ||
| ) | ||
| elif tokenizer is not None: | ||
| return tokenizer.batch_decode(generated_ids[:, input_shape:]) | ||
| else: | ||
|
|
@@ -612,6 +745,10 @@ def output_decode(generated_ids, input_shape): | |
| assert model_type != "dbrx", f"Does not support export {model_type} without quantizaton" | ||
| print(f"qformat: {args.qformat}. No quantization applied, export {device} model") | ||
|
|
||
| if model_type == "qwen3omni": | ||
| print("Export of Qwen3Omni model is not supported yet") | ||
| return | ||
|
|
||
| with torch.inference_mode(): | ||
| if model_type is None: | ||
| print(f"Unknown model type {type(model).__name__}. Continue exporting...") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we probably need to find a better way for configurations like this