Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,7 +1126,18 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig):
description=(
"If specified, we force forward tokens to % of experts during the calibration"
" pass. This forward is for calibration purpose only and will not affect the"
" actual inference."
" actual inference. Not supported for all MoE architectures; currently works"
" with a few HuggingFace models such as Mixtral, Qwen3Moe, MiniMax."
),
)

moe_count_expert_calib_tokens: bool = ModeloptField(
default=False,
title="Enable expert token counting during MoE calibration.",
description=(
"If True, counts how many tokens are routed to each expert during calibration."
" Not supported for all MoE architectures; currently works with a few HuggingFace"
" models such as Mixtral, Qwen3Moe, MiniMax."
),
)

Expand Down
6 changes: 6 additions & 0 deletions modelopt/torch/quantization/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,12 @@ def wrapped_calib_func(
if hasattr(module, "_moe_calib_experts_ratio"):
module._moe_calib_experts_ratio = moe_calib_experts_ratio

moe_count_expert_calib_tokens = kwargs.pop("moe_count_expert_calib_tokens", False)
if moe_count_expert_calib_tokens:
for module in model.modules():
if hasattr(module, "_moe_count_expert_calib_tokens"):
module._moe_count_expert_calib_tokens = True

if func is not None:
if sequential:
if forward_loop is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1338,7 +1338,7 @@ class SequentialQuantizer(nn.Sequential):

"""

_delegated_properties = ["fake_quant", "is_enabled"]
_delegated_properties = ["fake_quant", "is_enabled", "amax"]
_delegated_methods = [
"reset_amax",
"disable",
Expand Down
80 changes: 49 additions & 31 deletions modelopt/torch/quantization/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
else:
weight_dequant = None

from ..utils import replace_function
from ..utils import replace_function, sync_moe_experts_input_amax
from .attention import register_attention_for_kv_quant
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear, _QuantFunctionalMixin

Expand Down Expand Up @@ -440,16 +440,24 @@ def backward(ctx, grad_output):


class _QuantSparseMoe(QuantModule):
"""Module to support special handling of token dispatching during calibration.
"""Quantization wrapper for HuggingFace sparse MoE blocks.

During calibration, we forward all tokens to all experts so that all experts see sufficient tokens to calibrate.
However, even in calibration mode, the actual top_k routing is used to calculate the actual outputs this instance
returns.
Supports ``layer_sync_moe_local_experts_amax`` to sync input quantizer amax across experts.

If calibration is not enabled, this module behaves as a normal MoELayer.
Optionally supports two config-driven features (disabled by default):
- ``_moe_calib_experts_ratio``: force-forward tokens to more experts during calibration.
- ``_moe_count_expert_calib_tokens``: count tokens routed to each expert during calibration.

When both are disabled, forward is a direct pass-through with zero overhead.
"""

def _setup(self):
self._moe_calib_experts_ratio = None
self._moe_count_expert_calib_tokens = False

def _init_token_counting(self):
"""Lazy-init token counting infra (buffer + gate hook). Called once from forward."""
self._token_counting_initialized = True
num_experts = 0
if hasattr(self, "gate") and hasattr(self.gate, "num_experts"):
num_experts = self.gate.num_experts
Expand All @@ -458,21 +466,19 @@ def _setup(self):
elif hasattr(self, "experts") and hasattr(self.experts, "num_experts"):
num_experts = self.experts.num_experts

self.register_buffer(
"expert_token_count",
torch.zeros(num_experts, dtype=torch.long, device=next(self.parameters()).device),
persistent=False,
)
self._count_expert_tokens = False
self._moe_calib_experts_ratio = None

if num_experts == 0:
warnings.warn(
f"{self.__class__.__name__}: could not resolve num_experts; "
"expert routing will not be tracked for this layer."
)
return

self.register_buffer(
"expert_token_count",
torch.zeros(num_experts, dtype=torch.long, device=next(self.parameters()).device),
persistent=False,
)
self._count_expert_tokens = False
if hasattr(self, "gate"):
self.gate.register_forward_hook(self._gate_forward_hook)

Expand All @@ -492,17 +498,24 @@ def _gate_forward_hook(self, module, input, output):
self.expert_token_count += counts.to(self.expert_token_count.device)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if not self._moe_calib_experts_ratio and not self._moe_count_expert_calib_tokens:
return super().forward(hidden_states)

if self._moe_count_expert_calib_tokens and not hasattr(self, "_token_counting_initialized"):
self._init_token_counting()

is_calib = any(getattr(m, "_if_calib", False) for m in self.experts.modules())
self._count_expert_tokens = is_calib
self._count_expert_tokens = is_calib and self._moe_count_expert_calib_tokens

# If any of the experts are in calibration mode, we will forward all tokens to
# self._moe_calib_experts_ratio % of the experts to improve the calibration coverage.
# This is used only for calibration, we need to re-calculate the actual outputs again using
# the original top_k
if is_calib and self._moe_calib_experts_ratio:
self._count_expert_tokens = True
assert 0 < self._moe_calib_experts_ratio <= 1, (
"moe_calib_experts_ratio must be between 0 and 1"
)
# If any of the experts are in calibration mode, we will forward all tokens to
# self._moe_calib_experts_ratio % of the experts to improve the calibration coverage.
# This is used only for calibration, we need to re-calculate the actual outputs again using
# the original top_k
if TRANSFORMERS_VERSION_GE_5_0:
assert hasattr(self, "gate") and hasattr(self.gate, "top_k")
original_top_k = self.gate.top_k
Expand All @@ -528,12 +541,15 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
super().forward(hidden_states)
self.top_k = original_top_k
self._count_expert_tokens = False
else:
self._count_expert_tokens = True

output = super().forward(hidden_states)
self._count_expert_tokens = False
return output

def layer_sync_moe_local_experts_amax(self):
"""Sync input_quantizer amax across experts so all share the same amax per quantizer."""
sync_moe_experts_input_amax(self.experts)


class _QuantLlama4TextExperts(QuantModule):
def _setup(self):
Expand Down Expand Up @@ -1110,31 +1126,33 @@ def register_falcon_linears_on_the_fly(model):
QuantModuleRegistry.register({linear_type: linear_type.__name__})(_QuantLinear)


def _has_num_experts(obj):
# n_routed_experts: NemotronH-style MoE
return hasattr(obj, "num_experts") or hasattr(obj, "n_routed_experts")


def _is_sparse_moe_block(module):
"""Check if a module is structurally a sparse MoE block compatible with _QuantSparseMoe.

All HuggingFace MoE blocks (Mixtral, Qwen3Moe, Qwen2Moe, Qwen3Next, Llama4, MiniMax, etc.)
share a common structural pattern: a ``gate`` (TopKRouter) sub-module with routing attributes
(``top_k`` and ``num_experts``), and an ``experts`` sub-module.
All HuggingFace MoE blocks (Mixtral, Qwen3Moe, Qwen2Moe, Qwen3Next, Llama4, MiniMax,
NemotronH, etc.) share a common structural pattern: a ``gate`` (TopKRouter) sub-module with
routing attributes (``top_k`` and ``num_experts`` or ``n_routed_experts``), and an ``experts``
sub-module.

This function detects that pattern instead of relying on class names, making it forward-compatible
with new MoE architectures. Some MoE models (e.g. Glm4MoeMoE) have ``gate`` and ``experts`` but
use a different routing interface (``n_routed_experts`` instead of ``num_experts``, custom
``route_tokens_to_experts``), so we require ``num_experts`` to be present to avoid false positives.
with new MoE architectures.
"""
if not hasattr(module, "experts"):
return False

# Primary: gate sub-module has topk/top_k + num_experts (standard TopKRouter pattern)
if hasattr(module, "gate"):
gate = module.gate
has_topk = hasattr(gate, "top_k")
has_num_experts = hasattr(gate, "num_experts")
if has_topk and has_num_experts:
if hasattr(gate, "top_k") and _has_num_experts(gate):
return True

# Fallback: top_k + num_experts on the block itself (older transformers, e.g. v4.x Qwen3Next)
return hasattr(module, "top_k") and hasattr(module, "num_experts")
return hasattr(module, "top_k") and _has_num_experts(module)


def register_sparse_moe_on_the_fly(model):
Expand Down
28 changes: 4 additions & 24 deletions modelopt/torch/quantization/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer
from ..nn.modules.quant_linear import RealQuantLinear
from ..qtensor import QTensorWrapper
from ..utils import sync_moe_experts_input_amax
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear

try:
Expand Down Expand Up @@ -577,7 +578,7 @@ def _setup(self):
def layer_sync_moe_local_experts_amax(self):
"""Sync input quantizer amax across local experts in a SequentialMLP.

Ensures all experts have the same input quantizer amax.This function operates
Ensures all experts have the same input quantizer amax. This function operates
on a single rank and does not require distributed sync.

Distributed amax sync across EP and ETP (for RowParallel) happens in model_calib.max_calibrate().
Expand All @@ -586,32 +587,11 @@ def layer_sync_moe_local_experts_amax(self):

Note:
Because there are logic which calls collective communication based on whether amax is not None,
We need to guarantee that all experts must have amax. Otherwise, there will be deadlock
we need to guarantee that all experts must have amax. Otherwise, there will be deadlock
when synchronizing over EP since some ranks may have amax None and not calling the collective
communication.
"""
# Collect amax from all local experts
amax_dict = {}
for expert in self.local_experts:
for name, module in expert.named_modules():
if (
isinstance(module, TensorQuantizer)
and module.amax is not None
and "input_quantizer" in name
):
stored_amax = amax_dict.get(name)
amax_tensor = module.amax.detach().clone()
amax_dict[name] = (
amax_tensor
if stored_amax is None
else torch.maximum(stored_amax, amax_tensor)
)

# Apply synchronized amax values back to all local experts
for expert in self.local_experts:
for name, module in expert.named_modules():
if isinstance(module, TensorQuantizer) and name in amax_dict:
module.amax = amax_dict[name].detach().clone()
sync_moe_experts_input_amax(self.local_experts)

def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
"""Override the default to enable singleton_local_shards.
Expand Down
40 changes: 40 additions & 0 deletions modelopt/torch/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,46 @@ def set_quantizer_state_dict(model: nn.Module, quantizer_state_dict: dict):
module.load_state_dict(quantizer_state_dict[key])


def sync_moe_experts_input_amax(experts):
"""Sync input_quantizer amax across MoE experts and fix missing weight amax.

1. Takes the element-wise max of each ``input_quantizer`` amax across all experts
and writes it back, so every expert shares the same input amax.
2. For any ``weight_quantizer`` that is enabled but has ``amax is None`` (expert
received no tokens during calibration), runs a weight-only ``max_calibrate``
to populate the missing amax.
"""
from .nn import TensorQuantizer

amax_dict: dict[str, torch.Tensor] = {}
for expert in experts:
for name, module in expert.named_modules():
if (
isinstance(module, TensorQuantizer)
and module.amax is not None
and "input_quantizer" in name
):
stored_amax = amax_dict.get(name)
amax_tensor = module.amax.detach().clone()
amax_dict[name] = (
amax_tensor if stored_amax is None else torch.maximum(stored_amax, amax_tensor)
)

for expert in experts:
for name, module in expert.named_modules():
if isinstance(module, TensorQuantizer) and name in amax_dict:
module.amax = amax_dict[name].detach().clone()

from .model_calib import max_calibrate

for expert in experts:
for name, module in expert.named_modules():
if name.endswith("weight_quantizer") and module.is_enabled and module.amax is None:
weight = expert.state_dict().get(name.replace("weight_quantizer", "weight"))
if weight is not None:
max_calibrate(module, lambda m, w=weight: m(w), distributed_sync=False)


@contextmanager
def patch_fsdp_mp_dtypes():
"""Patch FSDP2 to handle mixed dtypes properly during quantization.
Expand Down
Loading