diff --git a/modelopt/torch/export/model_utils.py b/modelopt/torch/export/model_utils.py index 3bd72d9de91..aa81f8213f7 100755 --- a/modelopt/torch/export/model_utils.py +++ b/modelopt/torch/export/model_utils.py @@ -107,7 +107,7 @@ def is_multimodal_model(model): config = model.config # Check for Nemotron-Parse encoder-decoder architecture - architectures = getattr(config, "architectures", []) + architectures = getattr(config, "architectures", []) or [] is_nemotron_parse = any("nemotronparse" in arch.lower() for arch in architectures) return ( diff --git a/modelopt/torch/export/plugins/mcore_common.py b/modelopt/torch/export/plugins/mcore_common.py index d5bab9b4ece..ec262bcf094 100644 --- a/modelopt/torch/export/plugins/mcore_common.py +++ b/modelopt/torch/export/plugins/mcore_common.py @@ -52,6 +52,7 @@ "LlamaForCausalLMEagle3Deep": eagle3_deep_llama_causal_lm_export, "Qwen3ForCausalLM": qwen3_causal_lm_export, "Qwen3MoeForCausalLM": qwen3_causal_lm_export, + "Qwen3_5MoeForConditionalGeneration": qwen3_causal_lm_export, "Qwen2ForCausalLM": qwen25_causal_lm_export, "GptOssForCausalLM": gptoss_causal_lm_export, } @@ -64,6 +65,7 @@ "NemotronHForCausalLM": nemotron_h_causal_lm_import, "Qwen3ForCausalLM": qwen3_causal_lm_import, "Qwen3MoeForCausalLM": qwen3_causal_lm_import, + "Qwen3_5MoeForConditionalGeneration": qwen3_causal_lm_import, "Qwen2ForCausalLM": qwen25_causal_lm_import, "GptOssForCausalLM": gptoss_causal_lm_import, } diff --git a/modelopt/torch/export/plugins/mcore_qwen.py b/modelopt/torch/export/plugins/mcore_qwen.py index 5c4ae0647d8..c26275e7517 100644 --- a/modelopt/torch/export/plugins/mcore_qwen.py +++ b/modelopt/torch/export/plugins/mcore_qwen.py @@ -24,6 +24,7 @@ CustomModuleMapping, GatedMLPMerging, GatedMLPSlicing, + GroupedMLPSlicing, NameRemapping, QKVMerging, QKVSlicing, @@ -68,6 +69,16 @@ "router": NameRemapping("model.layers.{}.mlp.gate."), "local_experts.linear_fc1": GatedMLPSlicing("model.layers.{}.mlp.experts.{}."), "local_experts.linear_fc2": NameRemapping("model.layers.{}.mlp.experts.{}.down_proj."), + # Grouped experts (TEGroupedMLP: fused per-expert weights via grouped GEMM) + "experts.linear_fc1": GroupedMLPSlicing("model.layers.{}.mlp.experts.{}.up_proj"), + "experts.linear_fc2": GroupedMLPSlicing("model.layers.{}.mlp.experts.{}.down_proj"), + # Shared experts (Qwen3.6 MoE) + "shared_experts.linear_fc1": GatedMLPSlicing("model.layers.{}.mlp.shared_experts."), + "shared_experts.linear_fc2": NameRemapping("model.layers.{}.mlp.shared_experts.down_proj."), + # GatedDeltaNet (linear attention) — no QKV slicing, direct name remap + "gated_delta_net_in_proj": NameRemapping("model.layers.{}.linear_attn.in_proj."), + "gated_delta_net_out_norm": NameRemapping("model.layers.{}.linear_attn.out_norm."), + "gated_delta_net_out_proj": NameRemapping("model.layers.{}.linear_attn.out_proj."), } qwen25_causal_lm_import: dict[str, CustomModuleMapping] = { diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 4ceb51cd2c0..cc55545699d 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -665,6 +665,36 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames f"Unsupported quantizer with num_bits: {weight_quantizer.num_bits}" ) + # Handle _QuantFusedExperts modules (e.g. Qwen3.x MoE) which use plural + # ModuleList quantizers (gate_up_proj_weight_quantizers, down_proj_weight_quantizers) + # instead of singular weight_quantizer attributes. + # The quantization format is determined at module setup time, not per-expert. + # Check any quantizer in the list (even disabled ones) to determine the format, + # since calibration may not have activated all experts. + for quantizer_list_name in ["gate_up_proj_weight_quantizers", "down_proj_weight_quantizers"]: + quantizer_list = getattr(module, quantizer_list_name, None) + if quantizer_list is not None and len(quantizer_list) > 0: + # Check any quantizer — enabled or not — for format config. + # Prefer enabled ones first, but fall back to any if none are enabled. + q = None + for candidate in quantizer_list: + if hasattr(candidate, "is_enabled") and candidate.is_enabled: + q = candidate + break + if q is None: + q = quantizer_list[0] + + num_bits = getattr(q, "num_bits", None) + block_sizes = getattr(q, "block_sizes", None) + scale_bits = ( + block_sizes.get("scale_bits", (8, 0)) + if isinstance(block_sizes, dict) and "scale_bits" in block_sizes + else (8, 0) + ) + if num_bits == (2, 1) and scale_bits == (4, 3): + return QUANTIZATION_NVFP4 + # Add other expert quantization format checks here as needed + for weight_name in weight_attr_names(module): quantization = _get_quantization_from_layer(module, quantizer_attr_names(weight_name)) if quantization != QUANTIZATION_NONE: diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index af936a3002a..c0145d16eaa 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -658,6 +658,16 @@ def _process_quantized_modules( raise AssertionError( f"Failed to export module '{name}' (type={type(sub_module).__name__}): {e}" ) from e + elif hasattr(sub_module, "gate_up_proj_weight_quantizers"): + # Generic fused MoE experts (_QuantFusedExperts) with per-expert + # quantizer ModuleLists. Split into per-expert modules and export. + # NOTE: This check must come before type-name checks (e.g. Llama4, + # GptOss) because _QuantFusedExperts wrapping renames quantizers + # to plural ModuleLists (e.g. gate_up_proj_weight_quantizers). + from modelopt.torch.export.moe_utils import _export_fused_experts + + with fsdp2_aware_weight_update(model, sub_module, reshard=False): + _export_fused_experts(sub_module, dtype) elif ( "Llama4TextExperts" in type(sub_module).__name__ or "GptOssExperts" in type(sub_module).__name__ @@ -677,13 +687,6 @@ def _process_quantized_modules( with fsdp2_aware_weight_update(model, sub_module, reshard=False): for weight_name in ["gate_up_proj", "down_proj"]: _export_quantized_weight(sub_module, dtype, weight_name) - elif hasattr(sub_module, "gate_up_proj_weight_quantizers"): - # Generic fused MoE experts (_QuantFusedExperts) with per-expert - # quantizer ModuleLists. Split into per-expert modules and export. - from modelopt.torch.export.moe_utils import _export_fused_experts - - with fsdp2_aware_weight_update(model, sub_module, reshard=False): - _export_fused_experts(sub_module, dtype) def _export_transformers_checkpoint( diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index 62053e549c8..3b997932010 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -1,3 +1,4 @@ +import re # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # @@ -73,6 +74,8 @@ from megatron.core.models.gpt import GPTModel from megatron.core.models.mamba import MambaModel from megatron.core.models.multimodal.llava_model import LLaVAModel + from megatron.core.models.hybrid.hybrid_model import HybridModel + from megatron.bridge.models.qwen_vl import Qwen3VLModel from megatron.core.parallel_state import ( get_pipeline_model_parallel_rank, get_pipeline_model_parallel_world_size, @@ -121,7 +124,7 @@ def __init__( moe_router_dtype: str | None = None, ): """Create a GPTModel exporter instance.""" - if not isinstance(model, (GPTModel, MambaModel, LLaVAModel)): + if not isinstance(model, (GPTModel, MambaModel, HybridModel, LLaVAModel, Qwen3VLModel)): raise ValueError("Input to GPTModelExport must be a megatron.core.models.GPTModel!") self._state_dict = OrderedDict() @@ -269,7 +272,9 @@ def save_pretrained( is_last_stage_main_rank = pp_rank == pp_size - 1 and tp_rank == 0 # Main export process + print("[export] About to build layer_state_dicts...", flush=True) layer_state_dicts = self.layer_state_dicts + print(f"[export] Built {len(layer_state_dicts)} layer state dicts", flush=True) quantization_format = self._get_quantization_format(self.model) quantization = None @@ -363,14 +368,63 @@ def save_pretrained( with open(config_json_file, "w") as f: json.dump(config_dict, f, indent=4) - # save_safetensors(state_dict, save_directory) + # Each EP rank writes to its own subdirectory to avoid OOM from gathering + if torch.distributed.is_initialized(): + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + else: + world_size = 1 + rank = 0 + + rank_save_dir = save_directory + "_rank" + str(rank) + os.makedirs(rank_save_dir, exist_ok=True) + + # Each rank writes its own layer shards save_safetensors_by_layer_index( layer_state_dicts=layer_state_dicts, total_layers=self.model.config.num_layers, - save_directory=save_directory, + save_directory=rank_save_dir, name_template="model-{:05d}-of-{:05d}", ) + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + # Rank 0 merges per-shard safetensors from all rank dirs + if rank == 0 and world_size > 1: + print("[export] Merging shard files from all ranks...", flush=True) + from safetensors import safe_open as _safe_open + from safetensors.torch import save_file as _save_file + for layer_idx in range(self.model.config.num_layers): + shard_name = "model-{:05d}-of-{:05d}".format(layer_idx + 1, self.model.config.num_layers) + ckpt_name = shard_name + ".safetensors" + meta_name = shard_name + ".json" + merged_dict = {} + for r in range(world_size): + rdir = save_directory + "_rank" + str(r) + fpath = os.path.join(rdir, ckpt_name) + if os.path.exists(fpath): + with _safe_open(fpath, framework="pt") as f: + for k in f.keys(): + merged_dict[k] = f.get_tensor(k) + # Write merged shard + os.makedirs(save_directory, exist_ok=True) + _save_file(merged_dict, os.path.join(save_directory, ckpt_name), metadata={"format": "pt"}) + # Build metadata + weight_map = {} + total_size = 0 + for k, v in merged_dict.items(): + weight_map[k] = ckpt_name + total_size += v.numel() * v.element_size() + with open(os.path.join(save_directory, meta_name), "w") as f: + json.dump({"metadata": {"total_size": total_size}, "weight_map": weight_map}, f, indent=4) + print(f"[export] Merged {len(merged_dict)} keys per layer across {world_size} ranks", flush=True) + elif rank == 0: + # Single rank, just rename dir + import shutil + if os.path.exists(save_directory + "_rank0"): + shutil.move(save_directory + "_rank0", save_directory) + @property def state_dict(self): """Return the real quantized state_dict of the base model.""" @@ -392,13 +446,17 @@ def extra_state_dict(self): return self._state_dict def _get_state_dict(self): + print("[export] _get_state_dict called", flush=True) model = self.model + import time as _time + _start = _time.time() # Embedding if hasattr(model, "embedding"): self.rules["word_embeddings"](model.embedding.word_embeddings) # Decoder layers + print(f"[export] Iterating {len(model.decoder.layers)} decoder layers", flush=True) for layer in model.decoder.layers: layer_id = layer.layer_number - 1 if isinstance(layer, MambaLayer): @@ -434,6 +492,19 @@ def _get_fused_norm_weight(self, module): return getattr(module, "layer_norm_weight", None) def _get_transformer_layer_state_dict(self, layer, layer_id): + if layer_id == 0: + print(f"[diag] layer.mlp type: {type(layer.mlp).__name__}", flush=True) + print(f"[diag] mlp attrs: {[a for a in dir(layer.mlp) if not a.startswith('_')][:25]}", flush=True) + print(f"[diag] hasattr mlp.experts: {hasattr(layer.mlp, 'experts')}", flush=True) + if hasattr(layer.mlp, 'experts'): + print(f"[diag] experts type: {type(layer.mlp.experts).__name__}", flush=True) + print(f"[diag] hasattr local_experts: {hasattr(layer.mlp.experts, 'local_experts')}", flush=True) + if hasattr(layer.mlp.experts, 'local_experts'): + print(f"[diag] num local_experts: {len(layer.mlp.experts.local_experts)}", flush=True) + print(f"[diag] hasattr shared_experts: {hasattr(layer.mlp, 'shared_experts')}", flush=True) + if hasattr(layer.mlp, 'config'): + print(f"[diag] mlp.config.num_experts: {getattr(layer.mlp.config, 'num_experts', 'N/A')}", flush=True) + if not isinstance(layer.input_layernorm, IdentityOp): self.rules["input_layernorm"](layer.input_layernorm, layer_id) elif ( @@ -460,8 +531,23 @@ def _get_transformer_layer_state_dict(self, layer, layer_id): self.rules["linear_kv_layernorm"](layer.self_attention.kv_layernorm, layer_id) self.rules["linear_kv_up_proj"](layer.self_attention.linear_kv_up_proj, layer_id) self.rules["linear_proj"](layer.self_attention.linear_proj, layer_id) + elif "GatedDeltaNet" in str(type(layer.self_attention)): + # GatedDeltaNet (linear attention) has in_proj, out_norm, out_proj + # instead of linear_qkv, q_layernorm, etc. + # Use dedicated GDN rules if available (no QKV slicing), else skip. + if "gated_delta_net_in_proj" in self.rules: + self.rules["gated_delta_net_in_proj"](layer.self_attention.in_proj, layer_id) + if hasattr(layer.self_attention, "out_norm") and not isinstance( + layer.self_attention.out_norm, IdentityOp + ): + if "gated_delta_net_out_norm" in self.rules: + self.rules["gated_delta_net_out_norm"](layer.self_attention.out_norm, layer_id) + if "gated_delta_net_out_proj" in self.rules: + self.rules["gated_delta_net_out_proj"](layer.self_attention.out_proj, layer_id) + else: + self.rules["linear_proj"](layer.self_attention.out_proj, layer_id) else: - if layer.self_attention.q_layernorm is not None and not isinstance( + if hasattr(layer.self_attention, "q_layernorm") and layer.self_attention.q_layernorm is not None and not isinstance( layer.self_attention.q_layernorm, (IdentityOp, L2Norm) ): self.rules["q_layernorm"](layer.self_attention.q_layernorm, layer_id) @@ -473,7 +559,7 @@ def _get_transformer_layer_state_dict(self, layer, layer_id): ): # KV cache quant export self.rules["core_attention"](layer.self_attention.core_attention, layer_id) self.rules["linear_proj"](layer.self_attention.linear_proj, layer_id) - if getattr(layer.self_attention.core_attention, "softmax_offset", None) is not None: + if hasattr(layer.self_attention, "core_attention") and getattr(layer.self_attention.core_attention, "softmax_offset", None) is not None: self.rules["softmax_offset"]( layer.self_attention.core_attention.softmax_offset, layer_id ) @@ -503,8 +589,17 @@ def _get_transformer_layer_state_dict(self, layer, layer_id): layer.mlp.shared_experts.linear_fc2, layer_id ) if hasattr(layer.mlp.experts, "local_experts"): + # With expert parallelism, local_experts are indexed 0..N-1 per rank, + # but the global expert ID needs the EP rank offset. + from megatron.core.parallel_state import get_expert_model_parallel_rank, get_expert_model_parallel_world_size + ep_rank = get_expert_model_parallel_rank() + ep_size = get_expert_model_parallel_world_size() + num_local = len(layer.mlp.experts.local_experts) + print(f"[export] layer {layer_id}: {num_local} local_experts, ep_rank={ep_rank}, ep_size={ep_size}", flush=True) if not self.rules.get("use_packed_local_experts", False): - for expert_id, expert in enumerate(layer.mlp.experts.local_experts): + for local_id, expert in enumerate(layer.mlp.experts.local_experts): + expert_id = ep_rank * num_local + local_id + print(f"[export] expert {local_id} -> global {expert_id}, linear_fc1={type(expert.linear_fc1).__name__}", flush=True) self.rules["local_experts.linear_fc1"]( expert.linear_fc1, layer_id, expert_id ) @@ -522,10 +617,26 @@ def _get_transformer_layer_state_dict(self, layer, layer_id): elif "experts.linear_fc1" in self.rules: # TEGroupedMLP: experts use fused grouped GEMM with a single # linear_fc1/linear_fc2 for all experts (no local_experts attribute). - # Uses "experts.linear_fc1" rule (GroupedMLPMerging) instead of - # "local_experts.linear_fc1" which expects per-expert iteration. - self.rules["experts.linear_fc1"](layer.mlp.experts.linear_fc1, layer_id) - self.rules["experts.linear_fc2"](layer.mlp.experts.linear_fc2, layer_id) + # Call _grouped_mlp_slicing directly because the lambda-based dispatch + # cannot handle two-placeholder prefixes (layer_id + expert_id). + raw_mappings = all_mcore_hf_export_mapping[self.arch] + fc1_prefix = raw_mappings["experts.linear_fc1"].target_name_or_prefix + fc2_prefix = raw_mappings["experts.linear_fc2"].target_name_or_prefix + # Fill only the first {} (layer_id), leave second {} for expert_id in _grouped_mlp_slicing + fc1_prefix_partial = re.sub(r'\{\}', str(layer_id), fc1_prefix, count=1) + fc2_prefix_partial = re.sub(r'\{\}', str(layer_id), fc2_prefix, count=1) + # With EP>1, each rank only has a subset of experts. Offset the expert IDs + # by ep_rank * num_local_experts so all ranks write to non-overlapping keys. + from megatron.core.parallel_state import get_expert_model_parallel_rank + ep_rank = get_expert_model_parallel_rank() + expert_offset = ep_rank * layer.mlp.experts.linear_fc1.num_gemms + print(f"[export] layer {layer_id}: TEGroupedMLP, ep_rank={ep_rank}, expert_offset={expert_offset}", flush=True) + self._grouped_mlp_slicing( + layer.mlp.experts.linear_fc1, fc1_prefix_partial, expert_offset=expert_offset + ) + self._grouped_mlp_slicing( + layer.mlp.experts.linear_fc2, fc2_prefix_partial, expert_offset=expert_offset + ) else: self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id) self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id) @@ -951,7 +1062,7 @@ def _gated_mlp_slicing( self._state_dict[gate_proj_key] = val.detach().clone() self._state_dict[up_proj_key] = val.detach().clone() - def _grouped_mlp_slicing(self, module, prefix, parallel_config=None): + def _grouped_mlp_slicing(self, module, prefix, parallel_config=None, expert_offset=0): """Export TEGroupedMLP weights by splitting per-expert weights into individual HF weights. TEGroupedMLP (via TEGroupedLinear) stores weights as weight0, weight1, ..., weight{N-1} @@ -981,9 +1092,10 @@ def _grouped_mlp_slicing(self, module, prefix, parallel_config=None): state_dict = module.state_dict() - for expert_id in range(num_experts): + for local_expert_id in range(num_experts): + expert_id = expert_offset + local_expert_id expert_prefix = prefix.format(expert_id) + "." - weight_key = f"weight{expert_id}" + weight_key = f"weight{local_expert_id}" if weight_key not in state_dict: raise ValueError(f"Missing expected TEGroupedMLP expert weight: {weight_key}") @@ -1008,7 +1120,8 @@ def _grouped_mlp_slicing(self, module, prefix, parallel_config=None): for key, val in name_to_value.items(): if key == "output_scale": continue - for expert_id in range(num_experts): + for local_expert_id in range(num_experts): + expert_id = expert_offset + local_expert_id expert_prefix = prefix.format(expert_id) + "." self._state_dict[expert_prefix + key] = val.detach().clone() diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 186ff1c7edd..850c1eeb423 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -578,6 +578,12 @@ def _nvfp4_selective_quant_cfg( quant_cfg.append( {"quantizer_name": f"{pattern}weight_quantizer", "cfg": copy.deepcopy(quantizer)} ) + # Also match plural ModuleList quantizers used by _QuantFusedExperts + # (e.g. gate_up_proj_weight_quantizers.N) for fused MoE architectures. + for suffix in ["gate_up_proj_weight_quantizers", "down_proj_weight_quantizers"]: + quant_cfg.append( + {"quantizer_name": f"{pattern}{suffix}*", "cfg": copy.deepcopy(quantizer)} + ) if not weight_only: quant_cfg.append( {"quantizer_name": f"{pattern}input_quantizer", "cfg": copy.deepcopy(quantizer)} diff --git a/modelopt/torch/quantization/qtensor/nvfp4_tensor.py b/modelopt/torch/quantization/qtensor/nvfp4_tensor.py index 6ff31424c77..8c2b9cfb0d7 100644 --- a/modelopt/torch/quantization/qtensor/nvfp4_tensor.py +++ b/modelopt/torch/quantization/qtensor/nvfp4_tensor.py @@ -286,6 +286,14 @@ def quantize( input, block_size, weights_scaling_factor_2 ) + # Handle empty tensors (e.g. from TP/EP sharding where this rank has no slice) + if input.numel() == 0: + return ( + cls(input_shape, input_dtype, input), + torch.zeros(*input.shape[:-1], device=input.device, dtype=torch.float8_e4m3fn), + torch.zeros(1, device=input.device, dtype=torch.float32), + ) + # Reshape the weight and scale factors original_shape = input.shape input = input.view((*tuple(input.shape[:-1]), -1, block_size))