Skip to content

fixes for fused moe (qwen3.6, GLM5.1 + MSE calibration#1382

Open
Fridah-nv wants to merge 4 commits intomainfrom
fridah/fused-moe-MSE-fix
Open

fixes for fused moe (qwen3.6, GLM5.1 + MSE calibration#1382
Fridah-nv wants to merge 4 commits intomainfrom
fridah/fused-moe-MSE-fix

Conversation

@Fridah-nv
Copy link
Copy Markdown
Contributor

@Fridah-nv Fridah-nv commented May 2, 2026

What does this PR do?

Type of change: Bug fix

Fixes several issues with NVFP4 MSE calibration and export for fused MoE expert modules (_QuantFusedExperts — used by Qwen3.6, GLM-5.1, and other HF transformers 5.0+ models that store expert weights as 3-D nn.Parameters).

  • Bug 1 — MSE weight calibration runs 0 iterations for fused experts (model_calib.py)

The weight-quantizer discovery loop in mse_calibrate used the singular attribute name gate_up_proj_weight_quantizer to look up quantizers, but _QuantFusedExperts stores them in a plural nn.ModuleList named gate_up_proj_weight_quantizers. All 20,480 expert quantizers were silently skipped, resulting in "MSE weight calibration: 0it" and no MSE-optimized scales.

Fix: add a second pass that detects plural {param}_weight_quantizers ModuleLists and enqueues each per-expert quantizer with a (param_name, expert_idx) tuple; step 3 unpacks the tuple to extract the per-expert weight slice.

  • Bug 2 — Zero weight scales in exported checkpoint (nvfp4_tensor.py)

Per-block weight scales can silently underflow to 0 when cast to FP8 E4M3FN. The existing scale == 0 guard only catches exact float32 zeros; values in (0, 2^-9) pass through and become 0 after the FP8 cast. This affects both the dynamic recompute path (get_weights_scaling_factor) and the static calibrated path (get_weights_scaling_factor_from_quantizer).

Fix: clamp per-block scales to 2^-9 (smallest positive FP8 E4M3FN subnormal) before the FP8 cast in both paths.

  • Bug 3 — Zero/corrupt amax for uncalibrated experts at export (moe_utils.py)

Experts that receive no tokens during calibration have _amax = 0 or uninitialized values. The existing scalar fallback used 1e-4 which itself underflows to 0 in FP8 E4M3FN (1e-4 < 2^-9 ≈ 0.00195). Additionally, the per-block fallback tensor had shape (H*W, 1) instead of (H, W), causing a shape mismatch that silently bypassed the fallback and fell through to the bad scalar. Finally, a stale zero global_amax from an uncalibrated expert was not recomputed, causing division-by-zero in the FP8 scale formula.

Fix: reshape the per-block fallback correctly; raise the clamp floor to 2e-3; always recompute global_amax from the current (possibly patched) per-block _amax.

Additional fixes:

  • moe_utils.py: safe CPU extraction of _amax before deepcopy to avoid async CUDA errors from corrupt bfloat16 amax storage on under-calibrated experts.
  • model_quant.py: print_quant_summary now calls os.makedirs(output_dir, exist_ok=True) before writing .quant_summary.txt, preventing a FileNotFoundError when the export directory doesn't exist yet.
  • tensor_quantizer.py: change default format in _short_amax / _short_tensor from ".4f" to ".2e" so small amax values (e.g. 3.5e-7) display as 3.50e-07 instead of 0.0000.
  • hf_ptq.py: strip leading pad tokens from the preview input and add skip_special_tokens=True to input_decode, fixing degenerate pre/post-PTQ output on models that use EOS as the pad token (e.g. Qwen3).

Usage

 # Quantize Qwen3.6-35B-A3B (or any compatible fused-expert MoE) with the new recipe:
  python examples/llm_ptq/hf_ptq.py \                                                                                                     
      --pyt_ckpt_path /path/to/Qwen3.6-35B-A3B \                                                                                          
      --recipe modelopt_recipes/general/ptq/nvfp4_experts_only_mse.yaml \                                                                 
      --export_path /path/to/output \                                                                                                     
      --calib_size 512 --calib_seq 2048   

Testing

validated on Qwen3.6-35B-A3B (8× B200):

  • 21,740 quantizers inserted; 20,480/20,480 MSE weight calibrations completed (~11 min)
  • 0 / 2,013,265,920 zero weight_scale entries in the exported checkpoint (3 shards)
  • Pre- and post-PTQ generation produce coherent, semantically consistent output

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

  • New Features

    • Added a PTQ recipe for NVFP4 W4A4 quantization targeting MoE routed experts with MSE calibration and FP8 scale sweep.
  • Bug Fixes

    • Prevented FP8 scale underflow to zero by clamping tiny scales.
    • Validated and patched per-block quantization scales and ensured proper device placement during export.
    • Fixed preview decoding/input prep to avoid padded-token artifacts.
    • Auto-creates quantization summary output directories.
  • Improvements

    • Calibration now supports per-expert/fused-expert quantizer layouts.
  • Tests

    • Added tests for NVFP4 scale clamping and fused-experts export/calibration.

Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
@Fridah-nv Fridah-nv requested review from a team as code owners May 2, 2026 00:14
@Fridah-nv Fridah-nv requested review from Edwardf0t1 and sychen52 May 2, 2026 00:14
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 2, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 7dd6f433-2b9a-492e-975d-7e85bc8af6aa

📥 Commits

Reviewing files that changed from the base of the PR and between b5e2c71 and cfe4a4a.

📒 Files selected for processing (4)
  • modelopt/torch/export/moe_utils.py
  • modelopt/torch/quantization/model_calib.py
  • tests/unit/torch/quantization/plugins/test_fused_experts.py
  • tests/unit/torch/quantization/test_nvfp4_tensor.py
✅ Files skipped from review due to trivial changes (1)
  • modelopt/torch/export/moe_utils.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/unit/torch/quantization/test_nvfp4_tensor.py
  • tests/unit/torch/quantization/plugins/test_fused_experts.py

📝 Walkthrough

Walkthrough

Adds per-expert handling and safety fixes to MoE quantization/export, prevents FP8 scale underflow by clamping per-block scales, extends MSE calibration to discover per-expert quantizers, adds an NVFP4 experts-only PTQ recipe and tests, and tweaks LLM PTQ preview input decoding behavior.

Changes

MoE Quantization Infrastructure & Tests

Layer / File(s) Summary
Quantizer discovery / typing
modelopt/torch/quantization/model_calib.py
mse_calibrate can collect per-expert quantizers from {param_name}_weight_quantizers ModuleLists and uses tuple identifiers (param_name, expert_idx); typing adjusted for the new identifier shape; weight lookup updated to handle tuple identifiers.
Per-expert quantizer export wiring
modelopt/torch/export/moe_utils.py
Deep-copy gate/up weight quantizers per expert (down quantizers reused); slice per-channel fused _amax via direct _amax assignment when evenly divisible.
Per-block amax validation & fallback
modelopt/torch/export/moe_utils.py
Validate multi-element per-block _amax entries; replace non-finite/out-of-range entries with weight-derived per-block fallback values; strengthen uncalibrated-expert fallback to require enabled per-block quantizers.
Move/clamp amax before export
modelopt/torch/export/moe_utils.py
Move patched _amax to the weight device and recompute global_amax from clamped per-block _amax prior to _export_quantized_weight.
NVFP4 scale clamping implementation
modelopt/torch/quantization/qtensor/nvfp4_tensor.py
Clamp computed per_block_scale to FP8 E4M3FN smallest positive subnormal (2**-9) before casting to torch.float8_e4m3fn in both static and dynamic quantizer paths.
Tests: fused experts export & calibration
tests/unit/torch/quantization/plugins/test_fused_experts.py
Refactor export test to quantize a tiny MoE model end-to-end and export per-expert submodules; add MSE calibration test asserting per-expert amax populated; add parameterized fallback-warning test to verify per-block _amax repair and positive global_amax.
Tests: NVFP4 scale clamping
tests/unit/torch/quantization/test_nvfp4_tensor.py
Add tests ensuring tiny per-block weights produce per-block scales >= FP8 minimum and that normal/mixed blocks yield strictly positive scales.
Minor formatting / summary
modelopt/torch/quantization/nn/modules/tensor_quantizer.py, modelopt/torch/quantization/model_quant.py
Adjust default numeric formatting in extra_repr() to scientific .2e; ensure print_quant_summary creates output_dir before writing .quant_summary.txt.

LLM PTQ Preview Input Handling

Layer / File(s) Summary
Preview input trimming
examples/llm_ptq/hf_ptq.py
In pre_quantize, strip leading pad_token_id tokens from preview_input_ids for non-Whisper models when tokenizer.pad_token_id exists.
Preview decode behavior
examples/llm_ptq/hf_ptq.py
In post_quantize input_decode, use tokenizer.batch_decode(..., skip_special_tokens=True) to omit special tokens from decoded preview text.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

  • NVIDIA/Model-Optimizer#1340: Related work addressing fused-experts PTQ calibration/export handling and per-expert quantizer splitting.

Suggested labels

bug, cherry-pick-0.44.0

Suggested reviewers

  • realAsma
  • sychen52
🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 65.38% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main changes: fixes for fused MoE quantization in Qwen3.6 and GLM5.1 with MSE calibration improvements.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed All modified files reviewed against SECURITY.md. No violations found: no unsafe torch.load, numpy.load, hardcoded trust_remote_code, eval/exec, nosec, yaml.load, or subprocess shell=True.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch fridah/fused-moe-MSE-fix

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 2, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1382/

Built to branch gh-pages at 2026-05-05 00:04 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/export/moe_utils.py`:
- Around line 98-103: The temporary mutation of w_quantizer_src._amax before
calling copy.deepcopy may leave the source quantizer with _amax == None if
deepcopy raises; change the code around copy.deepcopy(w_quantizer_src) to save
_saved_amax, set w_quantizer_src._amax = None, then perform deepcopy inside a
try block and restore w_quantizer_src._amax = _saved_amax in a finally block;
after deepcopy set w_quantizer._amax = gu_amax_cpu as before so the source state
is always restored even on exceptions.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: c7efeb50-0d25-4ef7-8b84-e1a0a66662b4

📥 Commits

Reviewing files that changed from the base of the PR and between 9d2e608 and 35dad9a.

📒 Files selected for processing (7)
  • examples/llm_ptq/hf_ptq.py
  • modelopt/torch/export/moe_utils.py
  • modelopt/torch/quantization/model_calib.py
  • modelopt/torch/quantization/model_quant.py
  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py
  • modelopt/torch/quantization/qtensor/nvfp4_tensor.py
  • modelopt_recipes/general/ptq/nvfp4_experts_only_mse.yaml

Comment thread modelopt/torch/export/moe_utils.py Outdated
@Fridah-nv Fridah-nv requested a review from cjluo-nv May 2, 2026 00:21
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

This PR fixes several real bugs in the fused MoE quantization pipeline (MSE calibration discovery, FP8 scale underflow, uncalibrated expert export). The fixes are well-described in the PR body and address genuine correctness issues. However, there are several concerns:

  1. Missing unit tests (critical): No tests are added for any of the bug fixes. The existing test_fused_experts.py covers registration/conversion/basic export but doesn't exercise MSE calibration for fused experts, FP8 scale clamping, or the invalid-amax patching logic. Given the complexity of the moe_utils.py changes and the project's known pattern of missing tests, this is a blocking concern.

  2. Threshold inconsistency: _MIN_VALID_AMAX = 1e-4 is below FP8 E4M3FN minimum (2^-9 ≈ 0.00195), meaning values between 1e-4 and 2e-3 pass the validity check but could still underflow.

  3. Hardcoded block_size=16: The fallback per-block amax computation in moe_utils.py hardcodes 16. If the actual block size differs, the shape will be wrong.

  4. Copyright year: New YAML file has Copyright (c) 2024 but LICENSE_HEADER says 2026.

Comment thread modelopt/torch/export/moe_utils.py
Comment thread modelopt/torch/export/moe_utils.py
Comment thread modelopt/torch/export/moe_utils.py Outdated
Comment thread modelopt/torch/export/moe_utils.py Outdated
Comment thread modelopt_recipes/general/ptq/nvfp4_experts_only_mse.yaml Outdated
Comment thread modelopt/torch/quantization/model_calib.py Outdated
Comment thread examples/llm_ptq/hf_ptq.py Outdated
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
tests/unit/torch/quantization/plugins/test_fused_experts.py (1)

728-737: ⚡ Quick win

Assert the repaired quantizer state, not just the warning.

This still passes if the warning is emitted but the fallback leaves _amax with the wrong per-block shape or global_amax stale/zero. Those are the export failures this change is fixing, so it would be worth capturing the mocked wrapper objects and asserting the repaired quantizer state directly.

Possible tightening
+        captured = []
+
+        def _capture_export(wrapper, dtype):
+            captured.append((tuple(wrapper.weight.shape), wrapper.weight_quantizer))
+
         with (
-            patch("modelopt.torch.export.unified_export_hf._export_quantized_weight"),
+            patch(
+                "modelopt.torch.export.unified_export_hf._export_quantized_weight",
+                side_effect=_capture_export,
+            ),
             warnings.catch_warnings(record=True) as caught,
         ):
             warnings.simplefilter("always")
             _export_fused_experts(converted, torch.float16)
 
         assert any("weight-derived per-block amax" in str(w.message) for w in caught), (
             f"No fallback warning emitted for {'zero' if zero_amax else 'None'} amax — Bug 3 regression"
         )
+        for weight_shape, weight_quantizer in captured:
+            assert weight_quantizer._amax is not None
+            assert weight_quantizer._amax.numel() == (weight_shape[0] * weight_shape[1]) // 16
+            assert weight_quantizer.global_amax is not None
+            assert weight_quantizer.global_amax.item() > 0
         self._cleanup_registry(expert_type)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/unit/torch/quantization/plugins/test_fused_experts.py` around lines 728
- 737, The test currently only checks for a fallback warning; update it to also
capture the mocked export wrapper(s) and assert the repaired quantizer state
after calling _export_fused_experts(converted, torch.float16): specifically,
patch and capture the wrapper returned by
modelopt.torch.export.unified_export_hf._export_quantized_weight (or the outer
wrapper used in the test) and then assert that each quantizer's internal _amax
has the expected per-block shape (not a scalar) and that global_amax is
updated/non-zero (or not stale) for the converted object’s experts; keep the
existing warning assertion but add these direct state assertions on converted
(and its quantizer instances) to ensure the fallback actually fixes the
quantizer internals.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@modelopt/torch/export/moe_utils.py`:
- Around line 109-137: The invalid-_amax repair should only run for per-block
quantizers—before computing _block_size or reshaping weight_slice, check that
(getattr(w_quantizer, "block_sizes", None) or {}).get(-1) is not None and bail
out of this repair branch when it is None; do not fall back to a default block
size (remove the hardcoded 16 default), obtain _block_size from that block_sizes
entry, and only then compute per_block_fallback and assign into
w_quantizer._amax using weight_slice, per_block_fallback, invalid_mask as
currently done.

In `@tests/unit/torch/quantization/test_nvfp4_tensor.py`:
- Around line 32-42: The test's wsf2 is too small and makes per_block_scale
large instead of exercising the FP8-min clamp path; update the wsf2 fixture used
before calling NVFP4QTensor.get_weights_scaling_factor so that per_block_scale
becomes very small (below _FP8_E4M3FN_MIN) for the tiny_weight case — e.g.,
increase wsf2 by many orders of magnitude (replace the current wsf2 value with a
larger magnitude such as using 1e-2/(6.0 * 448.0) or similar) so per_block_scale
= per_block_amax / (6 * wsf2) triggers the FP8 underflow/clamp behavior checked
against _FP8_E4M3FN_MIN.

---

Nitpick comments:
In `@tests/unit/torch/quantization/plugins/test_fused_experts.py`:
- Around line 728-737: The test currently only checks for a fallback warning;
update it to also capture the mocked export wrapper(s) and assert the repaired
quantizer state after calling _export_fused_experts(converted, torch.float16):
specifically, patch and capture the wrapper returned by
modelopt.torch.export.unified_export_hf._export_quantized_weight (or the outer
wrapper used in the test) and then assert that each quantizer's internal _amax
has the expected per-block shape (not a scalar) and that global_amax is
updated/non-zero (or not stale) for the converted object’s experts; keep the
existing warning assertion but add these direct state assertions on converted
(and its quantizer instances) to ensure the fallback actually fixes the
quantizer internals.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: d9b0177c-879d-409f-a07b-6b174403d0a0

📥 Commits

Reviewing files that changed from the base of the PR and between 35dad9a and ea670ab.

📒 Files selected for processing (5)
  • modelopt/torch/export/moe_utils.py
  • modelopt/torch/quantization/model_calib.py
  • modelopt_recipes/general/ptq/nvfp4_experts_only_mse.yaml
  • tests/unit/torch/quantization/plugins/test_fused_experts.py
  • tests/unit/torch/quantization/test_nvfp4_tensor.py
✅ Files skipped from review due to trivial changes (1)
  • modelopt_recipes/general/ptq/nvfp4_experts_only_mse.yaml
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/quantization/model_calib.py

Comment thread modelopt/torch/export/moe_utils.py
Comment thread tests/unit/torch/quantization/test_nvfp4_tensor.py Outdated
@codecov
Copy link
Copy Markdown

codecov Bot commented May 4, 2026

Codecov Report

❌ Patch coverage is 79.68750% with 13 lines in your changes missing coverage. Please review.
✅ Project coverage is 76.85%. Comparing base (50706d1) to head (cfe4a4a).
⚠️ Report is 9 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/export/moe_utils.py 66.66% 7 Missing ⚠️
modelopt/torch/quantization/model_calib.py 83.33% 6 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1382       +/-   ##
===========================================
+ Coverage   66.36%   76.85%   +10.49%     
===========================================
  Files         471      471               
  Lines       50510    50768      +258     
===========================================
+ Hits        33522    39019     +5497     
+ Misses      16988    11749     -5239     
Flag Coverage Δ
examples 41.56% <10.93%> (+0.92%) ⬆️
gpu 59.73% <51.56%> (+32.80%) ⬆️
regression 14.89% <4.68%> (+0.05%) ⬆️
unit 52.85% <75.00%> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Comment thread modelopt/torch/quantization/model_calib.py Outdated
Comment thread modelopt/torch/export/moe_utils.py Outdated
Comment thread modelopt/torch/export/moe_utils.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 429-436: The loop assumes each f"{param_name}_weight_quantizers"
ModuleList length equals the leading expert dimension of the parameter, which
can drift and cause indexing errors; modify the branch that processes
parent_module.named_parameters(recurse=False) to retrieve the parameter tensor
(via parent_module.get_parameter(param_name) or getattr), read its leading
dimension, and assert it equals len(qlist) (or raise a clear ValueError) before
iterating experts; reference the symbols parent_module, param_name, qlist
(f"{param_name}_weight_quantizers"), TensorQuantizer, expert_idx, and
weight_quantizers so the check fails fast with a descriptive message when sizes
mismatch.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 4a9c7d3a-e170-4c6e-80d6-2422c258738e

📥 Commits

Reviewing files that changed from the base of the PR and between ea670ab and b5e2c71.

📒 Files selected for processing (3)
  • examples/llm_ptq/hf_ptq.py
  • modelopt/torch/quantization/model_calib.py
  • modelopt/torch/quantization/qtensor/nvfp4_tensor.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/quantization/qtensor/nvfp4_tensor.py

Comment thread modelopt/torch/quantization/model_calib.py Outdated
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
torch.float8_e4m3fn
fp8_e4m3fn_min = 2**-9 # 0.001953125 — smallest positive subnormal
per_block_scale = (
(per_block_scale * 448.0 / per_block_scale_max)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
(per_block_scale * 448.0 / per_block_scale_max)
(per_block_scale.float() * 448.0 / per_block_scale_max)

Comment on lines +173 to 179
# Clamp to the minimum positive FP8 E4M3FN subnormal (~0.00195 = 2^-9) before
# casting. Without this, blocks whose scale falls below the FP8 representable
# range silently underflow to 0, causing those blocks to produce zero output at
# inference even when the weights are non-trivial.
fp8_e4m3fn_min = 2**-9 # 0.001953125 — smallest positive subnormal
per_block_scale = per_block_scale.clamp(min=fp8_e4m3fn_min)
per_block_scale = per_block_scale.to(torch.float8_e4m3fn)
Copy link
Copy Markdown
Contributor

@realAsma realAsma May 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we create a helper method which does the FP8 quantization of per_tensor scale and use that here and here https://github.com/NVIDIA/Model-Optimizer/pull/1382/changes#r3191334011

Comment on lines +429 to +447
# Enqueue per-expert quantizers from {param}_weight_quantizers ModuleLists.
if _qfe_cls is not None and isinstance(parent_module, _qfe_cls):
for param_name, param in parent_module.named_parameters(recurse=False):
qlist = getattr(parent_module, f"{param_name}_weight_quantizers", None)
if not isinstance(qlist, nn.ModuleList):
continue
if len(qlist) != param.shape[0]:
warnings.warn(
f"Skipping {param_name}_weight_quantizers: list length {len(qlist)} "
f"does not match parameter leading dimension {param.shape[0]}. "
"This may indicate a misconfigured fused-experts module.",
stacklevel=2,
)
continue
for expert_idx, wq in enumerate(qlist):
if isinstance(wq, TensorQuantizer) and wq.is_enabled:
if getattr(wq, "_calibrator", None) is not None:
weight_quantizers.append((parent_module, (param_name, expert_idx), wq))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have a helper method get_weight_quantizers(module) which can support both MoE and regular weight quantizers? This will help avoid the code branching here

cal = getattr(module, "_calibrator", None)
if cal and not getattr(module, "_dynamic", False):
if method in {"entropy"}:
if method == "entropy":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we rebase this PR to the composable recipe PR - #1253

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants