Skip to content

Fix weight-only quantization and export for TEGroupedMLP (MoE models)#971

Open
jQizhang wants to merge 2 commits intoNVIDIA:mainfrom
jQizhang:weight_only_te_fix
Open

Fix weight-only quantization and export for TEGroupedMLP (MoE models)#971
jQizhang wants to merge 2 commits intoNVIDIA:mainfrom
jQizhang:weight_only_te_fix

Conversation

@jQizhang
Copy link

@jQizhang jQizhang commented Mar 4, 2026

What does this PR do?

This PR fixes a critical issue where weight-only quantization fails for MoE models utilizing TEGroupedMLP (e.g., Qwen3-30B-A3B).

The Problem:

In TEGroupedMLP, weights are stored per-expert as weight0, weight1, ..., weightN. During _QuantTEGroupedLinear._setup, the standard self.weight attribute is deleted.
The existing weight_only_quantize logic expects to find a self.weight associated with the quantizer. Because it couldn't find these "hidden" expert weights, the weight_quantizer failed to calibrate, resulting in a missing _amax attribute. This leads to the following crash during export/inference:

File ".../modelopt/torch/quantization/qtensor/nvfp4_tensor.py", line 59, in get_weights_scaling_factor_2_from_quantizer
assert hasattr(weight_quantizer, "_amax"), "Weight quantizer does not have attribute amax"

The Solution:

  1. Calibration Interface: Introduced iter_weights_for_calibration in the QuantModule base class.
  2. MoE Support: Overrode this method in _QuantTEGroupedLinear to yield all per-expert weights (weight0...weightN) that share the same quantizer. This ensures the calibrator "sees" all expert weights and calculates a valid _amax.
  3. Exporter Fix: Updated GPTModelExporter to correctly handle the structure of TEGroupedMLP during HuggingFace format conversion, ensuring MoE checkpoints can be exported after quantization.

2. Type of change

  • Bug fix

3. Usage / Reproduction

This issue is reproducible when running weight-only quantization on MoE models like Qwen3-30B-A3B:

# Step 1: Quantization
torchrun --nproc_per_node 8 examples/quantization/quantize.py \
    --hf-model-id Qwen/Qwen3-30B-A3B \
    --export-quant-cfg nvfp4 \
    --tp 2 \
    --ep 8 \
    --weight-only \
    --megatron-save-path ./qwen3_30b_nvfp4

# Step 2: Export
torchrun --nproc_per_node 2 examples/quantization/export.py \
    --hf-model-id Qwen/Qwen3-30B-A3B \
    --megatron-load-path ./qwen3_30b_nvfp4 \
    --export-dir ./qwen3_30b_nvfp4_hf \
    --pp 2 \
    --dtype bfloat16

4. Testing & Verification

  • Models Tested: Qwen3-8B (Dense), Qwen3-30B-A3B (MoE).
  • Quantization: NVFP4/FP8 weight-only quantization.
  • Verification: - Confirmed that QuantTEGroupedMLP now correctly shows calculated _amax values in the quantization statistics table instead of remaining dynamic.
  • Successful export to HuggingFace format without AttributeError.
  • Validated that the change does not regress dense model (Qwen3-8B) quantization flow.
  • After fix, the amax of experts can be calculated correctly.
                                  Quantization Statistics                                   
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━━┓
┃ Parameter Name                                                      ┃ Shape ┃  Max Value ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━━┩
│ decoder.layers.0.self_attention.linear_proj.weight_quantizer._amax  │ ()    │ 7.5781e-01 │
│ decoder.layers.0.self_attention.linear_qkv.weight_quantizer._amax   │ ()    │ 2.8711e-01 │
│ decoder.layers.0.mlp.experts.linear_fc1.weight_quantizer._amax      │ ()    │ 7.1094e-01 │
│ decoder.layers.0.mlp.experts.linear_fc2.weight_quantizer._amax      │ ()    │ 8.6719e-01 │
│ decoder.layers.1.self_attention.linear_proj.weight_quantizer._amax  │ ()    │ 5.8594e-01 │
│ decoder.layers.1.self_attention.linear_qkv.weight_quantizer._amax   │ ()    │ 7.4219e-01 │
│ decoder.layers.1.mlp.experts.linear_fc1.weight_quantizer._amax      │ ()    │ 7.2266e-01 │
│ decoder.layers.1.mlp.experts.linear_fc2.weight_quantizer._amax      │ ()    │ 1.9922e+00 │
│ decoder.layers.2.self_attention.linear_proj.weight_quantizer._amax  │ ()    │ 1.0859e+00 │
│ decoder.layers.2.self_attention.linear_qkv.weight_quantizer._amax   │ ()    │ 1.7812e+00 │
│ decoder.layers.2.mlp.experts.linear_fc1.weight_quantizer._amax      │ ()    │ 7.3047e-01 │
│ decoder.layers.2.mlp.experts.linear_fc2.weight_quantizer._amax      │ ()    │ 1.9219e+00 │

Summary by CodeRabbit

  • New Features

    • Added export support for grouped-expert layers so individual experts are exported for downstream use.
  • Improvements

    • Improved weight-only quantization calibration with a new per-weight iteration hook for quantization modules, including grouped-linear paths.
    • More robust detection of real weight attributes during calibration to avoid spurious calibration steps.

@jQizhang jQizhang requested review from a team as code owners March 4, 2026 15:47
@jQizhang jQizhang requested review from jingyu-ml and sychen52 March 4, 2026 15:47
@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 4, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 4, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 44821cad-30a4-4d25-937d-0150f28d3528

📥 Commits

Reviewing files that changed from the base of the PR and between 77cf3d2 and 463d639.

📒 Files selected for processing (2)
  • modelopt/torch/export/unified_export_megatron.py
  • modelopt/torch/quantization/model_calib.py

📝 Walkthrough

Walkthrough

Adds a weight-calibration iterator API and uses it in calibration flow; introduces an internal proxy and export path to emit per-expert weights from grouped-linear/TEGroupedMLP when local_experts are absent.

Changes

Cohort / File(s) Summary
Weight calibration iterator & utils
modelopt/torch/quantization/nn/modules/quant_module.py, modelopt/torch/quantization/plugins/transformer_engine.py, modelopt/torch/quantization/utils.py, modelopt/torch/quantization/model_calib.py
Adds iter_weights_for_calibration() to QuantModule and TE grouped-linear quantization class; updates weight_attr_names to yield only when weight exists; refactors weight_only_quantize to use the new iterator path for QuantModule instances and preserve prior behavior for others.
Grouped-linear expert export
modelopt/torch/export/unified_export_megatron.py
Adds _GroupedLinearExpertProxy and _export_grouped_mlp_experts() in GPTModelExporter; switches export paths to emit per-expert proxies for TEGroupedMLP layers when local_experts is absent, mirroring transformer/eagle export branches. Also imports copy.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested reviewers

  • realAsma
  • jenchen13
  • ChenhanYu
  • yueshen2016
  • cjluo-nv
🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 44.44% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main change: fixing weight-only quantization and export for TEGroupedMLP in MoE models, which directly corresponds to the multi-file changes across quantization and export modules.
Security Anti-Patterns ✅ Passed The pull request does not introduce security anti-patterns. The trust_remote_code parameter is properly exposed as caller-configurable with safe default of False, not hardcoded. No torch.load, numpy.load, eval, exec, or nosec comments introduced.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 74-84: The QuantModule branch captures weights by calling
module.iter_weights_for_calibration() outside the
enable_weight_access_and_writeback context, so remapped/sharded/offloaded
weights may be stale; move the call into the context so iteration happens while
enable_weight_access_and_writeback(module, model) is active (i.e., enter the
with block first, then call module.iter_weights_for_calibration() and call
weight_quantizer(weight) inside that context). Keep the else branch behavior for
weight_attr_names/quantizer_attr_names unchanged.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: a8c54433-356b-4651-9f26-5e26affd01eb

📥 Commits

Reviewing files that changed from the base of the PR and between a34d613 and 77cf3d2.

📒 Files selected for processing (5)
  • modelopt/torch/export/unified_export_megatron.py
  • modelopt/torch/quantization/model_calib.py
  • modelopt/torch/quantization/nn/modules/quant_module.py
  • modelopt/torch/quantization/plugins/transformer_engine.py
  • modelopt/torch/quantization/utils.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant