Skip to content

[OMNIML-3495] Add TEGroupedMLP export support for NemotronH models#967

Open
yueshen2016 wants to merge 1 commit intomainfrom
yueshen/Support-Nemotron-Export
Open

[OMNIML-3495] Add TEGroupedMLP export support for NemotronH models#967
yueshen2016 wants to merge 1 commit intomainfrom
yueshen/Support-Nemotron-Export

Conversation

@yueshen2016
Copy link
Contributor

@yueshen2016 yueshen2016 commented Mar 4, 2026

What does this PR do?

Type of change: New feature

Add export support for TEGroupedMLP (fused grouped GEMM experts) in the MCore-to-HuggingFace checkpoint exporter. Previously, the exporter only supported SequentialMLP (which has local_experts as a ModuleList). TEGroupedMLP stores per-expert weights as weight0, weight1, ..., weight{N-1} in a single TEGroupedLinear module instead. This caused an AttributeError: 'QuantTEGroupedMLP' object has no attribute 'local_experts' when exporting NemotronH models.

Changes:

  • Add GroupedMLPSlicing class in mcore_custom.py — the export counterpart of GroupedMLPMerging
  • Add _grouped_mlp_slicing method in GPTModelExporter that iterates TEGroupedLinear's per-expert weights and exports them as individual HF-format weights with proper quantization scale handling
  • Add "experts.linear_fc1" and "experts.linear_fc2" rules using GroupedMLPSlicing to nemotron_h_causal_lm_export
  • Route TEGroupedMLP (detected by absence of local_experts attribute) to the new "experts.linear_fc1" rule in _get_transformer_layer_state_dict

Usage

No API change. NemotronH models using TEGroupedMLP can now be exported:

import modelopt.torch.export as mtex

mtex.export_mcore_gpt_to_hf(
    model=megatron_model,
    export_dir="/path/to/hf_export",
    pretrained_model_name_or_path="/path/to/hf_model",
)

Testing

Inside Model-Bridge

torchrun --nproc_per_node 4 examples/quantization/export.py \
    --hf-model-id /models/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/ \
    --megatron-load-path /models/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4-MLM \
    --export-dir /models/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4-MLM_hf \
    --pp 4 \
    --dtype bfloat16 \
    --trust-remote-code

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, using torch.load(..., weights_only=True), avoiding pickle, etc.).

  • Is this change backward compatible?: ✅ The existing SequentialMLP (local_experts) path is guarded by hasattr(layer.mlp.experts, "local_experts") and remains unchanged. The new TEGroupedMLP path only activates when local_experts is absent and "experts.linear_fc1" is defined in the architecture's rules.
  • If you copied code from any other source, did you follow IP policy in CONTRIBUTING.md?: N/A
  • Did you write any new necessary tests?: ❌ Tested manually with Nemotron-3-Nano-30B-A3B. Unit test coverage should be added for _grouped_mlp_slicing.
  • Did you update Changelog?: ❌ New feature for a specific model architecture.

Additional Information

  • The import counterpart (GroupedMLPMerging / _grouped_mlp_merging) was added by @jennifchen in PR Latent MOE & Repeated MTP support for NemotronH; fix KV cache quant export #830. This PR completes the round-trip by adding the export side.
  • _grouped_mlp_slicing temporarily assigns module.weight = module.weight0 so that _get_quantized_state can extract qformat/scales from the module's quantizers, then removes it afterward. This follows the same pattern used by _QuantTEGroupedLinear._setup() in the quantization plugin.

Summary by CodeRabbit

  • New Features
    • Export now supports grouped-expert weight slicing: fused expert weights can be split into per-expert tensors for downstream formats.
    • Per-expert export handling added with fallbacks between packed and per-expert layouts.
    • Nemotron H causal LM import/export improved to align grouped local-expert mappings.
    • Added export support for fused normalization metadata and safer remote-code model loading.

@yueshen2016 yueshen2016 requested a review from a team as a code owner March 4, 2026 09:01
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 4, 2026

📝 Walkthrough

Walkthrough

Adds a new GroupedMLPSlicing custom mapping and integrates grouped MLP export handling into Nemotron mappings and the unified Megatron exporter, implementing grouped/per-expert TEGroupedMLP weight slicing, quantization-aware state-dict emission, and fused-norm export paths.

Changes

Cohort / File(s) Summary
Grouped MLP Mapping
modelopt/torch/export/plugins/mcore_custom.py
Adds GroupedMLPSlicing(CustomModuleMapping) with __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] = {}) registering func_name="grouped_mlp_slicing".
Nemotron mappings
modelopt/torch/export/plugins/mcore_nemotron.py
Imports GroupedMLPSlicing and adds "fused_norm" plus grouped local-expert mappings for "experts.linear_fc1" and "experts.linear_fc2" in both import and export mapping dictionaries.
Unified Megatron exporter
modelopt/torch/export/unified_export_megatron.py
Adds GPTModelExporter._grouped_mlp_slicing(...), registers "grouped_mlp_slicing" in custom mappings, implements grouped vs per-expert export flows (including TEGroupedMLP fused-weight slicing), handles quantization scales, and emits per-expert state-dict entries and fused-norm paths.

Sequence Diagram(s)

sequenceDiagram
    participant Pipeline as Export Pipeline
    participant Mapping as Rule Mapping
    participant Handler as grouped_mlp_slicing
    participant Module as TEGroupedMLP
    participant State as State Dict

    Pipeline->>Mapping: lookup "grouped_mlp_slicing"
    Mapping->>Handler: invoke(module, prefix, parallel_config)
    Handler->>Module: inspect local_experts / fused weights / quant state
    alt per-expert path
        Handler->>Module: iterate experts
        loop for each expert
            Handler->>State: emit expert weight, scales, metadata
        end
    else fused/grouped path
        Handler->>Module: read grouped linear_fc1/linear_fc2
        Handler->>State: slice grouped weights into per-expert tensors
    end
    Handler->>State: emit fused_norm if present and propagate non-weight metadata
    State-->>Pipeline: populated state dict
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 40.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main change: adding TEGroupedMLP export support for NemotronH models, which is the core functionality introduced across all three modified files.
Security Anti-Patterns ✅ Passed Pull request adheres to all SECURITY.md practices: trust_remote_code properly exposed with False default, no unsafe deserialization patterns found, no eval/exec or bypass comments used.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch yueshen/Support-Nemotron-Export

Tip

Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs).
Share your feedback on Discord.


Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
modelopt/torch/export/unified_export_megatron.py (1)

871-931: Please add focused tests for the new grouped slicing path.

Recommended coverage: non-quantized export, quantized export, missing expert-weight key behavior, and cleanup of temporary module.weight on exceptions.

If you want, I can draft a pytest matrix for these cases in a follow-up.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/unified_export_megatron.py` around lines 871 - 931, Add
focused pytest unit tests for _grouped_mlp_slicing covering: (1) non-quantized
export where module has weight0..weightN and exported per-expert "weight"
entries are correct; (2) quantized export where _get_quantized_state returns
qformat/weight_scale(s) and exported per-expert "weight", "weight_scale", and
"weight_scale_2" use to_quantized_weight and cloned scales; (3) behavior when an
expert weight key (e.g., "weight2") is missing from module.state_dict — ensure
slicing skips that expert and others still export; and (4) cleanup when
_get_quantized_state or to_quantized_weight raises: ensure temporary assignment
of module.weight (done in _grouped_mlp_slicing) is removed after the call even
on exception. In tests, instantiate a minimal TEGroupedMLP-like object with
num_gemms, weight0..weightN in state_dict, control _get_quantized_state via
monkeypatch or fixture to simulate quantized/non-quantized returns, call
_grouped_mlp_slicing(prefix=...) and assert resulting self._state_dict
keys/values and that module has no lingering "weight" attribute after success or
exception.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/export/unified_export_megatron.py`:
- Around line 494-497: Update the inline comment describing the TEGroupedMLP
export path to reflect the correct mapping type: change the reference from
"GroupedMLPMerging" to "GroupedMLPSlicing" and clarify that the export uses the
"experts.linear_fc1" rule with GroupedMLPSlicing (not
"local_experts.linear_fc1"); modify the comment around TEGroupedMLP /
experts.linear_fc1 to name GroupedMLPSlicing so it matches the actual
implementation.
- Around line 883-907: The temporary assignment of module.weight =
module.weight0 before calling _get_quantized_state is not exception-safe and may
leave the alias in place if an error occurs; wrap the assignment and subsequent
calls to _get_quantized_state and _get_weight_scales in a try/finally so you
always deleteattr(module, "weight") when it was not originally present. Also
change the expert loop (using expert_prefix and weight_key =
f"weight{expert_id}") to fail fast instead of continuing silently when a
weight_key is missing from module.state_dict()—raise a clear exception
(including the expert_id and expert_prefix) so incomplete checkpoints are not
exported unnoticed.

