From ef7c30435c107072045091100a6c877f10d28e7d Mon Sep 17 00:00:00 2001 From: realAsma Date: Wed, 4 Mar 2026 14:20:54 +0000 Subject: [PATCH 1/3] Refactor _QuantSparseMoe: config-driven token counting, NemotronH detection - Accept n_routed_experts alongside num_experts in _is_sparse_moe_block - Add layer_sync_moe_local_experts_amax to _QuantSparseMoe - Make token counting and force-all-token calibration config-driven (moe_count_expert_calib_tokens, moe_calib_experts_ratio) with lazy init; forward is zero-overhead pass-through when both are disabled Signed-off-by: realAsma Made-with: Cursor Deduplicate layer_sync_moe_local_experts_amax into shared sync_moe_experts_input_amax Signed-off-by: realAsma Made-with: Cursor --- modelopt/torch/quantization/config.py | 15 +- modelopt/torch/quantization/mode.py | 6 + .../nn/modules/tensor_quantizer.py | 2 +- .../torch/quantization/plugins/huggingface.py | 77 +++++---- .../torch/quantization/plugins/megatron.py | 28 +--- modelopt/torch/quantization/utils.py | 42 +++++ .../quantization/plugins/test_sparse_moe.py | 151 +++++++++--------- 7 files changed, 184 insertions(+), 137 deletions(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 80a2a6876..17614499d 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -416,7 +416,7 @@ }, **_default_disabled_quantizer_cfg, }, - "algorithm": "max", + "algorithm": {"method": "max", "kv_scales": "constant", "moe_calib_experts_ratio": 0.5}, } NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG = { @@ -1126,7 +1126,18 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig): description=( "If specified, we force forward tokens to % of experts during the calibration" " pass. This forward is for calibration purpose only and will not affect the" - " actual inference." + " actual inference. Not supported for all MoE architectures; currently works" + " with a few HuggingFace models such as Mixtral, Qwen3Moe, MiniMax." + ), + ) + + moe_count_expert_calib_tokens: bool = ModeloptField( + default=False, + title="Enable expert token counting during MoE calibration.", + description=( + "If True, counts how many tokens are routed to each expert during calibration." + " Not supported for all MoE architectures; currently works with a few HuggingFace" + " models such as Mixtral, Qwen3Moe, MiniMax." ), ) diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index e08efece9..1fbe65406 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -236,6 +236,12 @@ def wrapped_calib_func( if hasattr(module, "_moe_calib_experts_ratio"): module._moe_calib_experts_ratio = moe_calib_experts_ratio + moe_count_expert_calib_tokens = kwargs.pop("moe_count_expert_calib_tokens", False) + if moe_count_expert_calib_tokens: + for module in model.modules(): + if hasattr(module, "_moe_count_expert_calib_tokens"): + module._moe_count_expert_calib_tokens = True + if func is not None: if sequential: if forward_loop is None: diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 2caec2565..7db479f76 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -1338,7 +1338,7 @@ class SequentialQuantizer(nn.Sequential): """ - _delegated_properties = ["fake_quant", "is_enabled"] + _delegated_properties = ["fake_quant", "is_enabled", "amax"] _delegated_methods = [ "reset_amax", "disable", diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 7a06f2214..fcfae0cfe 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -56,7 +56,7 @@ else: weight_dequant = None -from ..utils import replace_function +from ..utils import replace_function, sync_moe_experts_input_amax from .attention import register_attention_for_kv_quant from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear, _QuantFunctionalMixin @@ -440,16 +440,24 @@ def backward(ctx, grad_output): class _QuantSparseMoe(QuantModule): - """Module to support special handling of token dispatching during calibration. + """Quantization wrapper for HuggingFace sparse MoE blocks. - During calibration, we forward all tokens to all experts so that all experts see sufficient tokens to calibrate. - However, even in calibration mode, the actual top_k routing is used to calculate the actual outputs this instance - returns. + Supports ``layer_sync_moe_local_experts_amax`` to sync input quantizer amax across experts. - If calibration is not enabled, this module behaves as a normal MoELayer. + Optionally supports two config-driven features (disabled by default): + - ``_moe_calib_experts_ratio``: force-forward tokens to more experts during calibration. + - ``_moe_count_expert_calib_tokens``: count tokens routed to each expert during calibration. + + When both are disabled, forward is a direct pass-through with zero overhead. """ def _setup(self): + self._moe_calib_experts_ratio = None + self._moe_count_expert_calib_tokens = False + + def _init_token_counting(self): + """Lazy-init token counting infra (buffer + gate hook). Called once from forward.""" + self._token_counting_initialized = True num_experts = 0 if hasattr(self, "gate") and hasattr(self.gate, "num_experts"): num_experts = self.gate.num_experts @@ -458,14 +466,6 @@ def _setup(self): elif hasattr(self, "experts") and hasattr(self.experts, "num_experts"): num_experts = self.experts.num_experts - self.register_buffer( - "expert_token_count", - torch.zeros(num_experts, dtype=torch.long, device=next(self.parameters()).device), - persistent=False, - ) - self._count_expert_tokens = False - self._moe_calib_experts_ratio = None - if num_experts == 0: warnings.warn( f"{self.__class__.__name__}: could not resolve num_experts; " @@ -473,6 +473,12 @@ def _setup(self): ) return + self.register_buffer( + "expert_token_count", + torch.zeros(num_experts, dtype=torch.long, device=next(self.parameters()).device), + persistent=False, + ) + self._count_expert_tokens = False if hasattr(self, "gate"): self.gate.register_forward_hook(self._gate_forward_hook) @@ -492,17 +498,20 @@ def _gate_forward_hook(self, module, input, output): self.expert_token_count += counts.to(self.expert_token_count.device) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if not self._moe_calib_experts_ratio and not self._moe_count_expert_calib_tokens: + return super().forward(hidden_states) + + if self._moe_count_expert_calib_tokens and not hasattr(self, "_token_counting_initialized"): + self._init_token_counting() + is_calib = any(getattr(m, "_if_calib", False) for m in self.experts.modules()) - self._count_expert_tokens = is_calib + self._count_expert_tokens = is_calib and self._moe_count_expert_calib_tokens + if is_calib and self._moe_calib_experts_ratio: self._count_expert_tokens = True assert 0 < self._moe_calib_experts_ratio <= 1, ( "moe_calib_experts_ratio must be between 0 and 1" ) - # If any of the experts are in calibration mode, we will forward all tokens to - # self._moe_calib_experts_ratio % of the experts to improve the calibration coverage. - # This is used only for calibration, we need to re-calculate the actual outputs again using - # the original top_k if TRANSFORMERS_VERSION_GE_5_0: assert hasattr(self, "gate") and hasattr(self.gate, "top_k") original_top_k = self.gate.top_k @@ -512,7 +521,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: super().forward(hidden_states) self.gate.top_k = original_top_k else: - # Path for transformers < 5.0 original_top_k = self.top_k if hasattr(self, "num_experts"): self.top_k = max( @@ -528,12 +536,15 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: super().forward(hidden_states) self.top_k = original_top_k self._count_expert_tokens = False - else: - self._count_expert_tokens = True + output = super().forward(hidden_states) self._count_expert_tokens = False return output + def layer_sync_moe_local_experts_amax(self): + """Sync input_quantizer amax across experts so all share the same amax per quantizer.""" + sync_moe_experts_input_amax(self.experts) + class _QuantLlama4TextExperts(QuantModule): def _setup(self): @@ -1110,17 +1121,21 @@ def register_falcon_linears_on_the_fly(model): QuantModuleRegistry.register({linear_type: linear_type.__name__})(_QuantLinear) +def _has_num_experts(obj): + # n_routed_experts: NemotronH-style MoE + return hasattr(obj, "num_experts") or hasattr(obj, "n_routed_experts") + + def _is_sparse_moe_block(module): """Check if a module is structurally a sparse MoE block compatible with _QuantSparseMoe. - All HuggingFace MoE blocks (Mixtral, Qwen3Moe, Qwen2Moe, Qwen3Next, Llama4, MiniMax, etc.) - share a common structural pattern: a ``gate`` (TopKRouter) sub-module with routing attributes - (``top_k`` and ``num_experts``), and an ``experts`` sub-module. + All HuggingFace MoE blocks (Mixtral, Qwen3Moe, Qwen2Moe, Qwen3Next, Llama4, MiniMax, + NemotronH, etc.) share a common structural pattern: a ``gate`` (TopKRouter) sub-module with + routing attributes (``top_k`` and ``num_experts`` or ``n_routed_experts``), and an ``experts`` + sub-module. This function detects that pattern instead of relying on class names, making it forward-compatible - with new MoE architectures. Some MoE models (e.g. Glm4MoeMoE) have ``gate`` and ``experts`` but - use a different routing interface (``n_routed_experts`` instead of ``num_experts``, custom - ``route_tokens_to_experts``), so we require ``num_experts`` to be present to avoid false positives. + with new MoE architectures. """ if not hasattr(module, "experts"): return False @@ -1128,13 +1143,11 @@ def _is_sparse_moe_block(module): # Primary: gate sub-module has topk/top_k + num_experts (standard TopKRouter pattern) if hasattr(module, "gate"): gate = module.gate - has_topk = hasattr(gate, "top_k") - has_num_experts = hasattr(gate, "num_experts") - if has_topk and has_num_experts: + if hasattr(gate, "top_k") and _has_num_experts(gate): return True # Fallback: top_k + num_experts on the block itself (older transformers, e.g. v4.x Qwen3Next) - return hasattr(module, "top_k") and hasattr(module, "num_experts") + return hasattr(module, "top_k") and _has_num_experts(module) def register_sparse_moe_on_the_fly(model): diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index e84735ae9..03a6cc190 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -41,6 +41,7 @@ from modelopt.torch.utils.distributed import ParallelState from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer +from ..utils import sync_moe_experts_input_amax from ..nn.modules.quant_linear import RealQuantLinear from ..qtensor import QTensorWrapper from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear @@ -577,7 +578,7 @@ def _setup(self): def layer_sync_moe_local_experts_amax(self): """Sync input quantizer amax across local experts in a SequentialMLP. - Ensures all experts have the same input quantizer amax.This function operates + Ensures all experts have the same input quantizer amax. This function operates on a single rank and does not require distributed sync. Distributed amax sync across EP and ETP (for RowParallel) happens in model_calib.max_calibrate(). @@ -586,32 +587,11 @@ def layer_sync_moe_local_experts_amax(self): Note: Because there are logic which calls collective communication based on whether amax is not None, - We need to guarantee that all experts must have amax. Otherwise, there will be deadlock + we need to guarantee that all experts must have amax. Otherwise, there will be deadlock when synchronizing over EP since some ranks may have amax None and not calling the collective communication. """ - # Collect amax from all local experts - amax_dict = {} - for expert in self.local_experts: - for name, module in expert.named_modules(): - if ( - isinstance(module, TensorQuantizer) - and module.amax is not None - and "input_quantizer" in name - ): - stored_amax = amax_dict.get(name) - amax_tensor = module.amax.detach().clone() - amax_dict[name] = ( - amax_tensor - if stored_amax is None - else torch.maximum(stored_amax, amax_tensor) - ) - - # Apply synchronized amax values back to all local experts - for expert in self.local_experts: - for name, module in expert.named_modules(): - if isinstance(module, TensorQuantizer) and name in amax_dict: - module.amax = amax_dict[name].detach().clone() + sync_moe_experts_input_amax(self.local_experts) def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): """Override the default to enable singleton_local_shards. diff --git a/modelopt/torch/quantization/utils.py b/modelopt/torch/quantization/utils.py index fdd8c692d..2e6df4bcf 100644 --- a/modelopt/torch/quantization/utils.py +++ b/modelopt/torch/quantization/utils.py @@ -535,6 +535,48 @@ def set_quantizer_state_dict(model: nn.Module, quantizer_state_dict: dict): module.load_state_dict(quantizer_state_dict[key]) +def sync_moe_experts_input_amax(experts): + """Sync input_quantizer amax across MoE experts and fix missing weight amax. + + 1. Takes the element-wise max of each ``input_quantizer`` amax across all experts + and writes it back, so every expert shares the same input amax. + 2. For any ``weight_quantizer`` that is enabled but has ``amax is None`` (expert + received no tokens during calibration), runs a weight-only ``max_calibrate`` + to populate the missing amax. + """ + from .nn import TensorQuantizer + + amax_dict: dict[str, torch.Tensor] = {} + for expert in experts: + for name, module in expert.named_modules(): + if ( + isinstance(module, TensorQuantizer) + and module.amax is not None + and "input_quantizer" in name + ): + stored_amax = amax_dict.get(name) + amax_tensor = module.amax.detach().clone() + amax_dict[name] = ( + amax_tensor + if stored_amax is None + else torch.maximum(stored_amax, amax_tensor) + ) + + for expert in experts: + for name, module in expert.named_modules(): + if isinstance(module, TensorQuantizer) and name in amax_dict: + module.amax = amax_dict[name].detach().clone() + + from .model_calib import max_calibrate + + for expert in experts: + for name, module in expert.named_modules(): + if name.endswith("weight_quantizer") and module.is_enabled and module.amax is None: + weight = expert.state_dict().get(name.replace("weight_quantizer", "weight")) + if weight is not None: + max_calibrate(module, lambda m, w=weight: m(w), distributed_sync=False) + + @contextmanager def patch_fsdp_mp_dtypes(): """Patch FSDP2 to handle mixed dtypes properly during quantization. diff --git a/tests/unit/torch/quantization/plugins/test_sparse_moe.py b/tests/unit/torch/quantization/plugins/test_sparse_moe.py index 6d548aa40..cf9d724d4 100644 --- a/tests/unit/torch/quantization/plugins/test_sparse_moe.py +++ b/tests/unit/torch/quantization/plugins/test_sparse_moe.py @@ -15,6 +15,8 @@ """Tests for _is_sparse_moe_block and _QuantSparseMoe.""" +import copy + import pytest import torch import torch.nn as nn @@ -147,15 +149,15 @@ def test_block_level_only_num_experts_returns_false(self): module.num_experts = 4 assert _is_sparse_moe_block(module) is False - def test_glm4_like_block_rejected(self): - """A module with n_routed_experts instead of num_experts should be rejected.""" + def test_n_routed_experts_accepted(self): + """A module with n_routed_experts (NemotronH-style) should be accepted.""" module = nn.Module() module.experts = nn.ModuleList([nn.Linear(8, 8)]) gate = nn.Module() gate.top_k = 2 - gate.n_routed_experts = 4 # different attr name + gate.n_routed_experts = 4 module.gate = gate - assert _is_sparse_moe_block(module) is False + assert _is_sparse_moe_block(module) is True # --------------------------------------------------------------------------- @@ -183,54 +185,29 @@ def test_register_sparse_moe_on_the_fly(self): register_sparse_moe_on_the_fly(model) assert QuantModuleRegistry.get(moe_type) is not None - def test_setup_creates_expert_token_count(self): - model = get_tiny_qwen3_moe() - moe_block = self._get_moe_block(model) - moe_type = type(moe_block) - - if QuantModuleRegistry.get(moe_type) is None: - register_sparse_moe_on_the_fly(model) - - converted = QuantModuleRegistry.convert(moe_block) - assert hasattr(converted, "expert_token_count") - if hasattr(moe_block, "gate") and hasattr(moe_block.gate, "num_experts"): - expected_num_experts = moe_block.gate.num_experts - elif hasattr(moe_block, "num_experts"): - expected_num_experts = moe_block.num_experts - elif hasattr(moe_block, "experts") and hasattr(moe_block.experts, "num_experts"): - expected_num_experts = moe_block.experts.num_experts - else: - expected_num_experts = 0 - assert converted.expert_token_count.shape == (expected_num_experts,) - assert converted.expert_token_count.dtype == torch.long - assert (converted.expert_token_count == 0).all() - - def test_setup_count_expert_tokens_default_false(self): + def test_setup_config_knobs_default(self): + """_setup should only initialize config knobs, no buffer or hook.""" model = get_tiny_qwen3_moe() moe_block = self._get_moe_block(model) - moe_type = type(moe_block) - - if QuantModuleRegistry.get(moe_type) is None: + if QuantModuleRegistry.get(type(moe_block)) is None: register_sparse_moe_on_the_fly(model) converted = QuantModuleRegistry.convert(moe_block) - assert converted._count_expert_tokens is False + assert converted._moe_calib_experts_ratio is None + assert converted._moe_count_expert_calib_tokens is False + assert not hasattr(converted, "expert_token_count") - def test_forward_no_calib_matches_original(self): - """When calibration is off, _QuantSparseMoe should produce the same output as the original.""" + def test_forward_default_config_passthrough(self): + """With default config (both features off), forward should be a direct pass-through.""" model = get_tiny_qwen3_moe() moe_block = self._get_moe_block(model) - moe_type = type(moe_block) - - if QuantModuleRegistry.get(moe_type) is None: + if QuantModuleRegistry.get(type(moe_block)) is None: register_sparse_moe_on_the_fly(model) ref_block = self._get_moe_block(get_tiny_qwen3_moe()) ref_block.load_state_dict(moe_block.state_dict()) - converted = QuantModuleRegistry.convert(moe_block) - torch.manual_seed(42) x = torch.randn(1, 4, 32) with torch.no_grad(): out_ref = ref_block(x) @@ -241,31 +218,13 @@ def test_forward_no_calib_matches_original(self): if isinstance(out_test, tuple): out_test = out_test[0] assert torch.allclose(out_ref, out_test, atol=1e-5) - - def test_forward_calib_sends_all_tokens_to_all_experts(self): - """During calibration, all experts should see tokens (expert_token_count all > 0).""" - model = get_tiny_qwen3_moe() - register_sparse_moe_on_the_fly(model) - - def calib_fn(model): - x = model.dummy_inputs["input_ids"] - model(x) - - mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calib_fn) - - for name, module in model.named_modules(): - if hasattr(module, "expert_token_count") and module.expert_token_count.numel() > 0: - assert (module.expert_token_count > 0).all(), ( - f"Not all experts received tokens in {name}: {module.expert_token_count}" - ) + assert not hasattr(converted, "expert_token_count") def test_forward_calib_restores_top_k(self): - """After calibration forward, top_k should be restored to its original value.""" + """After calibration forward with moe_calib_experts_ratio, top_k should be restored.""" model = get_tiny_qwen3_moe() moe_block = self._get_moe_block(model) - moe_type = type(moe_block) - - if QuantModuleRegistry.get(moe_type) is None: + if QuantModuleRegistry.get(type(moe_block)) is None: register_sparse_moe_on_the_fly(model) if TRANSFORMERS_VERSION_GE_5_0: @@ -274,8 +233,9 @@ def test_forward_calib_restores_top_k(self): original_top_k = moe_block.top_k converted = QuantModuleRegistry.convert(moe_block) + converted._moe_calib_experts_ratio = 1.0 - # Simulate calibration mode: set _if_calib on a child TensorQuantizer + # Simulate calibration mode for m in converted.experts.modules(): if hasattr(m, "_if_calib"): m._if_calib = True @@ -290,21 +250,28 @@ def test_forward_calib_restores_top_k(self): else: assert converted.top_k == original_top_k - def test_gate_forward_hook_counts_tokens(self): - """Verify the gate forward hook correctly counts expert token assignments.""" + def test_token_counting_lazy_init(self): + """When moe_count_expert_calib_tokens is enabled, token counting infra is lazy-inited.""" model = get_tiny_qwen3_moe() moe_block = self._get_moe_block(model) - moe_type = type(moe_block) - - if QuantModuleRegistry.get(moe_type) is None: + if QuantModuleRegistry.get(type(moe_block)) is None: register_sparse_moe_on_the_fly(model) converted = QuantModuleRegistry.convert(moe_block) + converted._moe_count_expert_calib_tokens = True - # Reset counts and enable counting - converted.expert_token_count.zero_() - converted._count_expert_tokens = True + assert not hasattr(converted, "expert_token_count") + + x = torch.randn(1, 4, 32) + with torch.no_grad(): + converted(x) + + # Buffer and hook should now exist + assert hasattr(converted, "expert_token_count") + assert converted.expert_token_count.numel() > 0 + # Manually enable counting and call gate to verify hook works + converted._count_expert_tokens = True if TRANSFORMERS_VERSION_GE_5_0: hidden_size = converted.gate.weight.shape[1] top_k = converted.gate.top_k @@ -312,15 +279,43 @@ def test_gate_forward_hook_counts_tokens(self): hidden_size = converted.gate.in_features top_k = converted.top_k if hasattr(converted, "top_k") else converted.gate.top_k - x = torch.randn(8, hidden_size) + converted.expert_token_count.zero_() + tokens = torch.randn(8, hidden_size) with torch.no_grad(): - converted.gate(x) - total_assigned = converted.expert_token_count.sum().item() - assert total_assigned == 8 * top_k + converted.gate(tokens) + assert converted.expert_token_count.sum().item() == 8 * top_k - # Disable counting and verify counts don't change - converted._count_expert_tokens = False - prev_counts = converted.expert_token_count.clone() - with torch.no_grad(): - converted.gate(x) - assert torch.equal(converted.expert_token_count, prev_counts) + +def test_qwen3_moe_quantize_with_token_forcing_and_counting(): + """End-to-end: mtq.quantize a Qwen3MoE with INT8 + moe_calib_experts_ratio + token counting.""" + model = get_tiny_qwen3_moe() + + # Verify detection + moe_found = any(_is_sparse_moe_block(m) for m in model.modules()) + assert moe_found, "Qwen3MoE should be detected as a sparse MoE block" + + quant_cfg = copy.deepcopy(mtq.INT8_DEFAULT_CFG) + quant_cfg["algorithm"] = { + "method": "max", + "moe_calib_experts_ratio": 1.0, + "moe_count_expert_calib_tokens": True, + } + + def calib_fn(model): + x = model.dummy_inputs["input_ids"] + for _ in range(2): + model(x) + + mtq.quantize(model, quant_cfg, calib_fn) + + # Verify token counting worked + for name, module in model.named_modules(): + if hasattr(module, "expert_token_count") and module.expert_token_count.numel() > 0: + assert (module.expert_token_count > 0).all(), ( + f"Not all experts received tokens in {name}: {module.expert_token_count}" + ) + + # Verify model still runs + with torch.no_grad(): + out = model(model.dummy_inputs["input_ids"]) + assert out.logits is not None From 9fae2613d8a7bc945e9dd9d929093d670fe6e0cb Mon Sep 17 00:00:00 2001 From: realAsma Date: Fri, 6 Mar 2026 19:23:29 +0000 Subject: [PATCH 2/3] Revert NVFP4_DEFAULT_CFG algorithm to plain "max" Signed-off-by: realAsma Made-with: Cursor --- modelopt/torch/quantization/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 17614499d..6b4d14f76 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -416,7 +416,7 @@ }, **_default_disabled_quantizer_cfg, }, - "algorithm": {"method": "max", "kv_scales": "constant", "moe_calib_experts_ratio": 0.5}, + "algorithm": "max", } NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG = { From bc965bd2c11f3133dd9ec680c89356a93d5b0899 Mon Sep 17 00:00:00 2001 From: realAsma Date: Fri, 6 Mar 2026 21:28:03 +0000 Subject: [PATCH 3/3] Restore deleted pre-existing comments in _QuantSparseMoe Signed-off-by: realAsma Made-with: Cursor --- modelopt/torch/quantization/plugins/huggingface.py | 5 +++++ modelopt/torch/quantization/plugins/megatron.py | 2 +- modelopt/torch/quantization/utils.py | 4 +--- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index fcfae0cfe..96d89ff86 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -507,6 +507,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: is_calib = any(getattr(m, "_if_calib", False) for m in self.experts.modules()) self._count_expert_tokens = is_calib and self._moe_count_expert_calib_tokens + # If any of the experts are in calibration mode, we will forward all tokens to + # self._moe_calib_experts_ratio % of the experts to improve the calibration coverage. + # This is used only for calibration, we need to re-calculate the actual outputs again using + # the original top_k if is_calib and self._moe_calib_experts_ratio: self._count_expert_tokens = True assert 0 < self._moe_calib_experts_ratio <= 1, ( @@ -521,6 +525,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: super().forward(hidden_states) self.gate.top_k = original_top_k else: + # Path for transformers < 5.0 original_top_k = self.top_k if hasattr(self, "num_experts"): self.top_k = max( diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 03a6cc190..8275f9c2a 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -41,9 +41,9 @@ from modelopt.torch.utils.distributed import ParallelState from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer -from ..utils import sync_moe_experts_input_amax from ..nn.modules.quant_linear import RealQuantLinear from ..qtensor import QTensorWrapper +from ..utils import sync_moe_experts_input_amax from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear try: diff --git a/modelopt/torch/quantization/utils.py b/modelopt/torch/quantization/utils.py index 2e6df4bcf..39773dfca 100644 --- a/modelopt/torch/quantization/utils.py +++ b/modelopt/torch/quantization/utils.py @@ -557,9 +557,7 @@ def sync_moe_experts_input_amax(experts): stored_amax = amax_dict.get(name) amax_tensor = module.amax.detach().clone() amax_dict[name] = ( - amax_tensor - if stored_amax is None - else torch.maximum(stored_amax, amax_tensor) + amax_tensor if stored_amax is None else torch.maximum(stored_amax, amax_tensor) ) for expert in experts: