Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions modelopt/torch/export/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
QuantizerAttrNames,
quantizer_attr_names,
reduce_block_amax,
representative_weight_quantizer,
weight_attr_names,
)
from modelopt.torch.utils import clear_cuda_cache
Expand Down Expand Up @@ -546,7 +547,7 @@ def _compute_kv_cache_dtype(

def get_weight_block_size(module: nn.Module, weight_name: str = "weight") -> int:
"""Returns the weight block size."""
weight_quantizer = getattr(module, quantizer_attr_names(weight_name).weight_quantizer, None)
weight_quantizer = representative_weight_quantizer(module, weight_name)

if weight_quantizer is None:
return 0
Expand All @@ -572,7 +573,11 @@ def get_quantization_format(module) -> str | None:
"""

def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames):
weight_quantizer = getattr(layer, quantizer_attr_names.weight_quantizer, None)
# Singular form first, plural ModuleList fallback (fused-experts).
# Strip the "_weight_quantizer" suffix to recover the weight attr name.
weight_attr = quantizer_attr_names.weight_quantizer
weight_name = weight_attr[: -len("_weight_quantizer")].rstrip("_") or "weight"
weight_quantizer = representative_weight_quantizer(layer, weight_name)
input_quantizer = getattr(layer, quantizer_attr_names.input_quantizer, None)

if weight_quantizer is None or not weight_quantizer.is_enabled:
Expand Down
19 changes: 11 additions & 8 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
QUANTIZATION_W4A8_NVFP4_FP8,
)
from .model_utils import get_language_model_from_vl, is_multimodal_model
from .moe_utils import _export_fused_experts
from .plugins import SpeculativeDecodingExporter, has_spec_opt
from .quant_utils import (
fuse_prequant_layernorm,
Expand Down Expand Up @@ -642,11 +643,20 @@ def _process_quantized_modules(
if is_modelopt_qlora and (hasattr(sub_module, "base_layer")):
continue

# Preprocessing: restore unpacked weight so the export path can read
# the live quantizer state. Falls through to the export branches below.
if hasattr(sub_module, "weight_packed") or (
"QuantFP8Linear" in type(sub_module).__name__ and sub_module.weight.element_size() <= 1
):
sub_module.unpack_weight()
if get_quantization_format(sub_module) != QUANTIZATION_NONE:

if hasattr(sub_module, "gate_up_proj_weight_quantizers"):
# _QuantFusedExperts uses plural `gate_up_proj_weight_quantizers` (ModuleList),
# which get_quantization_format's singular-weight_quantizer check misses. Handle
# it explicitly before the format gate so fused-experts get split + quantized.
with fsdp2_aware_weight_update(model, sub_module, reshard=False):
_export_fused_experts(sub_module, dtype)
Comment thread
meenchen marked this conversation as resolved.
elif get_quantization_format(sub_module) != QUANTIZATION_NONE:
# Skip QuantMoELinear - it's handled separately in _reconstruct_fused_moe_linear
if type(sub_module).__name__ == "QuantMoELinear":
continue
Expand Down Expand Up @@ -677,13 +687,6 @@ def _process_quantized_modules(
with fsdp2_aware_weight_update(model, sub_module, reshard=False):
for weight_name in ["gate_up_proj", "down_proj"]:
_export_quantized_weight(sub_module, dtype, weight_name)
elif hasattr(sub_module, "gate_up_proj_weight_quantizers"):
# Generic fused MoE experts (_QuantFusedExperts) with per-expert
# quantizer ModuleLists. Split into per-expert modules and export.
from modelopt.torch.export.moe_utils import _export_fused_experts

with fsdp2_aware_weight_update(model, sub_module, reshard=False):
_export_fused_experts(sub_module, dtype)


def _export_transformers_checkpoint(
Expand Down
34 changes: 33 additions & 1 deletion modelopt/torch/quantization/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""Quantization conversion/restore utilities."""

import fnmatch
import re
import warnings
from collections.abc import Callable
from contextlib import contextmanager
Expand Down Expand Up @@ -286,6 +287,33 @@ def set_quantizer_by_cfg(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgType
set_quantizer_attributes_full(quant_model, quantizer_name, attributes, parent_class)


_FUSED_EXPERTS_QUANTIZER_LIST_RE = re.compile(
r"(weight_quantizers?|input_quantizers?)\.\d+(?=$|\.)"
)


def _normalize_fused_experts_quantizer_name(name: str) -> str:
"""Strip the per-expert index from per-expert quantizer ModuleList names.

Fused-experts modules register per-expert weight/input quantizers in a
``nn.ModuleList``; its children surface as dotted names like
``...gate_up_proj_weight_quantizers.0`` (plural) or — if a variant uses
singular naming — ``...gate_up_proj_weight_quantizer.0``. Neither matches
the singular-suffix wildcards (``*weight_quantizer``) used in the stock
configs, so the experts stay at their defaults.

Return a normalized name where either ``weight_quantizer[s]?.N`` or
``input_quantizer[s]?.N`` collapses to the singular form without the index
so the standard wildcards match.
"""

def _repl(m: re.Match) -> str:
base = m.group(1)
return base.removesuffix("s")

return _FUSED_EXPERTS_QUANTIZER_LIST_RE.sub(_repl, name)


def _match_quantizer(
wildcard_or_filter_func: str | Callable,
name: str,
Expand All @@ -296,7 +324,11 @@ def _match_quantizer(
if not isinstance(module, (TensorQuantizer, SequentialQuantizer)):
return False
if isinstance(wildcard_or_filter_func, str):
if not fnmatch.fnmatch(name, wildcard_or_filter_func):
normalized = _normalize_fused_experts_quantizer_name(name)
if not (
fnmatch.fnmatch(name, wildcard_or_filter_func)
or (normalized != name and fnmatch.fnmatch(normalized, wildcard_or_filter_func))
):
return False
elif callable(wildcard_or_filter_func):
if not wildcard_or_filter_func(name):
Expand Down
33 changes: 33 additions & 0 deletions modelopt/torch/quantization/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -1419,6 +1419,38 @@ def register_fused_experts_on_the_fly(model):
QuantModuleRegistry.register({mod_type: f"hf.{mod_type.__name__}"})(_QuantFusedExperts)


def force_eager_experts_impl_on_the_fly(model):
"""Force HF fused-experts modules onto the eager ``F.linear``-based forward.

HF transformers 5.0+ decorates fused-experts forwards with
``@use_experts_implementation``, which may dispatch to ``torch._grouped_mm``
or ``torch.bmm`` backends. Those backends bypass ``F.linear`` and so bypass
``_QuantFusedExperts``'s input/weight quantizer hooks — calibration silently
does nothing, no ``input_scale`` / ``amax`` is collected, and the exported
checkpoint produces garbage at inference.

Sets ``config._experts_implementation = "eager"`` on the model config (and
recursively on ``text_config`` / ``vision_config`` / ``audio_config`` /
``speech_config``) whenever a fused-experts module is present.
"""
if not any(_is_fused_experts_module(m) for m in model.modules()):
return

nested_cfg_attrs = ("text_config", "vision_config", "audio_config", "speech_config")

def _force(cfg):
if cfg is None:
return
if hasattr(cfg, "_experts_implementation"):
cfg._experts_implementation = "eager"
for sub in nested_cfg_attrs:
if hasattr(cfg, sub):
_force(getattr(cfg, sub))

if hasattr(model, "config"):
_force(model.config)


def _is_supported_hf_model(model):
"""Check if the model a valid model for transformers quantization specific support."""
supported_models = [transformers.PreTrainedModel]
Expand Down Expand Up @@ -1646,6 +1678,7 @@ def _reconstruct_fused_moe_linear(model: nn.Module) -> None:
register_dbrx_moe_on_the_fly,
register_step3p5_moe_on_the_fly,
register_fused_experts_on_the_fly,
force_eager_experts_impl_on_the_fly,
register_sparse_moe_on_the_fly,
register_hf_attentions_on_the_fly,
convert_hf_parallel_linears_on_the_fly,
Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/quantization/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"reduce_amax",
"reduce_sum",
"replace_function",
"representative_weight_quantizer",
"update_quant_cfg_with_kv_cache_quant",
"weight_attr_names",
]
56 changes: 42 additions & 14 deletions modelopt/torch/quantization/utils/core_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,27 +202,55 @@ def reduce_sum(input, axis=None, keepdims=True):
return output


def weight_attr_names(module: nn.Module) -> "Generator[str, None, None]":
"""Get the weight param attribute names in a converted module, non-recursive.
def representative_weight_quantizer(module: nn.Module, weight_name: str = "weight"):
"""Return the representative weight quantizer for ``weight_name`` on ``module``.

Handles two layouts:
- singular ``<name>_weight_quantizer`` — standard ``nn.Linear`` / ``_QuantLinear``.
- plural ``<name>_weight_quantizers`` (``nn.ModuleList``) — fused-experts modules
(``_QuantFusedExperts``) hold one ``TensorQuantizer`` per expert. Per-expert
formats are identical, so the first element is representative.

We consider the following two cases for each weight param attribute:
- The standard weight attribute (e.g. nn.Linear).
- The custom `weight_attr_name`. (e.g. Llama4TextExperts has weight attributes `gate_up_proj` and `down_proj`)
Returns ``None`` if no matching quantizer is found.
"""
from ..nn import SequentialQuantizer, TensorQuantizer

# the standard weight and quantizer case
weight = getattr(module, "weight", None)
weight_quantizer = getattr(module, "weight_quantizer", None)
if weight is not None and isinstance(weight_quantizer, (TensorQuantizer, SequentialQuantizer)):
yield "weight"
singular = quantizer_attr_names(weight_name).weight_quantizer
q = getattr(module, singular, None)
if isinstance(q, (TensorQuantizer, SequentialQuantizer)):
return q

# other weight and quantizer case
plural = getattr(module, singular + "s", None)
if isinstance(plural, nn.ModuleList) and len(plural) > 0:
first = plural[0]
if isinstance(first, (TensorQuantizer, SequentialQuantizer)):
return first
return None


def weight_attr_names(module: nn.Module) -> "Generator[str, None, None]":
"""Get the weight param attribute names in a converted module, non-recursive.

Covers three layouts:
- standard ``nn.Linear``: ``weight`` + ``weight_quantizer``.
- custom per-weight quantizer (e.g. ``Llama4TextExperts`` with ``gate_up_proj`` +
``gate_up_proj_weight_quantizer``).
- fused-experts ``nn.ModuleList`` quantizers (``_QuantFusedExperts`` with
``gate_up_proj`` + ``gate_up_proj_weight_quantizers`` plural list).
"""
# standard: "weight" + "weight_quantizer" (singular) or "weight_quantizers" (plural)
if getattr(module, "weight", None) is not None:
if representative_weight_quantizer(module, "weight") is not None:
yield "weight"

# per-parameter custom attr names
for name, _ in module.named_parameters(recurse=False):
if name == "weight":
continue
weight = getattr(module, name, None)
weight_quantizer = getattr(module, f"{name}_weight_quantizer", None)
if isinstance(weight, nn.Parameter) and isinstance(
weight_quantizer, (TensorQuantizer, SequentialQuantizer)
if (
isinstance(weight, nn.Parameter)
and representative_weight_quantizer(module, name) is not None
):
yield name

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ quantize:
algorithm:
method: max
# Max calibration is fast and does not typically need checkpointing.
layerwise: true
# layerwise=false required for VLMs where the decoder layers are nested under
# `model.language_model.layers` (layerwise_calibrate can't find them otherwise).
layerwise: false
quant_cfg:
- quantizer_name: '*'
enable: false
Expand Down
Loading
Loading