---

Nitpick comments:
In `@modelopt/torch/export/unified_export_megatron.py`:
- Around line 871-931: Add focused pytest unit tests for _grouped_mlp_slicing
covering: (1) non-quantized export where module has weight0..weightN and
exported per-expert "weight" entries are correct; (2) quantized export where
_get_quantized_state returns qformat/weight_scale(s) and exported per-expert
"weight", "weight_scale", and "weight_scale_2" use to_quantized_weight and
cloned scales; (3) behavior when an expert weight key (e.g., "weight2") is
missing from module.state_dict — ensure slicing skips that expert and others
still export; and (4) cleanup when _get_quantized_state or to_quantized_weight
raises: ensure temporary assignment of module.weight (done in
_grouped_mlp_slicing) is removed after the call even on exception. In tests,
instantiate a minimal TEGroupedMLP-like object with num_gemms, weight0..weightN
in state_dict, control _get_quantized_state via monkeypatch or fixture to
simulate quantized/non-quantized returns, call _grouped_mlp_slicing(prefix=...)
and assert resulting self._state_dict keys/values and that module has no
lingering "weight" attribute after success or exception.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: f82ada24-3356-48b5-b6c1-453fc3152768

📥 Commits

Reviewing files that changed from the base of the PR and between a34d613 and d076a18.

📒 Files selected for processing (3)
  • modelopt/torch/export/plugins/mcore_custom.py
  • modelopt/torch/export/plugins/mcore_nemotron.py
  • modelopt/torch/export/unified_export_megatron.py

Comment on lines +883 to +907
# TEGroupedLinear doesn't have module.weight (it has weight0, weight1, ...).
# Temporarily assign weight = weight0 so _get_quantized_state can extract
# qformat, scales, and input_scale from the module's quantizers.
has_weight = hasattr(module, "weight")
if not has_weight:
module.weight = module.weight0

name_to_value, qformat, block_size = self._get_quantized_state(
module, self.dtype, prefix=prefix
)
weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat)
name_to_value.pop("weight", None)

if not has_weight:
delattr(module, "weight")

state_dict = module.state_dict()

for expert_id in range(num_experts):
expert_prefix = prefix.format(expert_id) + "."
weight_key = f"weight{expert_id}"

if weight_key not in state_dict:
continue

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Make temporary module.weight patch exception-safe and fail fast on missing expert weights.

If _get_quantized_state(...) raises, the temporary module.weight alias can be left behind. Also, silently continue on missing weight{i} can export incomplete checkpoints without surfacing an error.

