diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 4ceb51cd2c..76f304a478 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -42,6 +42,7 @@ QuantizerAttrNames, quantizer_attr_names, reduce_block_amax, + representative_weight_quantizer, weight_attr_names, ) from modelopt.torch.utils import clear_cuda_cache @@ -546,7 +547,7 @@ def _compute_kv_cache_dtype( def get_weight_block_size(module: nn.Module, weight_name: str = "weight") -> int: """Returns the weight block size.""" - weight_quantizer = getattr(module, quantizer_attr_names(weight_name).weight_quantizer, None) + weight_quantizer = representative_weight_quantizer(module, weight_name) if weight_quantizer is None: return 0 @@ -572,7 +573,11 @@ def get_quantization_format(module) -> str | None: """ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames): - weight_quantizer = getattr(layer, quantizer_attr_names.weight_quantizer, None) + # Singular form first, plural ModuleList fallback (fused-experts). + # Strip the "_weight_quantizer" suffix to recover the weight attr name. + weight_attr = quantizer_attr_names.weight_quantizer + weight_name = weight_attr[: -len("_weight_quantizer")].rstrip("_") or "weight" + weight_quantizer = representative_weight_quantizer(layer, weight_name) input_quantizer = getattr(layer, quantizer_attr_names.input_quantizer, None) if weight_quantizer is None or not weight_quantizer.is_enabled: diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index af936a3002..a76783ac17 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -88,6 +88,7 @@ QUANTIZATION_W4A8_NVFP4_FP8, ) from .model_utils import get_language_model_from_vl, is_multimodal_model +from .moe_utils import _export_fused_experts from .plugins import SpeculativeDecodingExporter, has_spec_opt from .quant_utils import ( fuse_prequant_layernorm, @@ -642,11 +643,20 @@ def _process_quantized_modules( if is_modelopt_qlora and (hasattr(sub_module, "base_layer")): continue + # Preprocessing: restore unpacked weight so the export path can read + # the live quantizer state. Falls through to the export branches below. if hasattr(sub_module, "weight_packed") or ( "QuantFP8Linear" in type(sub_module).__name__ and sub_module.weight.element_size() <= 1 ): sub_module.unpack_weight() - if get_quantization_format(sub_module) != QUANTIZATION_NONE: + + if hasattr(sub_module, "gate_up_proj_weight_quantizers"): + # _QuantFusedExperts uses plural `gate_up_proj_weight_quantizers` (ModuleList), + # which get_quantization_format's singular-weight_quantizer check misses. Handle + # it explicitly before the format gate so fused-experts get split + quantized. + with fsdp2_aware_weight_update(model, sub_module, reshard=False): + _export_fused_experts(sub_module, dtype) + elif get_quantization_format(sub_module) != QUANTIZATION_NONE: # Skip QuantMoELinear - it's handled separately in _reconstruct_fused_moe_linear if type(sub_module).__name__ == "QuantMoELinear": continue @@ -677,13 +687,6 @@ def _process_quantized_modules( with fsdp2_aware_weight_update(model, sub_module, reshard=False): for weight_name in ["gate_up_proj", "down_proj"]: _export_quantized_weight(sub_module, dtype, weight_name) - elif hasattr(sub_module, "gate_up_proj_weight_quantizers"): - # Generic fused MoE experts (_QuantFusedExperts) with per-expert - # quantizer ModuleLists. Split into per-expert modules and export. - from modelopt.torch.export.moe_utils import _export_fused_experts - - with fsdp2_aware_weight_update(model, sub_module, reshard=False): - _export_fused_experts(sub_module, dtype) def _export_transformers_checkpoint( diff --git a/modelopt/torch/quantization/conversion.py b/modelopt/torch/quantization/conversion.py index 55f7fdf6fc..3f97f8380b 100644 --- a/modelopt/torch/quantization/conversion.py +++ b/modelopt/torch/quantization/conversion.py @@ -16,6 +16,7 @@ """Quantization conversion/restore utilities.""" import fnmatch +import re import warnings from collections.abc import Callable from contextlib import contextmanager @@ -286,6 +287,33 @@ def set_quantizer_by_cfg(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgType set_quantizer_attributes_full(quant_model, quantizer_name, attributes, parent_class) +_FUSED_EXPERTS_QUANTIZER_LIST_RE = re.compile( + r"(weight_quantizers?|input_quantizers?)\.\d+(?=$|\.)" +) + + +def _normalize_fused_experts_quantizer_name(name: str) -> str: + """Strip the per-expert index from per-expert quantizer ModuleList names. + + Fused-experts modules register per-expert weight/input quantizers in a + ``nn.ModuleList``; its children surface as dotted names like + ``...gate_up_proj_weight_quantizers.0`` (plural) or — if a variant uses + singular naming — ``...gate_up_proj_weight_quantizer.0``. Neither matches + the singular-suffix wildcards (``*weight_quantizer``) used in the stock + configs, so the experts stay at their defaults. + + Return a normalized name where either ``weight_quantizer[s]?.N`` or + ``input_quantizer[s]?.N`` collapses to the singular form without the index + so the standard wildcards match. + """ + + def _repl(m: re.Match) -> str: + base = m.group(1) + return base.removesuffix("s") + + return _FUSED_EXPERTS_QUANTIZER_LIST_RE.sub(_repl, name) + + def _match_quantizer( wildcard_or_filter_func: str | Callable, name: str, @@ -296,7 +324,11 @@ def _match_quantizer( if not isinstance(module, (TensorQuantizer, SequentialQuantizer)): return False if isinstance(wildcard_or_filter_func, str): - if not fnmatch.fnmatch(name, wildcard_or_filter_func): + normalized = _normalize_fused_experts_quantizer_name(name) + if not ( + fnmatch.fnmatch(name, wildcard_or_filter_func) + or (normalized != name and fnmatch.fnmatch(normalized, wildcard_or_filter_func)) + ): return False elif callable(wildcard_or_filter_func): if not wildcard_or_filter_func(name): diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 59bcd215bb..5549f20698 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -881,6 +881,33 @@ def forward(self, *args, **kwargs): self._down_proj_linear = False return super().forward(*args, **kwargs) + def fold_weight(self, keep_attrs: bool = False): + """Fold per-expert weight quantizers into the fused 3-D weights. + + The base ``fold_weight`` only handles singular ``*_weight_quantizer`` + attributes. Fused experts use ``nn.ModuleList`` of per-expert quantizers + (``gate_up_proj_weight_quantizers``, ``down_proj_weight_quantizers``), + which would otherwise be skipped, leaving ``_amax`` on every quantizer. + """ + for weight_name, quantizers_name in ( + ("gate_up_proj", "gate_up_proj_weight_quantizers"), + ("down_proj", "down_proj_weight_quantizers"), + ): + weight = getattr(self, weight_name, None) + quantizers = getattr(self, quantizers_name, None) + if weight is None or quantizers is None: + continue + for idx, q in enumerate(quantizers): + if not (isinstance(q, TensorQuantizer) and q.fake_quant): + continue + slice_ = weight.data[idx] + slice_.copy_(q(slice_.float()).to(weight.dtype)) + q.disable() + if not keep_attrs: + for attr_name in ("_pre_quant_scale", "_amax"): + if hasattr(q, attr_name): + delattr(q, attr_name) + class _QuantDbrxFFN(_QuantSparseSequentialMoe): @property @@ -1419,6 +1446,38 @@ def register_fused_experts_on_the_fly(model): QuantModuleRegistry.register({mod_type: f"hf.{mod_type.__name__}"})(_QuantFusedExperts) +def force_eager_experts_impl_on_the_fly(model): + """Force HF fused-experts modules onto the eager ``F.linear``-based forward. + + HF transformers 5.0+ decorates fused-experts forwards with + ``@use_experts_implementation``, which may dispatch to ``torch._grouped_mm`` + or ``torch.bmm`` backends. Those backends bypass ``F.linear`` and so bypass + ``_QuantFusedExperts``'s input/weight quantizer hooks — calibration silently + does nothing, no ``input_scale`` / ``amax`` is collected, and the exported + checkpoint produces garbage at inference. + + Sets ``config._experts_implementation = "eager"`` on the model config (and + recursively on ``text_config`` / ``vision_config`` / ``audio_config`` / + ``speech_config``) whenever a fused-experts module is present. + """ + if not any(_is_fused_experts_module(m) for m in model.modules()): + return + + nested_cfg_attrs = ("text_config", "vision_config", "audio_config", "speech_config") + + def _force(cfg): + if cfg is None: + return + if hasattr(cfg, "_experts_implementation"): + cfg._experts_implementation = "eager" + for sub in nested_cfg_attrs: + if hasattr(cfg, sub): + _force(getattr(cfg, sub)) + + if hasattr(model, "config"): + _force(model.config) + + def _is_supported_hf_model(model): """Check if the model a valid model for transformers quantization specific support.""" supported_models = [transformers.PreTrainedModel] @@ -1646,6 +1705,7 @@ def _reconstruct_fused_moe_linear(model: nn.Module) -> None: register_dbrx_moe_on_the_fly, register_step3p5_moe_on_the_fly, register_fused_experts_on_the_fly, + force_eager_experts_impl_on_the_fly, register_sparse_moe_on_the_fly, register_hf_attentions_on_the_fly, convert_hf_parallel_linears_on_the_fly, diff --git a/modelopt/torch/quantization/utils/__init__.py b/modelopt/torch/quantization/utils/__init__.py index dfc23c42ee..dc6daa0084 100644 --- a/modelopt/torch/quantization/utils/__init__.py +++ b/modelopt/torch/quantization/utils/__init__.py @@ -30,6 +30,7 @@ "reduce_amax", "reduce_sum", "replace_function", + "representative_weight_quantizer", "update_quant_cfg_with_kv_cache_quant", "weight_attr_names", ] diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index 29661e18f5..1a177e04dc 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -202,27 +202,57 @@ def reduce_sum(input, axis=None, keepdims=True): return output -def weight_attr_names(module: nn.Module) -> "Generator[str, None, None]": - """Get the weight param attribute names in a converted module, non-recursive. +def representative_weight_quantizer(module: nn.Module, weight_name: str = "weight"): + """Return the representative weight quantizer for ``weight_name`` on ``module``. + + Handles two layouts: + + - singular ``_weight_quantizer`` — standard ``nn.Linear`` / ``_QuantLinear``. + - plural ``_weight_quantizers`` (``nn.ModuleList``) — fused-experts modules + (``_QuantFusedExperts``) hold one ``TensorQuantizer`` per expert. Per-expert + formats are identical, so the first element is representative. - We consider the following two cases for each weight param attribute: - - The standard weight attribute (e.g. nn.Linear). - - The custom `weight_attr_name`. (e.g. Llama4TextExperts has weight attributes `gate_up_proj` and `down_proj`) + Returns ``None`` if no matching quantizer is found. """ from ..nn import SequentialQuantizer, TensorQuantizer - # the standard weight and quantizer case - weight = getattr(module, "weight", None) - weight_quantizer = getattr(module, "weight_quantizer", None) - if weight is not None and isinstance(weight_quantizer, (TensorQuantizer, SequentialQuantizer)): - yield "weight" + singular = quantizer_attr_names(weight_name).weight_quantizer + q = getattr(module, singular, None) + if isinstance(q, (TensorQuantizer, SequentialQuantizer)): + return q - # other weight and quantizer case + plural = getattr(module, singular + "s", None) + if isinstance(plural, nn.ModuleList) and len(plural) > 0: + first = plural[0] + if isinstance(first, (TensorQuantizer, SequentialQuantizer)): + return first + return None + + +def weight_attr_names(module: nn.Module) -> "Generator[str, None, None]": + """Get the weight param attribute names in a converted module, non-recursive. + + Covers three layouts: + + - standard ``nn.Linear``: ``weight`` + ``weight_quantizer``. + - custom per-weight quantizer (e.g. ``Llama4TextExperts`` with ``gate_up_proj`` + + ``gate_up_proj_weight_quantizer``). + - fused-experts ``nn.ModuleList`` quantizers (``_QuantFusedExperts`` with + ``gate_up_proj`` + ``gate_up_proj_weight_quantizers`` plural list). + """ + # standard: "weight" + "weight_quantizer" (singular) or "weight_quantizers" (plural) + if getattr(module, "weight", None) is not None: + if representative_weight_quantizer(module, "weight") is not None: + yield "weight" + + # per-parameter custom attr names for name, _ in module.named_parameters(recurse=False): + if name == "weight": + continue weight = getattr(module, name, None) - weight_quantizer = getattr(module, f"{name}_weight_quantizer", None) - if isinstance(weight, nn.Parameter) and isinstance( - weight_quantizer, (TensorQuantizer, SequentialQuantizer) + if ( + isinstance(weight, nn.Parameter) + and representative_weight_quantizer(module, name) is not None ): yield name diff --git a/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml b/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml index 220d062232..7c55703963 100644 --- a/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml @@ -20,7 +20,9 @@ quantize: algorithm: method: max # Max calibration is fast and does not typically need checkpointing. - layerwise: true + # layerwise=false required for VLMs where the decoder layers are nested under + # `model.language_model.layers` (layerwise_calibrate can't find them otherwise). + layerwise: false quant_cfg: - quantizer_name: '*' enable: false diff --git a/tests/unit/torch/quantization/plugins/test_fused_experts.py b/tests/unit/torch/quantization/plugins/test_fused_experts.py index 7e77bf1151..2943582774 100644 --- a/tests/unit/torch/quantization/plugins/test_fused_experts.py +++ b/tests/unit/torch/quantization/plugins/test_fused_experts.py @@ -22,11 +22,13 @@ pytest.importorskip("transformers") +from modelopt.torch.quantization.conversion import _normalize_fused_experts_quantizer_name from modelopt.torch.quantization.nn import QuantModuleRegistry from modelopt.torch.quantization.plugins.huggingface import ( _is_fused_experts_module, _is_sparse_sequaential_moe_block, _QuantFusedExperts, + force_eager_experts_impl_on_the_fly, register_fused_experts_on_the_fly, register_sparse_moe_on_the_fly, ) @@ -297,3 +299,316 @@ def test_export_creates_per_expert_submodules(self): if QuantModuleRegistry.get(expert_type) is not None: QuantModuleRegistry.unregister(expert_type) + + +# --------------------------------------------------------------------------- +# Tests for force_eager_experts_impl_on_the_fly +# --------------------------------------------------------------------------- +class _StubConfig: + """Minimal stand-in for HF PretrainedConfig with optional nested sub-configs.""" + + def __init__(self, impl=None, **nested): + if impl is not None: + self._experts_implementation = impl + for key, value in nested.items(): + setattr(self, key, value) + + +class _TinyMoEModelWithConfig(_TinyMoEModel): + def __init__(self, config): + super().__init__() + self.config = config + + +class _NonMoEModelWithConfig(nn.Module): + def __init__(self, config): + super().__init__() + self.linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM) + self.config = config + + +class TestForceEagerExpertsImpl: + def test_sets_eager_on_moe_model(self): + """Non-eager backend on an MoE model gets flipped to eager.""" + cfg = _StubConfig(impl="kernels") + model = _TinyMoEModelWithConfig(cfg) + force_eager_experts_impl_on_the_fly(model) + assert cfg._experts_implementation == "eager" + + def test_recurses_into_nested_configs(self): + """VLM-style nested text_config / vision_config are also flipped.""" + text_cfg = _StubConfig(impl="grouped_mm") + vision_cfg = _StubConfig(impl="bmm") + root_cfg = _StubConfig(text_config=text_cfg, vision_config=vision_cfg) + model = _TinyMoEModelWithConfig(root_cfg) + force_eager_experts_impl_on_the_fly(model) + assert text_cfg._experts_implementation == "eager" + assert vision_cfg._experts_implementation == "eager" + + def test_skips_model_without_fused_experts(self): + """Non-MoE models must not have their config silently mutated.""" + cfg = _StubConfig(impl="kernels") + model = _NonMoEModelWithConfig(cfg) + force_eager_experts_impl_on_the_fly(model) + assert cfg._experts_implementation == "kernels" + + def test_no_crash_when_config_missing(self): + """Model without a ``config`` attribute must not raise.""" + force_eager_experts_impl_on_the_fly(_TinyMoEModel()) # no-op, no error + + def test_no_crash_when_impl_attr_missing(self): + """Config without ``_experts_implementation`` must not raise.""" + cfg = _StubConfig() # no impl attr + model = _TinyMoEModelWithConfig(cfg) + force_eager_experts_impl_on_the_fly(model) + assert not hasattr(cfg, "_experts_implementation") + + def test_leaves_eager_value_unchanged(self): + cfg = _StubConfig(impl="eager") + model = _TinyMoEModelWithConfig(cfg) + force_eager_experts_impl_on_the_fly(model) + assert cfg._experts_implementation == "eager" + + +# --------------------------------------------------------------------------- +# End-to-end PTQ calibration test — guards the full fused-experts path: +# register_fused_experts_on_the_fly → _QuantFusedExperts.{_setup, forward} → +# plural ModuleList name normalization in conversion._match_quantizer → +# TensorQuantizer amax collection via the F.linear hook. +# If any link breaks, quantizer `amax` stays None and this test fails. +# --------------------------------------------------------------------------- +class TestFusedExpertsCalibration: + @staticmethod + def _cleanup_registry(mod_type): + if QuantModuleRegistry.get(mod_type) is not None: + QuantModuleRegistry.unregister(mod_type) + + def test_calibration_populates_all_expert_quantizers(self): + """After PTQ, every input/weight quantizer on the fused-experts module has amax set.""" + import modelopt.torch.quantization as mtq + + model = _TinyMoEModel() + expert_type = type(model.moe.experts) + self._cleanup_registry(expert_type) + + quant_cfg = { + "quant_cfg": [ + {"quantizer_name": "*", "enable": False}, + { + "quantizer_name": "*gate_up_proj_input_quantizer", + "cfg": {"num_bits": 8, "axis": None}, + }, + { + "quantizer_name": "*down_proj_input_quantizer", + "cfg": {"num_bits": 8, "axis": None}, + }, + { + "quantizer_name": "*gate_up_proj_weight_quantizer", + "cfg": {"num_bits": 8, "axis": 0}, + }, + { + "quantizer_name": "*down_proj_weight_quantizer", + "cfg": {"num_bits": 8, "axis": 0}, + }, + ], + "algorithm": "max", + } + + def forward_loop(m): + torch.manual_seed(0) + for _ in range(2): + x = torch.randn(1, 4, HIDDEN_DIM) + m(x) + + mtq.quantize(model, quant_cfg, forward_loop=forward_loop) + + experts = model.moe.experts + assert experts.gate_up_proj_input_quantizer.amax is not None, ( + "Shared gate_up_proj input quantizer was not calibrated — " + "F.linear hook likely bypassed by non-eager experts_implementation." + ) + assert experts.down_proj_input_quantizer.amax is not None, ( + "Shared down_proj input quantizer was not calibrated." + ) + for idx in range(NUM_EXPERTS): + assert experts.gate_up_proj_weight_quantizers[idx].amax is not None, ( + f"gate_up_proj_weight_quantizers[{idx}].amax is None — " + "plural ModuleList name normalization in _match_quantizer likely broken." + ) + assert experts.down_proj_weight_quantizers[idx].amax is not None, ( + f"down_proj_weight_quantizers[{idx}].amax is None." + ) + + self._cleanup_registry(expert_type) + + +# --------------------------------------------------------------------------- +# Tests for export enumeration — guards the bug where fused-experts were +# silently skipped by get_quant_config because their weight quantizers live +# on a plural nn.ModuleList instead of the singular *_weight_quantizer attr. +# Missed enumeration → experts don't appear in quantized_layers → +# quantization_formats has only 1 entry from the non-expert modules → +# quant_algo lands on that format instead of "MIXED_PRECISION". +# --------------------------------------------------------------------------- +class _MixedPrecisionModel(nn.Module): + """A model with both a fused-experts block AND a standard Linear, so a + mixed-precision recipe should produce two distinct format groups.""" + + def __init__(self): + super().__init__() + self.moe = _SyntheticSparseMoeBlock() + self.dense = nn.Linear(HIDDEN_DIM, HIDDEN_DIM) + + def forward(self, x): + return self.dense(self.moe(x)) + + +class TestMixedPrecisionExport: + @staticmethod + def _cleanup_registry(mod_type): + if QuantModuleRegistry.get(mod_type) is not None: + QuantModuleRegistry.unregister(mod_type) + + def test_weight_attr_names_yields_fused_expert_params(self): + """weight_attr_names must yield gate_up_proj / down_proj on fused experts + even though their quantizers are a plural ModuleList, not singular.""" + from modelopt.torch.quantization.utils.core_utils import weight_attr_names + + model = _TinyMoEModel() + expert_type = type(model.moe.experts) + self._cleanup_registry(expert_type) + + register_fused_experts_on_the_fly(model) + converted = QuantModuleRegistry.convert(model.moe.experts) + + yielded = list(weight_attr_names(converted)) + assert set(yielded) == {"gate_up_proj", "down_proj"}, ( + f"Expected both fused weight attrs, got {yielded}. " + "Likely regression in representative_weight_quantizer plural fallback." + ) + + self._cleanup_registry(expert_type) + + def test_mixed_precision_config_export(self): + """Mixed-precision recipe (experts FP8 + dense Linear FP8 per-channel) should + show both modules in quantized_layers. Using two distinct formats would + trigger MIXED_PRECISION; using same-format still exercises enumeration.""" + import modelopt.torch.quantization as mtq + from modelopt.torch.export.quant_utils import get_quant_config + + model = _MixedPrecisionModel() + expert_type = type(model.moe.experts) + self._cleanup_registry(expert_type) + + # FP8 per-tensor for experts; FP8 per-channel for dense — two distinct + # format strings in quantization_formats, so quant_algo must become + # MIXED_PRECISION. + quant_cfg = { + "quant_cfg": [ + {"quantizer_name": "*", "enable": False}, + { + "quantizer_name": "*gate_up_proj_input_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + }, + { + "quantizer_name": "*down_proj_input_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + }, + { + "quantizer_name": "*gate_up_proj_weight_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + }, + { + "quantizer_name": "*down_proj_weight_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + }, + { + "quantizer_name": "*dense.input_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + }, + { + "quantizer_name": "*dense.weight_quantizer", + "cfg": {"num_bits": (4, 3), "axis": 0}, # per-channel → FP8_PC_PT + }, + ], + "algorithm": "max", + } + + def forward_loop(m): + torch.manual_seed(0) + for _ in range(2): + x = torch.randn(1, 4, HIDDEN_DIM) + m(x) + + mtq.quantize(model, quant_cfg, forward_loop=forward_loop) + + cfg = get_quant_config(model) + q = cfg["quantization"] + + # The fused-experts module MUST appear in quantized_layers. This is the + # central guard: regressions of weight_attr_names plural fallback would + # make experts disappear here. + layer_names = set(q.get("quantized_layers", {}).keys()) + assert any("moe.experts" in n for n in layer_names), ( + f"Fused-experts module missing from quantized_layers: {layer_names}. " + "weight_attr_names likely not yielding plural-ModuleList weight attrs." + ) + assert any(n.endswith("dense") for n in layer_names), ( + f"Dense Linear missing from quantized_layers: {layer_names}." + ) + + # Two distinct formats → MIXED_PRECISION at top level. + assert q["quant_algo"] == "MIXED_PRECISION", ( + f"Expected MIXED_PRECISION (fused-experts FP8 per-tensor + dense " + f"FP8 per-channel), got quant_algo={q['quant_algo']}. " + f"quantized_layers={q.get('quantized_layers')}" + ) + + self._cleanup_registry(expert_type) + + +# --------------------------------------------------------------------------- +# Tests for the fused-experts quantizer-name normalizer used by +# conversion._match_quantizer. Covers both plural (actual _QuantFusedExperts +# layout) and singular (defensive: future variants may name the ModuleList +# without the trailing `s`) forms. +# --------------------------------------------------------------------------- +class TestNormalizeFusedExpertsQuantizerName: + def test_plural_weight_quantizers_stripped(self): + assert ( + _normalize_fused_experts_quantizer_name("moe.experts.gate_up_proj_weight_quantizers.7") + == "moe.experts.gate_up_proj_weight_quantizer" + ) + + def test_plural_input_quantizers_stripped(self): + assert ( + _normalize_fused_experts_quantizer_name("moe.experts.down_proj_input_quantizers.3") + == "moe.experts.down_proj_input_quantizer" + ) + + def test_singular_weight_quantizer_with_index_stripped(self): + """Defensive: handle variants that name the ModuleList singular.""" + assert ( + _normalize_fused_experts_quantizer_name("moe.experts.gate_up_proj_weight_quantizer.2") + == "moe.experts.gate_up_proj_weight_quantizer" + ) + + def test_singular_input_quantizer_with_index_stripped(self): + assert ( + _normalize_fused_experts_quantizer_name("moe.experts.down_proj_input_quantizer.0") + == "moe.experts.down_proj_input_quantizer" + ) + + def test_non_indexed_name_unchanged(self): + """Plain singular names (no index) must be passed through untouched.""" + assert ( + _normalize_fused_experts_quantizer_name("moe.experts.gate_up_proj_weight_quantizer") + == "moe.experts.gate_up_proj_weight_quantizer" + ) + + def test_unrelated_dotted_number_unchanged(self): + """Dotted numbers that aren't inside a quantizer-list context are left alone.""" + assert ( + _normalize_fused_experts_quantizer_name("moe.layers.3.gate.weight") + == "moe.layers.3.gate.weight" + )