[OMNIML-3495] Add TEGroupedMLP export support for NemotronH models#967
[OMNIML-3495] Add TEGroupedMLP export support for NemotronH models#967yueshen2016 wants to merge 1 commit intomainfrom
Conversation
📝 WalkthroughWalkthroughAdds a new GroupedMLPSlicing custom mapping and integrates grouped MLP export handling into Nemotron mappings and the unified Megatron exporter, implementing grouped/per-expert TEGroupedMLP weight slicing, quantization-aware state-dict emission, and fused-norm export paths. Changes
Sequence Diagram(s)sequenceDiagram
participant Pipeline as Export Pipeline
participant Mapping as Rule Mapping
participant Handler as grouped_mlp_slicing
participant Module as TEGroupedMLP
participant State as State Dict
Pipeline->>Mapping: lookup "grouped_mlp_slicing"
Mapping->>Handler: invoke(module, prefix, parallel_config)
Handler->>Module: inspect local_experts / fused weights / quant state
alt per-expert path
Handler->>Module: iterate experts
loop for each expert
Handler->>State: emit expert weight, scales, metadata
end
else fused/grouped path
Handler->>Module: read grouped linear_fc1/linear_fc2
Handler->>State: slice grouped weights into per-expert tensors
end
Handler->>State: emit fused_norm if present and propagate non-weight metadata
State-->>Pipeline: populated state dict
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
modelopt/torch/export/unified_export_megatron.py (1)
871-931: Please add focused tests for the new grouped slicing path.Recommended coverage: non-quantized export, quantized export, missing expert-weight key behavior, and cleanup of temporary
module.weighton exceptions.If you want, I can draft a pytest matrix for these cases in a follow-up.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/export/unified_export_megatron.py` around lines 871 - 931, Add focused pytest unit tests for _grouped_mlp_slicing covering: (1) non-quantized export where module has weight0..weightN and exported per-expert "weight" entries are correct; (2) quantized export where _get_quantized_state returns qformat/weight_scale(s) and exported per-expert "weight", "weight_scale", and "weight_scale_2" use to_quantized_weight and cloned scales; (3) behavior when an expert weight key (e.g., "weight2") is missing from module.state_dict — ensure slicing skips that expert and others still export; and (4) cleanup when _get_quantized_state or to_quantized_weight raises: ensure temporary assignment of module.weight (done in _grouped_mlp_slicing) is removed after the call even on exception. In tests, instantiate a minimal TEGroupedMLP-like object with num_gemms, weight0..weightN in state_dict, control _get_quantized_state via monkeypatch or fixture to simulate quantized/non-quantized returns, call _grouped_mlp_slicing(prefix=...) and assert resulting self._state_dict keys/values and that module has no lingering "weight" attribute after success or exception.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/export/unified_export_megatron.py`:
- Around line 494-497: Update the inline comment describing the TEGroupedMLP
export path to reflect the correct mapping type: change the reference from
"GroupedMLPMerging" to "GroupedMLPSlicing" and clarify that the export uses the
"experts.linear_fc1" rule with GroupedMLPSlicing (not
"local_experts.linear_fc1"); modify the comment around TEGroupedMLP /
experts.linear_fc1 to name GroupedMLPSlicing so it matches the actual
implementation.
- Around line 883-907: The temporary assignment of module.weight =
module.weight0 before calling _get_quantized_state is not exception-safe and may
leave the alias in place if an error occurs; wrap the assignment and subsequent
calls to _get_quantized_state and _get_weight_scales in a try/finally so you
always deleteattr(module, "weight") when it was not originally present. Also
change the expert loop (using expert_prefix and weight_key =
f"weight{expert_id}") to fail fast instead of continuing silently when a
weight_key is missing from module.state_dict()—raise a clear exception
(including the expert_id and expert_prefix) so incomplete checkpoints are not
exported unnoticed.
---
Nitpick comments:
In `@modelopt/torch/export/unified_export_megatron.py`:
- Around line 871-931: Add focused pytest unit tests for _grouped_mlp_slicing
covering: (1) non-quantized export where module has weight0..weightN and
exported per-expert "weight" entries are correct; (2) quantized export where
_get_quantized_state returns qformat/weight_scale(s) and exported per-expert
"weight", "weight_scale", and "weight_scale_2" use to_quantized_weight and
cloned scales; (3) behavior when an expert weight key (e.g., "weight2") is
missing from module.state_dict — ensure slicing skips that expert and others
still export; and (4) cleanup when _get_quantized_state or to_quantized_weight
raises: ensure temporary assignment of module.weight (done in
_grouped_mlp_slicing) is removed after the call even on exception. In tests,
instantiate a minimal TEGroupedMLP-like object with num_gemms, weight0..weightN
in state_dict, control _get_quantized_state via monkeypatch or fixture to
simulate quantized/non-quantized returns, call _grouped_mlp_slicing(prefix=...)
and assert resulting self._state_dict keys/values and that module has no
lingering "weight" attribute after success or exception.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: f82ada24-3356-48b5-b6c1-453fc3152768
📒 Files selected for processing (3)
modelopt/torch/export/plugins/mcore_custom.pymodelopt/torch/export/plugins/mcore_nemotron.pymodelopt/torch/export/unified_export_megatron.py
| # TEGroupedLinear doesn't have module.weight (it has weight0, weight1, ...). | ||
| # Temporarily assign weight = weight0 so _get_quantized_state can extract | ||
| # qformat, scales, and input_scale from the module's quantizers. | ||
| has_weight = hasattr(module, "weight") | ||
| if not has_weight: | ||
| module.weight = module.weight0 | ||
|
|
||
| name_to_value, qformat, block_size = self._get_quantized_state( | ||
| module, self.dtype, prefix=prefix | ||
| ) | ||
| weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat) | ||
| name_to_value.pop("weight", None) | ||
|
|
||
| if not has_weight: | ||
| delattr(module, "weight") | ||
|
|
||
| state_dict = module.state_dict() | ||
|
|
||
| for expert_id in range(num_experts): | ||
| expert_prefix = prefix.format(expert_id) + "." | ||
| weight_key = f"weight{expert_id}" | ||
|
|
||
| if weight_key not in state_dict: | ||
| continue | ||
|
|
There was a problem hiding this comment.
Make temporary module.weight patch exception-safe and fail fast on missing expert weights.
If _get_quantized_state(...) raises, the temporary module.weight alias can be left behind. Also, silently continue on missing weight{i} can export incomplete checkpoints without surfacing an error.
Proposed fix
- has_weight = hasattr(module, "weight")
- if not has_weight:
- module.weight = module.weight0
-
- name_to_value, qformat, block_size = self._get_quantized_state(
- module, self.dtype, prefix=prefix
- )
- weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat)
- name_to_value.pop("weight", None)
-
- if not has_weight:
- delattr(module, "weight")
+ has_weight = hasattr(module, "weight")
+ if not has_weight:
+ module.weight = module.weight0
+ try:
+ name_to_value, qformat, block_size = self._get_quantized_state(
+ module, self.dtype, prefix=prefix
+ )
+ weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat)
+ name_to_value.pop("weight", None)
+ finally:
+ if not has_weight and hasattr(module, "weight"):
+ delattr(module, "weight")
@@
- if weight_key not in state_dict:
- continue
+ if weight_key not in state_dict:
+ raise ValueError(f"Missing expected TEGroupedMLP expert weight: {weight_key}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/export/unified_export_megatron.py` around lines 883 - 907, The
temporary assignment of module.weight = module.weight0 before calling
_get_quantized_state is not exception-safe and may leave the alias in place if
an error occurs; wrap the assignment and subsequent calls to
_get_quantized_state and _get_weight_scales in a try/finally so you always
deleteattr(module, "weight") when it was not originally present. Also change the
expert loop (using expert_prefix and weight_key = f"weight{expert_id}") to fail
fast instead of continuing silently when a weight_key is missing from
module.state_dict()—raise a clear exception (including the expert_id and
expert_prefix) so incomplete checkpoints are not exported unnoticed.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #967 +/- ##
==========================================
+ Coverage 72.12% 72.14% +0.02%
==========================================
Files 209 209
Lines 23628 23667 +39
==========================================
+ Hits 17042 17075 +33
- Misses 6586 6592 +6 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
d076a18 to
c3e9f46
Compare
Signed-off-by: James Shen <yueshen@nvidia.com>
c3e9f46 to
d879084
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
modelopt/torch/export/unified_export_megatron.py (1)
892-903:⚠️ Potential issue | 🟠 MajorMake the temporary
module.weightalias exception-safe and fail fast on missing expert weights.If
_get_quantized_state(...)or_get_weight_scales(...)throws, Line 894 can leavemodule.weightbehind. And Lines 911-912 silently drop experts, which can produce an incomplete checkpoint with no error.Suggested fix
has_weight = hasattr(module, "weight") if not has_weight: module.weight = module.weight0 - - name_to_value, qformat, block_size = self._get_quantized_state( - module, self.dtype, prefix=prefix - ) - weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat) - name_to_value.pop("weight", None) - - if not has_weight: - delattr(module, "weight") + try: + name_to_value, qformat, block_size = self._get_quantized_state( + module, self.dtype, prefix=prefix + ) + weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat) + name_to_value.pop("weight", None) + finally: + if not has_weight and hasattr(module, "weight"): + delattr(module, "weight") @@ - if weight_key not in state_dict: - continue + if weight_key not in state_dict: + raise ValueError( + f"Missing expected TEGroupedMLP expert weight {weight_key!r} for {expert_prefix}" + )Also applies to: 911-912
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/export/unified_export_megatron.py` around lines 892 - 903, Wrap the temporary aliasing of module.weight (where code sets module.weight = module.weight0 when has_weight is False) in a try/finally so that module.weight is always removed in the finally block even if _get_quantized_state or _get_weight_scales throws; locate the aliasing and restoration around the calls to _get_quantized_state and _get_weight_scales and move name_to_value, qformat, block_size assignment into the try. Also replace the silent pop of "weight" from name_to_value (name_to_value.pop("weight", None)) with an explicit existence check and raise a clear exception if expected expert weights are missing so the export fails fast rather than producing an incomplete checkpoint.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/export/unified_export_megatron.py`:
- Around line 919-929: The export is quantizing each 2D expert weight using the
full grouped-module scale tensors, which breaks when TEGroupedLinear provides
per-expert (batched) scales; modify the export to slice the per-expert portions
of weight_scale (and weight_scale_2) before calling to_quantized_weight so the
scale shape matches the 2D weight being quantized. Locate the block that assigns
self._state_dict[expert_prefix + "weight"] and self._state_dict[expert_prefix +
"weight_scale(_2)"], slice weight_scale (and weight_scale_2 when not None) to
the single-expert index corresponding to the current expert_prefix, then pass
those sliced tensors to to_quantized_weight and store the sliced clones in the
state dict to ensure correct broadcasting and export for per-expert scales.
---
Duplicate comments:
In `@modelopt/torch/export/unified_export_megatron.py`:
- Around line 892-903: Wrap the temporary aliasing of module.weight (where code
sets module.weight = module.weight0 when has_weight is False) in a try/finally
so that module.weight is always removed in the finally block even if
_get_quantized_state or _get_weight_scales throws; locate the aliasing and
restoration around the calls to _get_quantized_state and _get_weight_scales and
move name_to_value, qformat, block_size assignment into the try. Also replace
the silent pop of "weight" from name_to_value (name_to_value.pop("weight",
None)) with an explicit existence check and raise a clear exception if expected
expert weights are missing so the export fails fast rather than producing an
incomplete checkpoint.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 4fc251f3-1c3c-4d69-bc1a-e267f4f500cd
📒 Files selected for processing (3)
modelopt/torch/export/plugins/mcore_custom.pymodelopt/torch/export/plugins/mcore_nemotron.pymodelopt/torch/export/unified_export_megatron.py
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/torch/export/plugins/mcore_nemotron.py
| self._state_dict[expert_prefix + "weight"] = to_quantized_weight( | ||
| weight, | ||
| weight_scale, | ||
| qformat, | ||
| weight_scale_2, | ||
| block_size, | ||
| ) | ||
| self._state_dict[expert_prefix + "weight_scale"] = weight_scale.detach().clone() | ||
|
|
||
| if weight_scale_2 is not None: | ||
| self._state_dict[expert_prefix + "weight_scale_2"] = weight_scale_2.detach().clone() |
There was a problem hiding this comment.
Slice expert-local scales before quantizing each expert.
On Lines 919-924, each 2D expert weight is quantized with the full weight_scale tensor from the grouped module. That only works if the scale is truly shared. If TEGroupedLinear exposes batched per-expert scales, to_quantized_weight(...) only handles those batched scales when the weight itself is 3D; reusing the full scale tensor here will mis-broadcast or fail for quantized exports.
Suggested fix
for expert_id in range(num_experts):
expert_prefix = prefix.format(expert_id) + "."
weight_key = f"weight{expert_id}"
@@
weight = state_dict[weight_key].to(self.dtype).cpu()
+ expert_weight_scale = weight_scale
+ if (
+ weight_scale is not None
+ and weight_scale.dim() > 0
+ and weight_scale.shape[0] == num_experts
+ ):
+ expert_weight_scale = weight_scale[expert_id]
+
+ expert_weight_scale_2 = weight_scale_2
+ if (
+ weight_scale_2 is not None
+ and weight_scale_2.dim() > 0
+ and weight_scale_2.shape[0] == num_experts
+ ):
+ expert_weight_scale_2 = weight_scale_2[expert_id]
if weight_scale is None:
self._state_dict[expert_prefix + "weight"] = weight
else:
self._state_dict[expert_prefix + "weight"] = to_quantized_weight(
weight,
- weight_scale,
+ expert_weight_scale,
qformat,
- weight_scale_2,
+ expert_weight_scale_2,
block_size,
)
- self._state_dict[expert_prefix + "weight_scale"] = weight_scale.detach().clone()
+ self._state_dict[expert_prefix + "weight_scale"] = (
+ expert_weight_scale.detach().clone()
+ )
- if weight_scale_2 is not None:
- self._state_dict[expert_prefix + "weight_scale_2"] = weight_scale_2.detach().clone()
+ if expert_weight_scale_2 is not None:
+ self._state_dict[expert_prefix + "weight_scale_2"] = (
+ expert_weight_scale_2.detach().clone()
+ )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/export/unified_export_megatron.py` around lines 919 - 929, The
export is quantizing each 2D expert weight using the full grouped-module scale
tensors, which breaks when TEGroupedLinear provides per-expert (batched) scales;
modify the export to slice the per-expert portions of weight_scale (and
weight_scale_2) before calling to_quantized_weight so the scale shape matches
the 2D weight being quantized. Locate the block that assigns
self._state_dict[expert_prefix + "weight"] and self._state_dict[expert_prefix +
"weight_scale(_2)"], slice weight_scale (and weight_scale_2 when not None) to
the single-expert index corresponding to the current expert_prefix, then pass
those sliced tensors to to_quantized_weight and store the sliced clones in the
state dict to ensure correct broadcasting and export for per-expert scales.
What does this PR do?
Type of change: New feature
Add export support for
TEGroupedMLP(fused grouped GEMM experts) in the MCore-to-HuggingFace checkpoint exporter. Previously, the exporter only supportedSequentialMLP(which haslocal_expertsas aModuleList).TEGroupedMLPstores per-expert weights asweight0,weight1, ...,weight{N-1}in a singleTEGroupedLinearmodule instead. This caused anAttributeError: 'QuantTEGroupedMLP' object has no attribute 'local_experts'when exporting NemotronH models.Changes:
GroupedMLPSlicingclass inmcore_custom.py— the export counterpart ofGroupedMLPMerging_grouped_mlp_slicingmethod inGPTModelExporterthat iteratesTEGroupedLinear's per-expert weights and exports them as individual HF-format weights with proper quantization scale handling"experts.linear_fc1"and"experts.linear_fc2"rules usingGroupedMLPSlicingtonemotron_h_causal_lm_exportTEGroupedMLP(detected by absence oflocal_expertsattribute) to the new"experts.linear_fc1"rule in_get_transformer_layer_state_dictUsage
No API change. NemotronH models using
TEGroupedMLPcan now be exported:Testing
Inside Model-Bridge
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True, usingtorch.load(..., weights_only=True), avoidingpickle, etc.).SequentialMLP(local_experts) path is guarded byhasattr(layer.mlp.experts, "local_experts")and remains unchanged. The newTEGroupedMLPpath only activates whenlocal_expertsis absent and"experts.linear_fc1"is defined in the architecture's rules._grouped_mlp_slicing.Additional Information
GroupedMLPMerging/_grouped_mlp_merging) was added by @jennifchen in PR Latent MOE & Repeated MTP support for NemotronH; fix KV cache quant export #830. This PR completes the round-trip by adding the export side._grouped_mlp_slicingtemporarily assignsmodule.weight = module.weight0so that_get_quantized_statecan extract qformat/scales from the module's quantizers, then removes it afterward. This follows the same pattern used by_QuantTEGroupedLinear._setup()in the quantization plugin.Summary by CodeRabbit