Proposed fix
-        has_weight = hasattr(module, "weight")
-        if not has_weight:
-            module.weight = module.weight0
-
-        name_to_value, qformat, block_size = self._get_quantized_state(
-            module, self.dtype, prefix=prefix
-        )
-        weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat)
-        name_to_value.pop("weight", None)
-
-        if not has_weight:
-            delattr(module, "weight")
+        has_weight = hasattr(module, "weight")
+        if not has_weight:
+            module.weight = module.weight0
+        try:
+            name_to_value, qformat, block_size = self._get_quantized_state(
+                module, self.dtype, prefix=prefix
+            )
+            weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat)
+            name_to_value.pop("weight", None)
+        finally:
+            if not has_weight and hasattr(module, "weight"):
+                delattr(module, "weight")
@@
-            if weight_key not in state_dict:
-                continue
+            if weight_key not in state_dict:
+                raise ValueError(f"Missing expected TEGroupedMLP expert weight: {weight_key}")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/unified_export_megatron.py` around lines 883 - 907, The
temporary assignment of module.weight = module.weight0 before calling
_get_quantized_state is not exception-safe and may leave the alias in place if
an error occurs; wrap the assignment and subsequent calls to
_get_quantized_state and _get_weight_scales in a try/finally so you always
deleteattr(module, "weight") when it was not originally present. Also change the
expert loop (using expert_prefix and weight_key = f"weight{expert_id}") to fail
fast instead of continuing silently when a weight_key is missing from
module.state_dict()—raise a clear exception (including the expert_id and
expert_prefix) so incomplete checkpoints are not exported unnoticed.

@codecov
Copy link

codecov bot commented Mar 4, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 72.14%. Comparing base (a076e6c) to head (d879084).
⚠️ Report is 13 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #967      +/-   ##
==========================================
+ Coverage   72.12%   72.14%   +0.02%     
==========================================
  Files         209      209              
  Lines       23628    23667      +39     
==========================================
+ Hits        17042    17075      +33     
- Misses       6586     6592       +6     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@yueshen2016 yueshen2016 force-pushed the yueshen/Support-Nemotron-Export branch from d076a18 to c3e9f46 Compare March 4, 2026 09:17
Signed-off-by: James Shen <yueshen@nvidia.com>
@yueshen2016 yueshen2016 force-pushed the yueshen/Support-Nemotron-Export branch from c3e9f46 to d879084 Compare March 7, 2026 00:24
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
modelopt/torch/export/unified_export_megatron.py (1)

892-903: ⚠️ Potential issue | 🟠 Major

Make the temporary module.weight alias exception-safe and fail fast on missing expert weights.

If _get_quantized_state(...) or _get_weight_scales(...) throws, Line 894 can leave module.weight behind. And Lines 911-912 silently drop experts, which can produce an incomplete checkpoint with no error.

Suggested fix
         has_weight = hasattr(module, "weight")
         if not has_weight:
             module.weight = module.weight0
-
-        name_to_value, qformat, block_size = self._get_quantized_state(
-            module, self.dtype, prefix=prefix
-        )
-        weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat)
-        name_to_value.pop("weight", None)
-
-        if not has_weight:
-            delattr(module, "weight")
+        try:
+            name_to_value, qformat, block_size = self._get_quantized_state(
+                module, self.dtype, prefix=prefix
+            )
+            weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat)
+            name_to_value.pop("weight", None)
+        finally:
+            if not has_weight and hasattr(module, "weight"):
+                delattr(module, "weight")
@@
-            if weight_key not in state_dict:
-                continue
+            if weight_key not in state_dict:
+                raise ValueError(
+                    f"Missing expected TEGroupedMLP expert weight {weight_key!r} for {expert_prefix}"
+                )

Also applies to: 911-912

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/unified_export_megatron.py` around lines 892 - 903,
Wrap the temporary aliasing of module.weight (where code sets module.weight =
module.weight0 when has_weight is False) in a try/finally so that module.weight
is always removed in the finally block even if _get_quantized_state or
_get_weight_scales throws; locate the aliasing and restoration around the calls
to _get_quantized_state and _get_weight_scales and move name_to_value, qformat,
block_size assignment into the try. Also replace the silent pop of "weight" from
name_to_value (name_to_value.pop("weight", None)) with an explicit existence
check and raise a clear exception if expected expert weights are missing so the
export fails fast rather than producing an incomplete checkpoint.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/export/unified_export_megatron.py`:
- Around line 919-929: The export is quantizing each 2D expert weight using the
full grouped-module scale tensors, which breaks when TEGroupedLinear provides
per-expert (batched) scales; modify the export to slice the per-expert portions
of weight_scale (and weight_scale_2) before calling to_quantized_weight so the
scale shape matches the 2D weight being quantized. Locate the block that assigns
self._state_dict[expert_prefix + "weight"] and self._state_dict[expert_prefix +
"weight_scale(_2)"], slice weight_scale (and weight_scale_2 when not None) to
the single-expert index corresponding to the current expert_prefix, then pass
those sliced tensors to to_quantized_weight and store the sliced clones in the
state dict to ensure correct broadcasting and export for per-expert scales.

---

Duplicate comments:
In `@modelopt/torch/export/unified_export_megatron.py`:
- Around line 892-903: Wrap the temporary aliasing of module.weight (where code
sets module.weight = module.weight0 when has_weight is False) in a try/finally
so that module.weight is always removed in the finally block even if
_get_quantized_state or _get_weight_scales throws; locate the aliasing and
restoration around the calls to _get_quantized_state and _get_weight_scales and
move name_to_value, qformat, block_size assignment into the try. Also replace
the silent pop of "weight" from name_to_value (name_to_value.pop("weight",
None)) with an explicit existence check and raise a clear exception if expected
expert weights are missing so the export fails fast rather than producing an
incomplete checkpoint.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 4fc251f3-1c3c-4d69-bc1a-e267f4f500cd

📥 Commits

Reviewing files that changed from the base of the PR and between c3e9f46 and d879084.

📒 Files selected for processing (3)
  • modelopt/torch/export/plugins/mcore_custom.py
  • modelopt/torch/export/plugins/mcore_nemotron.py
  • modelopt/torch/export/unified_export_megatron.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/export/plugins/mcore_nemotron.py

Comment on lines +919 to +929
self._state_dict[expert_prefix + "weight"] = to_quantized_weight(
weight,
weight_scale,
qformat,
weight_scale_2,
block_size,
)
self._state_dict[expert_prefix + "weight_scale"] = weight_scale.detach().clone()

if weight_scale_2 is not None:
self._state_dict[expert_prefix + "weight_scale_2"] = weight_scale_2.detach().clone()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Slice expert-local scales before quantizing each expert.

On Lines 919-924, each 2D expert weight is quantized with the full weight_scale tensor from the grouped module. That only works if the scale is truly shared. If TEGroupedLinear exposes batched per-expert scales, to_quantized_weight(...) only handles those batched scales when the weight itself is 3D; reusing the full scale tensor here will mis-broadcast or fail for quantized exports.

Suggested fix
         for expert_id in range(num_experts):
             expert_prefix = prefix.format(expert_id) + "."
             weight_key = f"weight{expert_id}"
@@
             weight = state_dict[weight_key].to(self.dtype).cpu()
+            expert_weight_scale = weight_scale
+            if (
+                weight_scale is not None
+                and weight_scale.dim() > 0
+                and weight_scale.shape[0] == num_experts
+            ):
+                expert_weight_scale = weight_scale[expert_id]
+
+            expert_weight_scale_2 = weight_scale_2
+            if (
+                weight_scale_2 is not None
+                and weight_scale_2.dim() > 0
+                and weight_scale_2.shape[0] == num_experts
+            ):
+                expert_weight_scale_2 = weight_scale_2[expert_id]
 
             if weight_scale is None:
                 self._state_dict[expert_prefix + "weight"] = weight
             else:
                 self._state_dict[expert_prefix + "weight"] = to_quantized_weight(
                     weight,
-                    weight_scale,
+                    expert_weight_scale,
                     qformat,
-                    weight_scale_2,
+                    expert_weight_scale_2,
                     block_size,
                 )
-                self._state_dict[expert_prefix + "weight_scale"] = weight_scale.detach().clone()
+                self._state_dict[expert_prefix + "weight_scale"] = (
+                    expert_weight_scale.detach().clone()
+                )
 
-            if weight_scale_2 is not None:
-                self._state_dict[expert_prefix + "weight_scale_2"] = weight_scale_2.detach().clone()
+            if expert_weight_scale_2 is not None:
+                self._state_dict[expert_prefix + "weight_scale_2"] = (
+                    expert_weight_scale_2.detach().clone()
+                )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/unified_export_megatron.py` around lines 919 - 929, The
export is quantizing each 2D expert weight using the full grouped-module scale
tensors, which breaks when TEGroupedLinear provides per-expert (batched) scales;
modify the export to slice the per-expert portions of weight_scale (and
weight_scale_2) before calling to_quantized_weight so the scale shape matches
the 2D weight being quantized. Locate the block that assigns
self._state_dict[expert_prefix + "weight"] and self._state_dict[expert_prefix +
"weight_scale(_2)"], slice weight_scale (and weight_scale_2 when not None) to
the single-expert index corresponding to the current expert_prefix, then pass
those sliced tensors to to_quantized_weight and store the sliced clones in the
state dict to ensure correct broadcasting and export for per-expert scales.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant