From 3a3af92cb0365db40989e6e56b134ccb4d183e01 Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Thu, 11 Dec 2025 06:02:06 +0000 Subject: [PATCH 01/10] Add support for Qwen3-Omni-30B-A3B-Thinking --- examples/llm_ptq/example_utils.py | 15 ++- examples/llm_ptq/hf_ptq.py | 66 +++++++++++- modelopt/torch/export/model_utils.py | 1 + modelopt/torch/utils/dataset_utils.py | 47 +++++++-- modelopt/torch/utils/image_processor.py | 116 ++++++++++++++++++++++ modelopt/torch/utils/vlm_dataset_utils.py | 4 +- 6 files changed, 231 insertions(+), 18 deletions(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index ce3fb0853..dd1958ac9 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -33,7 +33,7 @@ snapshot_download = None import modelopt.torch.quantization as mtq -from modelopt.torch.utils.image_processor import MllamaImageProcessor +from modelopt.torch.utils.image_processor import MllamaImageProcessor, Qwen3OmniImageProcessor SPECULATIVE_MODEL_LIST = ["Eagle", "Medusa"] @@ -240,6 +240,19 @@ def get_processor( ) return MllamaImageProcessor(processor, device) + elif model_type == "qwen3omni": + processor = AutoProcessor.from_pretrained( + ckpt_path, + padding_side="left", + **model_kwargs, + ) + if processor.tokenizer.pad_token is None: + processor.tokenizer.pad_token = processor.tokenizer.eos_token + assert processor.tokenizer.pad_token is not None, ( + f"Pad token for {ckpt_path} cannot be set!" + ) + + return Qwen3OmniImageProcessor(processor, device) def get_dtype(dtype): diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 57f0b5a89..40e8ee42e 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -59,7 +59,11 @@ get_max_batch_size, get_supported_datasets, ) -from modelopt.torch.utils.image_processor import MllamaImageProcessor +from modelopt.torch.utils.image_processor import ( + BaseImageProcessor, + MllamaImageProcessor, + Qwen3OmniImageProcessor, +) from modelopt.torch.utils.memory_monitor import launch_memory_monitor from modelopt.torch.utils.speech_dataset_utils import get_speech_dataset_dataloader from modelopt.torch.utils.vlm_dataset_utils import get_vlm_dataset_dataloader @@ -305,10 +309,16 @@ def main(args): model_is_already_quantized = is_quantized(model) model_type = get_model_type(model) + if model_type == "qwen3omni": + model.disable_talker() device = model.device if hasattr(model, "model"): device = model.model.device + # For multi-GPU models with device_map="auto", model.device may return 'meta' or 'cpu' + # since parameters are distributed. Force cuda:0 for input tensors. + if device is None or str(device) in ("meta", "cpu"): + device = "cuda" processor = None tokenizer = None @@ -317,7 +327,7 @@ def main(args): # Detect if this is a Nemotron VL model using architecture-based detection is_nemotron_vl_model = is_nemotron_vl(full_model) - if model_type == "mllama": + if model_type in ["mllama", "qwen3omni"]: processor = get_processor( args.pyt_ckpt_path, model_type, @@ -453,6 +463,19 @@ def main(args): batch_size=args.batch_size, num_samples=args.calib_size[0], ) + elif model_type == "qwen3omni": + assert processor is not None and isinstance(processor, Qwen3OmniImageProcessor), ( + "The Qwen3OmniImageProcessor must be set." + ) + assert len(args.calib_size) == 1, ( + "qwen3omni only supports one dataset for calibration, can extend this in the future" + ) + calib_dataloader = get_vlm_dataset_dataloader( + dataset_name=args.dataset[0] if args.dataset else "scienceqa", + processor=processor, + batch_size=args.batch_size, + num_samples=args.calib_size[0], + ) elif model_type == "whisper": assert processor is not None and isinstance(processor, WhisperProcessor), ( "The AutoProcessor must be set." @@ -520,6 +543,16 @@ def main(args): "before quantization", allow_fallback=True, ) + elif model_type == "qwen3omni": + # Qwen3Omni returns (text_ids, audio) tuple; text_ids has .sequences + result = full_model.generate(input_ids, max_new_tokens=100) + if isinstance(result, tuple): + text_ids, _ = result + generated_ids_before_ptq = ( + text_ids.sequences if hasattr(text_ids, "sequences") else text_ids + ) + else: + generated_ids_before_ptq = result else: # Standard generation for non-Nemotron VL models generated_ids_before_ptq = full_model.generate(input_ids, max_new_tokens=100) @@ -536,13 +569,29 @@ def main(args): print("Updating full_model with quantized language_model...") language_model_lineage[-2].language_model = model + # if args.verbose: + # mtq.print_quant_summary(full_model) + + import contextlib + if args.verbose: - mtq.print_quant_summary(full_model) + with open("./quant_summary.txt", "w") as f, contextlib.redirect_stdout(f): + mtq.print_quant_summary(full_model) # Run some samples torch.cuda.empty_cache() generated_ids_after_ptq = None - if model_type != "llama4" and not is_nemotron_vl_model: + if model_type == "qwen3omni": + # Qwen3Omni returns (text_ids, audio) tuple; text_ids has .sequences + result = full_model.generate(input_ids, max_new_tokens=100) + if isinstance(result, tuple): + text_ids, _ = result + generated_ids_after_ptq = ( + text_ids.sequences if hasattr(text_ids, "sequences") else text_ids + ) + else: + generated_ids_after_ptq = result + elif model_type != "llama4" and not is_nemotron_vl_model: # Our fake quantizer may not be fully compatible with torch.compile. generated_ids_after_ptq = full_model.generate(input_ids, max_new_tokens=100) elif is_nemotron_vl_model and tokenizer is not None: @@ -560,7 +609,8 @@ def main(args): ) def input_decode(input_ids): - if processor is not None and isinstance(processor, MllamaImageProcessor): + # BaseImageProcessor covers MllamaImageProcessor and Qwen3OmniImageProcessor + if processor is not None and isinstance(processor, BaseImageProcessor): return processor.tokenizer.batch_decode(input_ids) elif processor is not None and isinstance(processor, WhisperProcessor): return first_text @@ -579,6 +629,12 @@ def output_decode(generated_ids, input_shape): return tokenizer.batch_decode(generated_ids, skip_special_tokens=True) elif processor is not None and isinstance(processor, MllamaImageProcessor): return processor.tokenizer.batch_decode(generated_ids[:, input_shape:]) + elif processor is not None and isinstance(processor, Qwen3OmniImageProcessor): + return processor.tokenizer.batch_decode( + generated_ids[:, input_shape:], + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) elif tokenizer is not None: return tokenizer.batch_decode(generated_ids[:, input_shape:]) else: diff --git a/modelopt/torch/export/model_utils.py b/modelopt/torch/export/model_utils.py index 706a01048..bedbc6d82 100755 --- a/modelopt/torch/export/model_utils.py +++ b/modelopt/torch/export/model_utils.py @@ -29,6 +29,7 @@ "MPT": "mpt", "Bloom": "bloom", "ChatGLM": "chatglm", + "Qwen3OmniMoeForConditionalGeneration": "qwen3omni", "QWen": "qwen", "RecurrentGemma": "recurrentgemma", "Gemma3": "gemma3", diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index d4cf5049d..141fdaacb 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -283,7 +283,8 @@ def _get_free_gpu_mem(): free_mem_before, max_allocated_before = _get_free_gpu_mem() is_enc_dec = model_type_is_enc_dec(model) - infer_method = model.generate if is_enc_dec else model.forward + requires_generate = _model_requires_generate(model) + infer_method = model.generate if (is_enc_dec or requires_generate) else model.forward if sample_input_single_batch is None: sample_input_single_batch = ( @@ -349,11 +350,15 @@ def _process_batch(batch_data, infer_method, max_working_batch_size=None): Returns: The maximum batch size that worked successfully """ - assert all(torch.is_tensor(data) or data is None for data in batch_data.values()), ( - "batch_data values must be tensors" + # Separate tensor values from scalar parameters (like max_new_tokens) + tensor_data = {k: v for k, v in batch_data.items() if torch.is_tensor(v) or v is None} + scalar_data = {k: v for k, v in batch_data.items() if not torch.is_tensor(v) and v is not None} + + assert all(torch.is_tensor(data) or data is None for data in tensor_data.values()), ( + "tensor_data values must be tensors" ) # Get the batch size of current data - batch_size = batch_data[next(iter(batch_data.keys()))].shape[0] + batch_size = tensor_data[next(iter(tensor_data.keys()))].shape[0] # If we know a smaller batch size works, preemptively split if max_working_batch_size is not None and batch_size > max_working_batch_size: @@ -361,11 +366,13 @@ def _process_batch(batch_data, infer_method, max_working_batch_size=None): for i in range(0, batch_size, max_working_batch_size): end_idx = min(i + max_working_batch_size, batch_size) split_data = {} - for key in batch_data: - if batch_data[key] is None: + for key in tensor_data: + if tensor_data[key] is None: split_data[key] = None else: - split_data[key] = batch_data[key][i:end_idx, ...] + split_data[key] = tensor_data[key][i:end_idx, ...] + # Add back scalar data (non-tensor params like max_new_tokens) + split_data.update(scalar_data) max_working_batch_size = _process_batch( split_data, infer_method, max_working_batch_size @@ -392,8 +399,11 @@ def _process_batch(batch_data, infer_method, max_working_batch_size=None): # Split the batch in half mid = (batch_size + 1) // 2 warn(f"CUDA out of memory with batch size {batch_size}, trying with batch size {mid}") - split_data_1 = {key: batch_data[key][:mid, ...] for key in batch_data} - split_data_2 = {key: batch_data[key][mid:, ...] for key in batch_data} + split_data_1 = {key: tensor_data[key][:mid, ...] for key in tensor_data} + split_data_2 = {key: tensor_data[key][mid:, ...] for key in tensor_data} + # Add back scalar data (non-tensor params like max_new_tokens) + split_data_1.update(scalar_data) + split_data_2.update(scalar_data) # Recursively process each half and track max working batch size max_working_batch_size = _process_batch(split_data_1, infer_method) @@ -412,10 +422,15 @@ def _forward_loop(model: torch.nn.Module, dataloader: DataLoader) -> None: """ with torch.no_grad(): is_enc_dec = model_type_is_enc_dec(model) - infer_method = model.generate if is_enc_dec else model.forward + requires_generate = _model_requires_generate(model) + use_generate = is_enc_dec or requires_generate + infer_method = model.generate if use_generate else model.forward max_working_batch_size = None # Initialize max working batch size as None for _, data in enumerate(tqdm(dataloader)): + # For generate(), add max_new_tokens to prevent indefinite generation during calibration + if use_generate: + data["max_new_tokens"] = 1 # Process batch and update max working batch size max_working_batch_size = _process_batch(data, infer_method, max_working_batch_size) @@ -493,3 +508,15 @@ def create_forward_loop( def model_type_is_enc_dec(model): enc_dec_model_list = ["t5", "bart", "whisper"] return any(model_name in model.__class__.__name__.lower() for model_name in enc_dec_model_list) + + +def _model_requires_generate(model): + """Check if model requires generate() instead of forward() for calibration. + + Some conditional generation models (like Qwen3-Omni) don't have a standard + forward(input_ids, ...) signature and need to use generate() for calibration. + """ + # Models that require generate() for calibration instead of forward() + generate_model_list = ["qwen3omni"] + model_name = model.__class__.__name__.lower() + return any(name in model_name for name in generate_model_list) diff --git a/modelopt/torch/utils/image_processor.py b/modelopt/torch/utils/image_processor.py index 87960d54d..4ed4b363a 100644 --- a/modelopt/torch/utils/image_processor.py +++ b/modelopt/torch/utils/image_processor.py @@ -25,6 +25,9 @@ class BaseImageProcessor: def __init__(self, tokenizer, device="auto"): """Constructor.""" self.tokenizer = tokenizer + # Handle invalid device values that can come from multi-GPU models with device_map="auto" + if device is None or str(device) in ("auto", "meta", "cpu"): + device = "cuda" self.device = device def __call__(self, **kwargs): @@ -110,3 +113,116 @@ def collate_function(self, batch): ).to(self.device) return batch[0] + + +class Qwen3OmniImageProcessor(BaseImageProcessor): + """Image processor for Qwen3-Omni multimodal model.""" + + def __init__(self, tokenizer, device="auto", use_audio_in_video=False): + """Constructor.""" + super().__init__(tokenizer, device) + self.use_audio_in_video = use_audio_in_video + # Try to import qwen_omni_utils for multimodal processing + try: + from qwen_omni_utils import process_mm_info + + self.process_mm_info = process_mm_info + except ImportError: + raise ImportError( + "qwen_omni_utils is required for Qwen3OmniImageProcessor. " + "Please install it from https://github.com/QwenLM/Qwen3-Omni" + ) + + def preprocess_function(self, examples): + """Preprocess function for Qwen3-Omni.""" + question = examples.get("question", "Describe this image.") + + # Build conversation in Qwen format + content = [] + if examples.get("image") is not None: + content.append({"type": "image", "image": examples["image"]}) + if examples.get("audio") is not None: + content.append({"type": "audio", "audio": examples["audio"]}) + if examples.get("video") is not None: + content.append({"type": "video", "video": examples["video"]}) + content.append({"type": "text", "text": question}) + + conversation = [{"role": "user", "content": content}] + + # Apply chat template (tokenize=False to get string) + text = self.tokenizer.apply_chat_template( + conversation, add_generation_prompt=True, tokenize=False + ) + + # Extract multimodal info using qwen_omni_utils + audios, images, videos = self.process_mm_info( + conversation, use_audio_in_video=self.use_audio_in_video + ) + + # Process inputs with the processor + values = self.tokenizer( + text=text, + audio=audios, + images=images, + videos=videos, + return_tensors="pt", + padding=True, + use_audio_in_video=self.use_audio_in_video, + ) + + # Define all possible keys to ensure consistent schema for Arrow serialization + all_keys = [ + "input_ids", + "attention_mask", + "pixel_values", + "image_grid_thw", + "audio_features", + "audio_feature_lens", + "video_grid_thw", + ] + + # Convert tensors to lists for Arrow serialization compatibility + # Tensor conversion back happens in collate_function + result = dict.fromkeys(all_keys) # Initialize all keys to None + for key, val in values.items(): + if val is not None and hasattr(val, "tolist"): + result[key] = val.tolist() + elif val is not None: + result[key] = val + + return result + + def collate_function(self, batch): + """Collate function to process inputs during data loading.""" + result = {} + + # Take first item from batch (batch_size handling) + first = batch[0] + + # Convert lists to tensors and move to device + if "input_ids" in first and first["input_ids"] is not None: + result["input_ids"] = torch.LongTensor(first["input_ids"]).to(self.device) + if "attention_mask" in first and first["attention_mask"] is not None: + result["attention_mask"] = torch.LongTensor(first["attention_mask"]).to(self.device) + + # Handle pixel values for images + if first.get("pixel_values") is not None: + result["pixel_values"] = torch.tensor(first["pixel_values"]).to(self.device) + + # Handle image grid thw (tile height width info) + if first.get("image_grid_thw") is not None: + result["image_grid_thw"] = torch.LongTensor(first["image_grid_thw"]).to(self.device) + + # Handle audio features if present + if first.get("audio_feature_lens") is not None: + result["audio_feature_lens"] = torch.LongTensor(first["audio_feature_lens"]).to( + self.device + ) + if first.get("audio_features") is not None: + result["audio_features"] = torch.tensor(first["audio_features"]).to(self.device) + + # Handle video features if present + if first.get("video_grid_thw") is not None: + result["video_grid_thw"] = torch.LongTensor(first["video_grid_thw"]).to(self.device) + + return result diff --git a/modelopt/torch/utils/vlm_dataset_utils.py b/modelopt/torch/utils/vlm_dataset_utils.py index 1d9f59484..84644b744 100644 --- a/modelopt/torch/utils/vlm_dataset_utils.py +++ b/modelopt/torch/utils/vlm_dataset_utils.py @@ -19,7 +19,7 @@ from torch.utils.data import DataLoader -from .image_processor import MllamaImageProcessor +from .image_processor import BaseImageProcessor # Use dict to store the config for each dataset. # If we want to export more options to user like target languages, we need more standardized approach like dataclass. @@ -75,7 +75,7 @@ def get_supported_vlm_datasets() -> list[str]: def get_vlm_dataset_dataloader( dataset_name: str = "scienceqa", - processor: MllamaImageProcessor = None, + processor: BaseImageProcessor = None, batch_size: int = 1, num_samples: int = 512, ) -> DataLoader: From 7857728b642b085f244335f08f4bf0bd6209f1ee Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Sat, 13 Dec 2025 01:37:15 +0000 Subject: [PATCH 02/10] Add the finevideo dataset for calibration Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 55 +++- modelopt/torch/utils/__init__.py | 1 + modelopt/torch/utils/video_dataset_utils.py | 292 ++++++++++++++++++++ 3 files changed, 334 insertions(+), 14 deletions(-) create mode 100644 modelopt/torch/utils/video_dataset_utils.py diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 40e8ee42e..47d186a57 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -66,6 +66,11 @@ ) from modelopt.torch.utils.memory_monitor import launch_memory_monitor from modelopt.torch.utils.speech_dataset_utils import get_speech_dataset_dataloader +from modelopt.torch.utils.video_dataset_utils import ( + Qwen3OmniVideoProcessor, + get_supported_video_datasets, + get_video_dataset_dataloader, +) from modelopt.torch.utils.vlm_dataset_utils import get_vlm_dataset_dataloader RAND_SEED = 1234 @@ -464,18 +469,37 @@ def main(args): num_samples=args.calib_size[0], ) elif model_type == "qwen3omni": - assert processor is not None and isinstance(processor, Qwen3OmniImageProcessor), ( - "The Qwen3OmniImageProcessor must be set." - ) assert len(args.calib_size) == 1, ( "qwen3omni only supports one dataset for calibration, can extend this in the future" ) - calib_dataloader = get_vlm_dataset_dataloader( - dataset_name=args.dataset[0] if args.dataset else "scienceqa", - processor=processor, - batch_size=args.batch_size, - num_samples=args.calib_size[0], - ) + assert processor is not None, "The processor must be set for qwen3omni model." + dataset_name = args.dataset[0] if args.dataset else "scienceqa" + # Check if using video dataset (e.g., finevideo) + if dataset_name in get_supported_video_datasets(): + video_processor = Qwen3OmniVideoProcessor( + processor.tokenizer if hasattr(processor, "tokenizer") else processor, + device=device, + dtype=model.dtype, + use_audio_in_video=True, + ) + calib_dataloader = get_video_dataset_dataloader( + dataset_name=dataset_name, + processor=video_processor, + batch_size=args.batch_size, + num_samples=args.calib_size[0], + ) + else: + assert processor is not None and isinstance(processor, Qwen3OmniImageProcessor), ( + "The Qwen3OmniImageProcessor must be set." + ) + # Set the dtype for proper tensor conversion in collate_function + processor.dtype = model.dtype + calib_dataloader = get_vlm_dataset_dataloader( + dataset_name=dataset_name, + processor=processor, + batch_size=args.batch_size, + num_samples=args.calib_size[0], + ) elif model_type == "whisper": assert processor is not None and isinstance(processor, WhisperProcessor), ( "The AutoProcessor must be set." @@ -529,9 +553,10 @@ def main(args): if not model_is_already_quantized or calibration_only: # Only run single sample for preview - input_ids = next(iter(calib_dataloader))[ - "input_features" if model_type == "whisper" else "input_ids" - ][0:1] + calib_batch = next(iter(calib_dataloader)) + input_ids = calib_batch["input_features" if model_type == "whisper" else "input_ids"][ + 0:1 + ] # Generate preview before quantization if is_nemotron_vl_model and tokenizer is not None: @@ -545,7 +570,8 @@ def main(args): ) elif model_type == "qwen3omni": # Qwen3Omni returns (text_ids, audio) tuple; text_ids has .sequences - result = full_model.generate(input_ids, max_new_tokens=100) + # Pass full batch with all multimodal inputs + result = full_model.generate(**calib_batch, max_new_tokens=100) if isinstance(result, tuple): text_ids, _ = result generated_ids_before_ptq = ( @@ -583,7 +609,8 @@ def main(args): generated_ids_after_ptq = None if model_type == "qwen3omni": # Qwen3Omni returns (text_ids, audio) tuple; text_ids has .sequences - result = full_model.generate(input_ids, max_new_tokens=100) + # Pass full batch with all multimodal inputs + result = full_model.generate(**calib_batch, max_new_tokens=100) if isinstance(result, tuple): text_ids, _ = result generated_ids_after_ptq = ( diff --git a/modelopt/torch/utils/__init__.py b/modelopt/torch/utils/__init__.py index 3ae385ac6..b909609c4 100644 --- a/modelopt/torch/utils/__init__.py +++ b/modelopt/torch/utils/__init__.py @@ -26,4 +26,5 @@ from .perf import * from .regex import * from .tensor import * +from .video_dataset_utils import * from .vlm_dataset_utils import * diff --git a/modelopt/torch/utils/video_dataset_utils.py b/modelopt/torch/utils/video_dataset_utils.py new file mode 100644 index 000000000..6ae5c2d2a --- /dev/null +++ b/modelopt/torch/utils/video_dataset_utils.py @@ -0,0 +1,292 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for getting samples and forward loop function for video datasets.""" + +import os +import tempfile +from typing import Any + +import torch +from torch.utils.data import DataLoader + +from .image_processor import BaseImageProcessor + +# Use dict to store the config for each dataset. +SUPPORTED_VIDEO_DATASET_CONFIG: dict[str, dict[str, Any]] = { + "finevideo": { + "config": {"path": "HuggingFaceFV/finevideo", "split": "train", "streaming": True} + }, +} + +__all__ = [ + "Qwen3OmniVideoProcessor", + "get_supported_video_datasets", + "get_video_dataset_dataloader", +] + + +def _get_video_dataset(dataset_name: str, num_samples: int): + """Load a portion of train dataset with the dataset name and a given size. + + Args: + dataset_name: Name of the dataset to load. + num_samples: Number of samples to load from the dataset. + + Returns: + A hugging face Dataset. + """ + if dataset_name in SUPPORTED_VIDEO_DATASET_CONFIG: + from datasets import Dataset, load_dataset + + config = SUPPORTED_VIDEO_DATASET_CONFIG[dataset_name]["config"] + is_streaming = config.get("streaming", False) + + dataset = load_dataset(**config) + + if is_streaming: + # For streaming datasets, use take() and convert to list then Dataset + samples = list(dataset.take(num_samples)) + return Dataset.from_list(samples) + else: + return dataset.select(range(num_samples)) + else: + raise NotImplementedError( + f"dataset {dataset_name} is not supported. Please use one of the following:" + f" {get_supported_video_datasets()}." + ) + + +def get_supported_video_datasets() -> list[str]: + """Retrieves a list of video datasets supported. + + Returns: + A list of strings, where each string is the name of a supported dataset. + + Example usage: + + .. code-block:: python + + from modelopt.torch.utils import get_supported_video_datasets + + print("Supported video datasets:", get_supported_video_datasets()) + """ + return list(SUPPORTED_VIDEO_DATASET_CONFIG.keys()) + + +def get_video_dataset_dataloader( + dataset_name: str = "finevideo", + processor: "Qwen3OmniVideoProcessor" = None, + batch_size: int = 1, + num_samples: int = 512, +) -> DataLoader: + """Get a dataloader with the dataset name and processor of the target model. + + Args: + dataset_name: Name of the dataset to load. + processor: Processor used for encoding video and text data. + batch_size: Batch size of the returned dataloader. + num_samples: Number of samples from the dataset. + + Returns: + An instance of dataloader. + """ + assert processor is not None, "Please provide a valid processor." + + dataset = _get_video_dataset(dataset_name, num_samples=num_samples) + # Apply the preprocessing function to the dataset + processed_dataset = dataset.map( + processor.preprocess_function, batched=False, remove_columns=dataset.column_names + ) + + # Create DataLoader with the custom collate function + return DataLoader( + processed_dataset, + batch_size=batch_size, + shuffle=False, + collate_fn=processor.collate_function, + ) + + +class Qwen3OmniVideoProcessor(BaseImageProcessor): + """Video processor for Qwen3-Omni multimodal model with finevideo dataset support.""" + + def __init__(self, tokenizer, device="cuda", dtype=None, use_audio_in_video=True): + """Constructor. + + Args: + tokenizer: The Qwen3OmniMoeProcessor for tokenizing and processing inputs. + device: Device to move tensors to. + dtype: dtype for float tensors (e.g., torch.bfloat16). If None, uses default. + use_audio_in_video: Whether to extract and use audio from video files. + """ + super().__init__(tokenizer, device) + self.dtype = dtype + self.use_audio_in_video = use_audio_in_video + self._temp_dir = tempfile.mkdtemp(prefix="qwen3omni_video_") + self._video_counter = 0 + # Try to import qwen_omni_utils for multimodal processing + try: + from qwen_omni_utils import process_mm_info + + self.process_mm_info = process_mm_info + except ImportError: + raise ImportError( + "qwen_omni_utils is required for Qwen3OmniVideoProcessor. " + "Please install it from https://github.com/QwenLM/Qwen3-Omni" + ) + + def _save_video_bytes_to_file(self, video_bytes: bytes) -> str: + """Save video bytes to a temporary file and return the path. + + Args: + video_bytes: Raw video bytes (e.g., from finevideo's 'mp4' field). + + Returns: + Path to the temporary video file. + """ + video_path = os.path.join(self._temp_dir, f"video_{self._video_counter}.mp4") + self._video_counter += 1 + with open(video_path, "wb") as f: + f.write(video_bytes) + return video_path + + def preprocess_function(self, examples): + """Preprocess function for Qwen3-Omni with video support. + + Handles both standard video paths and raw video bytes (finevideo format). + """ + # Get question/prompt - finevideo has metadata in 'json' field + if "json" in examples and examples["json"] is not None: + metadata = examples["json"] + # Try to get a meaningful question from metadata + category = metadata.get("content_fine_category", "") + question = f"/no_think Describe what is happening in this video in detail. Category hint: {category}" + else: + question = examples.get("question", "/no_think Describe this video in detail.") + + # Build conversation in Qwen format + content = [] + + # Handle video - check for raw bytes (finevideo format) or path + video_path = None + if examples.get("mp4") is not None: + # finevideo format: raw video bytes in 'mp4' field + video_path = self._save_video_bytes_to_file(examples["mp4"]) + elif examples.get("video") is not None: + # Standard format: video path or URL + video_path = examples["video"] + + if video_path is not None: + content.append({"type": "video", "video": video_path}) + + content.append({"type": "text", "text": question}) + + conversation = [{"role": "user", "content": content}] + + # Apply chat template (tokenize=False to get string) + text = self.tokenizer.apply_chat_template( + conversation, add_generation_prompt=True, tokenize=False + ) + + # Extract multimodal info using qwen_omni_utils + audios, images, videos = self.process_mm_info( + conversation, use_audio_in_video=self.use_audio_in_video + ) + + # Process inputs with the processor + values = self.tokenizer( + text=text, + audio=audios, + images=images, + videos=videos, + return_tensors="pt", + padding=True, + use_audio_in_video=self.use_audio_in_video, + ) + # Define all possible keys to ensure consistent schema for Arrow serialization + all_keys = [ + "input_ids", + "attention_mask", + "pixel_values_videos", + "video_grid_thw", + "video_second_per_grid", + "feature_attention_mask", + "input_features", + ] + + # Convert tensors to lists for Arrow serialization compatibility + # Tensor conversion back happens in collate_function + result = dict.fromkeys(all_keys) # Initialize all keys to None + for key, val in values.items(): + if val is not None and hasattr(val, "tolist"): + result[key] = val.tolist() + elif val is not None: + result[key] = val + + return result + + def collate_function(self, batch): + """Collate function to process inputs during data loading.""" + result = {} + + # Take first item from batch (batch_size handling) + first = batch[0] + + # Convert lists to tensors and move to device + if first.get("input_ids") is not None: + result["input_ids"] = torch.LongTensor(first["input_ids"]).to(self.device) + if first.get("attention_mask") is not None: + result["attention_mask"] = torch.LongTensor(first["attention_mask"]).to(self.device) + + # Handle pixel values for video frames + if first.get("pixel_values_videos") is not None: + pv = torch.tensor(first["pixel_values_videos"]) + if self.dtype is not None: + pv = pv.to(self.dtype) + result["pixel_values_videos"] = pv.to(self.device) + + # Handle video grid thw (tile height width info) + if first.get("video_grid_thw") is not None: + result["video_grid_thw"] = torch.LongTensor(first["video_grid_thw"]).to(self.device) + + # Handle video second per grid (temporal info for rope) + if first.get("video_second_per_grid") is not None: + result["video_second_per_grid"] = torch.tensor(first["video_second_per_grid"]).to( + self.device + ) + + # Handle audio features if present + if first.get("feature_attention_mask") is not None: + result["feature_attention_mask"] = torch.LongTensor(first["feature_attention_mask"]).to( + self.device + ) + if first.get("input_features") is not None: + inp_feat = torch.tensor(first["input_features"]) + if self.dtype is not None: + inp_feat = inp_feat.to(self.dtype) + result["input_features"] = inp_feat.to(self.device) + + # Pass use_audio_in_video flag to model.generate() for Qwen3Omni + result["use_audio_in_video"] = self.use_audio_in_video + + return result + + def cleanup(self): + """Clean up temporary video files.""" + import shutil + + if os.path.exists(self._temp_dir): + shutil.rmtree(self._temp_dir) From ae1346998a2baf679e09ed71397ffcc2a00c43e1 Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Tue, 16 Dec 2025 20:26:07 +0000 Subject: [PATCH 03/10] Add option to disable talker Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 47d186a57..0a35db0f1 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -14,6 +14,7 @@ # limitations under the License. import argparse +import os import random import time import warnings @@ -314,7 +315,7 @@ def main(args): model_is_already_quantized = is_quantized(model) model_type = get_model_type(model) - if model_type == "qwen3omni": + if model_type == "qwen3omni" and os.environ.get("DISABLE_TALKER", "0") == "1": model.disable_talker() device = model.device From 5746ea0cd3c387a0f2eef936e260b4ca485e0af9 Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Tue, 16 Dec 2025 21:57:05 +0000 Subject: [PATCH 04/10] Add quantization configs for the model Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 60 +++++++++++++++++++++++++-- modelopt/torch/utils/dataset_utils.py | 2 +- 2 files changed, 58 insertions(+), 4 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 0a35db0f1..c0b4fa85c 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -72,7 +72,10 @@ get_supported_video_datasets, get_video_dataset_dataloader, ) -from modelopt.torch.utils.vlm_dataset_utils import get_vlm_dataset_dataloader +from modelopt.torch.utils.vlm_dataset_utils import ( + get_supported_vlm_datasets, + get_vlm_dataset_dataloader, +) RAND_SEED = 1234 @@ -90,6 +93,11 @@ "w4a8_nvfp4_fp8": mtq.W4A8_NVFP4_FP8_CFG, "w4a8_mxfp4_fp8": mtq.W4A8_MXFP4_FP8_CFG, "nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG, + "qwen3_nvfp4_qkv_disabled": mtq.NVFP4_DEFAULT_CFG, + "qwen3_nvfp4_qkvo_disabled": mtq.NVFP4_DEFAULT_CFG, + "qwen3_nvfp4_first_n_disabled": mtq.NVFP4_DEFAULT_CFG, + "qwen3_nvfp4_last_n_disabled": mtq.NVFP4_DEFAULT_CFG, + "qwen3_first_and_last_n_disabled": mtq.NVFP4_DEFAULT_CFG, } KV_QUANT_CFG_CHOICES = { @@ -295,6 +303,40 @@ def main(args): f"Quantization format is not supported for low memory mode. Supported formats: {QUANT_CFG_CHOICES.keys()}" ) quant_cfg = QUANT_CFG_CHOICES[args.qformat] + + # Qwen3 specific quantizer disabling patterns (thinker.model.layers only) + if args.qformat == "qwen3_nvfp4_qkv_disabled": + # Disable q_proj, k_proj, v_proj quantizers + for proj in ["q_proj", "k_proj", "v_proj"]: + quant_cfg["quant_cfg"][f"*thinker.model.layers.*.self_attn.{proj}*"] = { + "enable": False + } + elif args.qformat == "qwen3_nvfp4_qkvo_disabled": + # Disable q_proj, k_proj, v_proj, o_proj quantizers + for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]: + quant_cfg["quant_cfg"][f"*thinker.model.layers.*.self_attn.{proj}*"] = { + "enable": False + } + elif args.qformat == "qwen3_nvfp4_first_n_disabled": + # Disable first N layers (e.g., layers 0-7) + n_layers_to_disable = 8 + for i in range(n_layers_to_disable): + quant_cfg["quant_cfg"][f"*thinker.model.layers.{i}.*"] = {"enable": False} + elif args.qformat == "qwen3_nvfp4_last_n_disabled": + # Disable last N layers (e.g., layers 40-47 for 48 total layers) + total_layers = 48 + n_layers_to_disable = 8 + for i in range(total_layers - n_layers_to_disable, total_layers): + quant_cfg["quant_cfg"][f"*thinker.model.layers.{i}.*"] = {"enable": False} + elif args.qformat == "qwen3_first_and_last_n_disabled": + # Disable both first N and last N layers + total_layers = 48 + n_layers_to_disable = 4 + for i in range(n_layers_to_disable): + quant_cfg["quant_cfg"][f"*thinker.model.layers.{i}.*"] = {"enable": False} + for i in range(total_layers - n_layers_to_disable, total_layers): + quant_cfg["quant_cfg"][f"*thinker.model.layers.{i}.*"] = {"enable": False} + if args.kv_cache_qformat != "none": quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant( quant_cfg, getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"] @@ -316,6 +358,7 @@ def main(args): model_type = get_model_type(model) if model_type == "qwen3omni" and os.environ.get("DISABLE_TALKER", "0") == "1": + print("Disabling talker for Qwen3Omni model") model.disable_talker() device = model.device @@ -474,7 +517,7 @@ def main(args): "qwen3omni only supports one dataset for calibration, can extend this in the future" ) assert processor is not None, "The processor must be set for qwen3omni model." - dataset_name = args.dataset[0] if args.dataset else "scienceqa" + dataset_name = args.dataset[0] if args.dataset else "cnn_dailymail" # Check if using video dataset (e.g., finevideo) if dataset_name in get_supported_video_datasets(): video_processor = Qwen3OmniVideoProcessor( @@ -489,7 +532,7 @@ def main(args): batch_size=args.batch_size, num_samples=args.calib_size[0], ) - else: + elif dataset_name in get_supported_vlm_datasets(): assert processor is not None and isinstance(processor, Qwen3OmniImageProcessor), ( "The Qwen3OmniImageProcessor must be set." ) @@ -501,6 +544,17 @@ def main(args): batch_size=args.batch_size, num_samples=args.calib_size[0], ) + else: + # Text-only datasets (e.g., cnn_dailymail) + qwen3omni_tokenizer = processor.tokenizer.tokenizer + calib_dataloader = get_dataset_dataloader( + dataset_name=dataset_name, + tokenizer=qwen3omni_tokenizer, + batch_size=args.batch_size, + num_samples=args.calib_size[0], + device=device, + ) + print(f"Selected dataset for calibration: {dataset_name}") elif model_type == "whisper": assert processor is not None and isinstance(processor, WhisperProcessor), ( "The AutoProcessor must be set." diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index 141fdaacb..a3271e075 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -74,7 +74,7 @@ }, "cnn_dailymail": { "config": {"path": "cnn_dailymail", "name": "3.0.0", "split": ["train"]}, - "preprocess": lambda sample: sample["article"], + "preprocess": lambda sample: "/no_think " + sample["article"], }, "pile": { "config": {"path": "monology/pile-uncopyrighted", "name": "v1.0", "split": ["train"]}, From 9ec99a0ec7ff22538da14cfe5dc92223a240fba9 Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Tue, 16 Dec 2025 23:44:58 +0000 Subject: [PATCH 05/10] Register Qwen3 thinker and talker sparse moe blocks in quant module Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- .../torch/quantization/plugins/huggingface.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index a3fa6ef1a..867f0ed4e 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -576,6 +576,23 @@ def top_k(self, value): except ImportError: pass +try: + from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import ( + Qwen3OmniMoeTalkerTextSparseMoeBlock, + Qwen3OmniMoeThinkerTextSparseMoeBlock, + ) + + if Qwen3OmniMoeTalkerTextSparseMoeBlock not in QuantModuleRegistry: + QuantModuleRegistry.register( + {Qwen3OmniMoeTalkerTextSparseMoeBlock: "hf.Qwen3OmniMoeTalkerTextSparseMoeBlock"} + )(_QuantSparseMoe) + if Qwen3OmniMoeThinkerTextSparseMoeBlock not in QuantModuleRegistry: + QuantModuleRegistry.register( + {Qwen3OmniMoeThinkerTextSparseMoeBlock: "hf.Qwen3OmniMoeThinkerTextSparseMoeBlock"} + )(_QuantSparseMoe) +except ImportError: + pass + class _QuantGptOssExperts(_QuantFunctionalMixin): """Quantized wrapper for `transformers.GptOssExperts`. From 156f7ee70b5a2f35ba27484a0a98df3d1d770471 Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Tue, 16 Dec 2025 23:52:54 +0000 Subject: [PATCH 06/10] remove first_n and last_n configs Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index c0b4fa85c..7273de5c7 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -95,8 +95,6 @@ "nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG, "qwen3_nvfp4_qkv_disabled": mtq.NVFP4_DEFAULT_CFG, "qwen3_nvfp4_qkvo_disabled": mtq.NVFP4_DEFAULT_CFG, - "qwen3_nvfp4_first_n_disabled": mtq.NVFP4_DEFAULT_CFG, - "qwen3_nvfp4_last_n_disabled": mtq.NVFP4_DEFAULT_CFG, "qwen3_first_and_last_n_disabled": mtq.NVFP4_DEFAULT_CFG, } @@ -317,17 +315,6 @@ def main(args): quant_cfg["quant_cfg"][f"*thinker.model.layers.*.self_attn.{proj}*"] = { "enable": False } - elif args.qformat == "qwen3_nvfp4_first_n_disabled": - # Disable first N layers (e.g., layers 0-7) - n_layers_to_disable = 8 - for i in range(n_layers_to_disable): - quant_cfg["quant_cfg"][f"*thinker.model.layers.{i}.*"] = {"enable": False} - elif args.qformat == "qwen3_nvfp4_last_n_disabled": - # Disable last N layers (e.g., layers 40-47 for 48 total layers) - total_layers = 48 - n_layers_to_disable = 8 - for i in range(total_layers - n_layers_to_disable, total_layers): - quant_cfg["quant_cfg"][f"*thinker.model.layers.{i}.*"] = {"enable": False} elif args.qformat == "qwen3_first_and_last_n_disabled": # Disable both first N and last N layers total_layers = 48 From f4ca2857b7518b4e297bb4e15b5808eb2ff708e8 Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Wed, 17 Dec 2025 00:02:35 +0000 Subject: [PATCH 07/10] Update quantization modes to stack on top of one another Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 7273de5c7..c82ac2ef2 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -303,19 +303,19 @@ def main(args): quant_cfg = QUANT_CFG_CHOICES[args.qformat] # Qwen3 specific quantizer disabling patterns (thinker.model.layers only) - if args.qformat == "qwen3_nvfp4_qkv_disabled": + if "qkv_disabled" in args.qformat: # Disable q_proj, k_proj, v_proj quantizers for proj in ["q_proj", "k_proj", "v_proj"]: quant_cfg["quant_cfg"][f"*thinker.model.layers.*.self_attn.{proj}*"] = { "enable": False } - elif args.qformat == "qwen3_nvfp4_qkvo_disabled": + if "qkvo_disabled" in args.qformat: # Disable q_proj, k_proj, v_proj, o_proj quantizers - for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]: + for proj in ["o_proj"]: quant_cfg["quant_cfg"][f"*thinker.model.layers.*.self_attn.{proj}*"] = { "enable": False } - elif args.qformat == "qwen3_first_and_last_n_disabled": + if "first_and_last_n_disabled" in args.qformat: # Disable both first N and last N layers total_layers = 48 n_layers_to_disable = 4 From c5f2fcea92adead4056a72938a3814d84c332823 Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Wed, 17 Dec 2025 03:37:37 +0000 Subject: [PATCH 08/10] Add a text processor for text datasets Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 15 ++++- modelopt/torch/utils/dataset_utils.py | 85 ++++++++++++++++++++++++- modelopt/torch/utils/image_processor.py | 64 +++++++++++++++++++ 3 files changed, 160 insertions(+), 4 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index c82ac2ef2..0dd7d9ac3 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -58,12 +58,14 @@ create_forward_loop, get_dataset_dataloader, get_max_batch_size, + get_qwen3omni_text_dataloader, get_supported_datasets, ) from modelopt.torch.utils.image_processor import ( BaseImageProcessor, MllamaImageProcessor, Qwen3OmniImageProcessor, + Qwen3OmniTextProcessor, ) from modelopt.torch.utils.memory_monitor import launch_memory_monitor from modelopt.torch.utils.speech_dataset_utils import get_speech_dataset_dataloader @@ -533,12 +535,19 @@ def main(args): ) else: # Text-only datasets (e.g., cnn_dailymail) - qwen3omni_tokenizer = processor.tokenizer.tokenizer - calib_dataloader = get_dataset_dataloader( + # Use Qwen3OmniTextProcessor to apply proper conversation template + # See: https://huggingface.co/Qwen/Qwen3-Omni-30B-A3B-Thinking + text_processor = Qwen3OmniTextProcessor( + processor=processor.tokenizer, # Pass the underlying HF processor + device=device, + dtype=model.dtype, + ) + calib_dataloader = get_qwen3omni_text_dataloader( dataset_name=dataset_name, - tokenizer=qwen3omni_tokenizer, + processor=text_processor, batch_size=args.batch_size, num_samples=args.calib_size[0], + max_sample_length=args.calib_seq, device=device, ) print(f"Selected dataset for calibration: {dataset_name}") diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index a3271e075..af9c83be4 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -74,7 +74,7 @@ }, "cnn_dailymail": { "config": {"path": "cnn_dailymail", "name": "3.0.0", "split": ["train"]}, - "preprocess": lambda sample: "/no_think " + sample["article"], + "preprocess": lambda sample: sample["article"], }, "pile": { "config": {"path": "monology/pile-uncopyrighted", "name": "v1.0", "split": ["train"]}, @@ -98,6 +98,7 @@ "create_forward_loop", "get_dataset_dataloader", "get_max_batch_size", + "get_qwen3omni_text_dataloader", "get_supported_datasets", ] @@ -243,6 +244,88 @@ def get_dataset_dataloader( return calib_dataloader +def get_qwen3omni_text_dataloader( + dataset_name: str | list[str] = "cnn_dailymail", + processor=None, + batch_size: int = 1, + num_samples: int | list[int] = 512, + max_sample_length: int = 512, + device: str | None = None, +) -> DataLoader: + """Get a text-only dataloader for Qwen3-Omni with proper conversation template applied. + + This function applies the Qwen3-Omni chat template to text samples before tokenization, + which is required for proper calibration of Qwen3-Omni models with text-only datasets. + + See: https://huggingface.co/Qwen/Qwen3-Omni-30B-A3B-Thinking + + Args: + dataset_name: Name of the dataset(s) to load. + processor: Qwen3OmniTextProcessor instance wrapping the Qwen3OmniMoeProcessor. + batch_size: Batch size of the returned dataloader. + num_samples: Number of samples from the dataset. + max_sample_length: Maximum length of a sample (for truncation). + device: Target device for the returned dataloader. + + Returns: + A DataLoader with properly formatted inputs for Qwen3-Omni. + """ + assert processor is not None, "Please provide a Qwen3OmniTextProcessor." + + if isinstance(num_samples, int): + num_samples = [num_samples] + + if isinstance(dataset_name, str): + dataset_name = [dataset_name] + + assert len(dataset_name) == len(num_samples), ( + "dataset_name and num_samples must be the same length" + ) + + # Get raw text samples + all_samples = [] + for ds_name, num_sample in zip(dataset_name, num_samples): + samples = _get_dataset_samples(ds_name, num_sample) + all_samples.extend(samples) + + # Preprocess each sample with the conversation template + processed_samples = [] + for text in all_samples: + # Apply conversation template and tokenize + values = processor.preprocess_function(text) + + # Convert to lists for dataset compatibility + sample_dict = {} + for key, val in values.items(): + if val is not None and hasattr(val, "tolist"): + sample_dict[key] = val.tolist() + elif val is not None: + sample_dict[key] = val + processed_samples.append(sample_dict) + + # Create dataset + class _Qwen3OmniTextDataset(torch.utils.data.Dataset): + def __init__(self, samples): + self.samples = samples + + def __getitem__(self, idx): + return self.samples[idx] + + def __len__(self): + return len(self.samples) + + dataset = _Qwen3OmniTextDataset(processed_samples) + + calib_dataloader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + collate_fn=processor.collate_function, + ) + + return calib_dataloader + + def get_supported_datasets() -> list[str]: """Retrieves a list of datasets supported. diff --git a/modelopt/torch/utils/image_processor.py b/modelopt/torch/utils/image_processor.py index 4ed4b363a..9489ebefe 100644 --- a/modelopt/torch/utils/image_processor.py +++ b/modelopt/torch/utils/image_processor.py @@ -115,6 +115,70 @@ def collate_function(self, batch): return batch[0] +class Qwen3OmniTextProcessor(BaseImageProcessor): + """Text-only processor for Qwen3-Omni that applies proper conversation template. + + This processor wraps raw text in the Qwen3-Omni conversation format and applies + the chat template before tokenization. Use this for text-only calibration datasets. + + See: https://huggingface.co/Qwen/Qwen3-Omni-30B-A3B-Thinking + """ + + def __init__(self, processor, device="auto", dtype=None): + """Constructor. + + Args: + processor: The Qwen3OmniMoeProcessor (from AutoProcessor.from_pretrained). + device: Device to move tensors to. + dtype: dtype for float tensors (e.g., torch.bfloat16). If None, uses default. + """ + super().__init__(processor, device) + self.dtype = dtype + + def preprocess_function(self, text: str) -> dict: + """Preprocess a single text sample by applying conversation template. + + Args: + text: Raw text string from dataset. + + Returns: + Dictionary with tokenized inputs. + """ + # Build conversation in Qwen format (text-only) + conversation = [ + {"role": "user", "content": [{"type": "text", "text": "/no_think " + text}]} + ] + + # Apply chat template (tokenize=False to get formatted string) + formatted_text = self.tokenizer.apply_chat_template( + conversation, add_generation_prompt=True, tokenize=False + ) + + # Tokenize with the processor (no multimodal inputs) + values = self.tokenizer( + text=formatted_text, + audio=None, + images=None, + videos=None, + return_tensors="pt", + padding=True, + ) + + return values + + def collate_function(self, batch): + """Collate function to process text inputs during data loading.""" + result = {} + first = batch[0] + + if "input_ids" in first and first["input_ids"] is not None: + result["input_ids"] = torch.LongTensor(first["input_ids"]).to(self.device) + if "attention_mask" in first and first["attention_mask"] is not None: + result["attention_mask"] = torch.LongTensor(first["attention_mask"]).to(self.device) + + return result + + class Qwen3OmniImageProcessor(BaseImageProcessor): """Image processor for Qwen3-Omni multimodal model.""" From e4b374aae71d730adab31c55984c2bb2ae1de4a6 Mon Sep 17 00:00:00 2001 From: Chenjie Luo <108829653+cjluo-nv@users.noreply.github.com> Date: Tue, 16 Dec 2025 22:52:56 -0800 Subject: [PATCH 09/10] Disable Qwen3OmniMoe class registration Comment out import and registration of Qwen3OmniMoe classes. Signed-off-by: Chenjie Luo <108829653+cjluo-nv@users.noreply.github.com> --- .../torch/quantization/plugins/huggingface.py | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 867f0ed4e..5de8c17ce 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -576,22 +576,23 @@ def top_k(self, value): except ImportError: pass -try: - from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import ( - Qwen3OmniMoeTalkerTextSparseMoeBlock, - Qwen3OmniMoeThinkerTextSparseMoeBlock, - ) - - if Qwen3OmniMoeTalkerTextSparseMoeBlock not in QuantModuleRegistry: - QuantModuleRegistry.register( - {Qwen3OmniMoeTalkerTextSparseMoeBlock: "hf.Qwen3OmniMoeTalkerTextSparseMoeBlock"} - )(_QuantSparseMoe) - if Qwen3OmniMoeThinkerTextSparseMoeBlock not in QuantModuleRegistry: - QuantModuleRegistry.register( - {Qwen3OmniMoeThinkerTextSparseMoeBlock: "hf.Qwen3OmniMoeThinkerTextSparseMoeBlock"} - )(_QuantSparseMoe) -except ImportError: - pass +# Uncomment to forward tokens to all MoE experts for full calibration. +# try: +# from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import ( +# Qwen3OmniMoeTalkerTextSparseMoeBlock, +# Qwen3OmniMoeThinkerTextSparseMoeBlock, +# ) + +# if Qwen3OmniMoeTalkerTextSparseMoeBlock not in QuantModuleRegistry: +# QuantModuleRegistry.register( +# {Qwen3OmniMoeTalkerTextSparseMoeBlock: "hf.Qwen3OmniMoeTalkerTextSparseMoeBlock"} +# )(_QuantSparseMoe) +# if Qwen3OmniMoeThinkerTextSparseMoeBlock not in QuantModuleRegistry: +# QuantModuleRegistry.register( +# {Qwen3OmniMoeThinkerTextSparseMoeBlock: "hf.Qwen3OmniMoeThinkerTextSparseMoeBlock"} +# )(_QuantSparseMoe) +# except ImportError: +# pass class _QuantGptOssExperts(_QuantFunctionalMixin): From 0c4b38fca0399b22b1d51ada4b5369ec515e6d82 Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Wed, 17 Dec 2025 08:22:08 +0000 Subject: [PATCH 10/10] Update logic to disable quantizers Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- examples/llm_ptq/example_utils.py | 15 +++++++++++++++ examples/llm_ptq/hf_ptq.py | 23 +++++++++++++---------- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index dd1958ac9..3cf5c840c 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -174,6 +174,21 @@ def build_quant_cfg( quant_cfg["quant_cfg"]["*image*"] = {"enable": False} quant_cfg["quant_cfg"]["*vision*"] = {"enable": False} + # Qwen3 specific quantizer disabling patterns (thinker.model.layers only) + if "qkv_disabled" in qformat: + quant_cfg = copy.deepcopy(quant_cfg) # Don't modify global config + for proj in ["q_proj", "k_proj", "v_proj"]: + quant_cfg["quant_cfg"][f"*thinker.model.layers.*.self_attn.{proj}*"] = { + "enable": False + } + if "qkvo_disabled" in qformat: + if "qkv_disabled" not in qformat: # Avoid double deepcopy + quant_cfg = copy.deepcopy(quant_cfg) + for proj in ["o_proj"]: + quant_cfg["quant_cfg"][f"*thinker.model.layers.*.self_attn.{proj}*"] = { + "enable": False + } + return quant_cfg diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 0dd7d9ac3..6c3a7dcb0 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -14,6 +14,7 @@ # limitations under the License. import argparse +import contextlib import os import random import time @@ -298,12 +299,8 @@ def main(args): use_seq_device_map=args.use_seq_device_map, attn_implementation=args.attn_implementation, ) - else: - assert args.qformat in QUANT_CFG_CHOICES, ( - f"Quantization format is not supported for low memory mode. Supported formats: {QUANT_CFG_CHOICES.keys()}" - ) - quant_cfg = QUANT_CFG_CHOICES[args.qformat] + quant_cfg = QUANT_CFG_CHOICES[args.qformat] # Qwen3 specific quantizer disabling patterns (thinker.model.layers only) if "qkv_disabled" in args.qformat: # Disable q_proj, k_proj, v_proj quantizers @@ -325,6 +322,11 @@ def main(args): quant_cfg["quant_cfg"][f"*thinker.model.layers.{i}.*"] = {"enable": False} for i in range(total_layers - n_layers_to_disable, total_layers): quant_cfg["quant_cfg"][f"*thinker.model.layers.{i}.*"] = {"enable": False} + else: + assert args.qformat in QUANT_CFG_CHOICES, ( + f"Quantization format is not supported for low memory mode. Supported formats: {QUANT_CFG_CHOICES.keys()}" + ) + quant_cfg = QUANT_CFG_CHOICES[args.qformat] if args.kv_cache_qformat != "none": quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant( @@ -357,6 +359,8 @@ def main(args): # since parameters are distributed. Force cuda:0 for input tensors. if device is None or str(device) in ("meta", "cpu"): device = "cuda" + print(f"Overriding device to {device}") + processor = None tokenizer = None @@ -646,11 +650,6 @@ def main(args): print("Updating full_model with quantized language_model...") language_model_lineage[-2].language_model = model - # if args.verbose: - # mtq.print_quant_summary(full_model) - - import contextlib - if args.verbose: with open("./quant_summary.txt", "w") as f, contextlib.redirect_stdout(f): mtq.print_quant_summary(full_model) @@ -746,6 +745,10 @@ def output_decode(generated_ids, input_shape): assert model_type != "dbrx", f"Does not support export {model_type} without quantizaton" print(f"qformat: {args.qformat}. No quantization applied, export {device} model") + if model_type == "qwen3omni": + print("Export of Qwen3Omni model is not supported yet") + return + with torch.inference_mode(): if model_type is None: print(f"Unknown model type {type(model).__name__}. Continue exporting...")