Support EP mcore import for TE Spec and Fix mamba moe config#1342
Support EP mcore import for TE Spec and Fix mamba moe config#1342
Conversation
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
📝 WalkthroughWalkthroughRemaps safetensor global expert indices into TEGroupedMLP local weight slots per expert-parallel rank, requires global expert count divisible by local expert count, relaxes layer-index detection for PP submodule names, extends disabled quantizer patterns, and accepts Megatron HybridModel in exporter input validation. Changes
Sequence Diagram(s)sequenceDiagram
participant Rank as ExpertParallelRank
participant Importer as MegatronImporter
participant Store as SafeTensorStore
participant MLP as TEGroupedMLP
Rank->>Importer: compute init_index = get_expert_model_parallel_rank() * num_local_experts
Importer->>Store: for local_id in 0..num_local_experts-1 load key for global_expert_id = init_index + local_id
Store-->>Importer: return expert weights (global_expert_id)
Importer->>MLP: write expert weights into `weight{local_id}` slots
Note right of MLP: requires total_global_experts % num_local_experts == 0
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/export/plugins/megatron_importer.py`:
- Around line 662-668: Replace the fragile assert-based check with an explicit
validation that raises a clear exception: verify both that num_global_experts is
divisible by num_local_experts and that num_global_experts == num_local_experts
* ep_size (where ep_size is the EP topology size used by
get_expert_model_parallel_rank()); if these conditions fail, raise a ValueError
with a descriptive message mentioning num_global_experts, num_local_experts and
ep_size so the expert slice computation (init_index =
get_expert_model_parallel_rank() * num_local_experts) cannot proceed with an
invalid topology.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: f87a3cc7-689f-4332-a931-506c6d7f0513
📒 Files selected for processing (2)
modelopt/torch/export/plugins/megatron_importer.pymodelopt/torch/quantization/config.py
| assert num_global_experts % num_local_experts == 0, ( | ||
| "num_global_experts must be divisible by num_local_experts " | ||
| "during MoE import" | ||
| ) | ||
| init_index = 0 | ||
| # Each EP rank owns a contiguous slice of global experts: | ||
| # [ep_rank * num_local_experts, (ep_rank + 1) * num_local_experts). | ||
| init_index = get_expert_model_parallel_rank() * num_local_experts |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
cat -n modelopt/torch/export/plugins/megatron_importer.py | sed -n '650,680p'Repository: NVIDIA/Model-Optimizer
Length of output: 2049
🏁 Script executed:
# Check imports and function definitions
cat -n modelopt/torch/export/plugins/megatron_importer.py | head -50Repository: NVIDIA/Model-Optimizer
Length of output: 2174
🏁 Script executed:
# Search for where these functions are defined or imported
rg "get_expert_model_parallel_rank|get_expert_model_parallel_world_size" modelopt/torch/export/plugins/megatron_importer.py -B 2 -A 2Repository: NVIDIA/Model-Optimizer
Length of output: 650
Replace assert with explicit EP-topology validation.
Line 662 uses assert for runtime validation, which Python removes when executed with optimization flags (-O). Additionally, divisibility alone (num_global_experts % num_local_experts == 0) does not guarantee correct expert mapping. The code distributes experts by EP rank using the formula [ep_rank * num_local_experts, (ep_rank + 1) * num_local_experts), which requires num_global_experts == num_local_experts * ep_size. Without this constraint, mismatches between expert count and EP topology can silently produce incorrect indexing.
Suggested fix
- assert num_global_experts % num_local_experts == 0, (
- "num_global_experts must be divisible by num_local_experts "
- "during MoE import"
- )
- # Each EP rank owns a contiguous slice of global experts:
- # [ep_rank * num_local_experts, (ep_rank + 1) * num_local_experts).
- init_index = get_expert_model_parallel_rank() * num_local_experts
+ ep_rank = get_expert_model_parallel_rank()
+ ep_size = get_expert_model_parallel_world_size()
+ if num_global_experts != num_local_experts * ep_size:
+ raise ValueError(
+ "Expected num_global_experts == num_local_experts * ep_size "
+ f"for TEGroupedMLP import, got {num_global_experts=}, "
+ f"{num_local_experts=}, {ep_size=}."
+ )
+ # Each EP rank owns a contiguous slice of global experts:
+ # [ep_rank * num_local_experts, (ep_rank + 1) * num_local_experts).
+ init_index = ep_rank * num_local_experts🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/export/plugins/megatron_importer.py` around lines 662 - 668,
Replace the fragile assert-based check with an explicit validation that raises a
clear exception: verify both that num_global_experts is divisible by
num_local_experts and that num_global_experts == num_local_experts * ep_size
(where ep_size is the EP topology size used by
get_expert_model_parallel_rank()); if these conditions fail, raise a ValueError
with a descriptive message mentioning num_global_experts, num_local_experts and
ep_size so the expert slice computation (init_index =
get_expert_model_parallel_rank() * num_local_experts) cannot proceed with an
invalid topology.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1342 +/- ##
==========================================
- Coverage 74.46% 73.67% -0.80%
==========================================
Files 464 481 +17
Lines 50089 52610 +2521
==========================================
+ Hits 37300 38759 +1459
- Misses 12789 13851 +1062
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
♻️ Duplicate comments (1)
modelopt/torch/export/plugins/megatron_importer.py (1)
661-667:⚠️ Potential issue | 🟠 MajorReplace
assertwith an explicit EP-topology check againstep_size.Two concerns on this block that are still present:
assertat line 661 is removed underpython -O, so the only guard against a misconfigured EP topology disappears in optimized runs.num_global_experts % num_local_experts == 0is necessary but not sufficient. The slicing[ep_rank * num_local_experts, (ep_rank + 1) * num_local_experts)is only well-defined whennum_global_experts == num_local_experts * ep_size. Divisibility alone (e.g.num_global=16, num_local=4, ep_size=2) will silently leave 8 global experts unimported and can also produce overlapping/duplicate reads on other shapes.This mirrors the pattern already used elsewhere in the same codebase (see
modelopt/torch/export/plugins/mcore_custom.pylines 416–445, which explicitly takesep_size = get_expert_model_parallel_world_size()and raisesValueErroron mismatch).🔧 Suggested fix
- num_local_experts = experts.num_local_experts - num_global_experts = experts.config.num_moe_experts - assert num_global_experts % num_local_experts == 0, ( - "num_global_experts must be divisible by num_local_experts " - "during MoE import" - ) - # Each EP rank owns a contiguous slice of global experts: - # [ep_rank * num_local_experts, (ep_rank + 1) * num_local_experts). - init_index = get_expert_model_parallel_rank() * num_local_experts + num_local_experts = experts.num_local_experts + num_global_experts = experts.config.num_moe_experts + ep_rank = get_expert_model_parallel_rank() + ep_size = get_expert_model_parallel_world_size() + if num_global_experts != num_local_experts * ep_size: + raise ValueError( + "TEGroupedMLP import requires " + "num_global_experts == num_local_experts * ep_size, got " + f"{num_global_experts=}, {num_local_experts=}, {ep_size=}." + ) + # Each EP rank owns a contiguous slice of global experts: + # [ep_rank * num_local_experts, (ep_rank + 1) * num_local_experts). + init_index = ep_rank * num_local_expertsThis will also require importing
get_expert_model_parallel_world_sizealongsideget_expert_model_parallel_rankat line 42.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/export/plugins/megatron_importer.py` around lines 661 - 667, Replace the fragile assert with an explicit EP-topology validation: import get_expert_model_parallel_world_size alongside get_expert_model_parallel_rank, compute ep_size = get_expert_model_parallel_world_size(), then check both that num_global_experts % num_local_experts == 0 and that num_global_experts == num_local_experts * ep_size; if the check fails raise a ValueError with a clear message; keep using get_expert_model_parallel_rank() to compute init_index only after the validation passes.
🧹 Nitpick comments (2)
modelopt/torch/quantization/config.py (1)
239-250: LGTM — patterns correctly target both HF and Mcore attention layers.The added
*self_attention.linear_qkv*and*self_attention.linear_proj*entries correctly complement the existing HF-style*q_proj*/*k_proj*/*v_proj*/*o_proj*patterns to cover Mcore's attention module naming, matching the structures seen inmodelopt/torch/export/unified_export_megatron.pyandmodelopt/torch/prune/plugins/mcore_minitron.py. Since all entries only toggleenable: False, ordering among themselves is not a concern.Minor nit (optional): the comment on line 250 reads "Skip QKV Output Projection (Mcore naming)" —
self_attention.linear_projis the attention output projection, not a QKV output projection. Consider rewording to "Skip Attention Output Projection (Mcore naming)" for clarity and consistency with line 242's HF counterpart.✏️ Optional comment tweak
{ "quantizer_name": "*self_attention.linear_proj*", "enable": False, - }, # Skip QKV Output Projection (Mcore naming) + }, # Skip Attention Output Projection (Mcore naming) ]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/quantization/config.py` around lines 239 - 250, Update the comment for the quantizer entry with "quantizer_name": "*self_attention.linear_proj*" to accurately describe it as the attention output projection (e.g., change "Skip QKV Output Projection (Mcore naming)" to "Skip Attention Output Projection (Mcore naming)"); locate the entry by the unique symbol "*self_attention.linear_proj*" (and optionally mirror wording with the HF counterpart "*o_proj*") and replace the comment text accordingly.modelopt/torch/export/plugins/megatron_importer.py (1)
298-307: Global→local expert remap looks correct.The loop correctly maps each HF global expert
init_expert_id + local_idinto the TEGroupedMLP local slotweight{local_id}, which is what the comment describes and matches the rank-derivedinit_indexset by the caller. Theparallel_configparameter is accepted but intentionally unused here since ETP is routed through theuse_packed_local_expertsbranch at lines 685–695.One small follow-up to consider (optional, not blocking): the
# TODO handle weight_scalemeans quantized-MoE import via TEGroupedMLP will silently skip scales. Please make sure this path is currently only exercised for unquantized HF checkpoints, or gate it with an explicit error when aweight_quantizer._scaleis present on the module so a quantized grouped-MLP import doesn't silently produce an unscaled model.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/export/plugins/megatron_importer.py` around lines 298 - 307, The current import loop that maps HF global experts to TEGroupedMLP local slots (for local_id in range(num_local_experts) ... state_dict[f"weight{local_id}"] = tensor; module.load_state_dict(state_dict)) ignores any weight scales; add a guard before this remapping to detect quantized grouped-MLP modules (e.g., check module.weight_quantizer._scale or similar attribute) and either raise a clear error or assert if a scale is present so we don't silently import a quantized checkpoint without applying scales; ensure the check references the actual module attribute (weight_quantizer._scale) and short-circuits this code path when quantization metadata exists, leaving the existing parallel_config handling unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@modelopt/torch/export/plugins/megatron_importer.py`:
- Around line 661-667: Replace the fragile assert with an explicit EP-topology
validation: import get_expert_model_parallel_world_size alongside
get_expert_model_parallel_rank, compute ep_size =
get_expert_model_parallel_world_size(), then check both that num_global_experts
% num_local_experts == 0 and that num_global_experts == num_local_experts *
ep_size; if the check fails raise a ValueError with a clear message; keep using
get_expert_model_parallel_rank() to compute init_index only after the validation
passes.
---
Nitpick comments:
In `@modelopt/torch/export/plugins/megatron_importer.py`:
- Around line 298-307: The current import loop that maps HF global experts to
TEGroupedMLP local slots (for local_id in range(num_local_experts) ...
state_dict[f"weight{local_id}"] = tensor; module.load_state_dict(state_dict))
ignores any weight scales; add a guard before this remapping to detect quantized
grouped-MLP modules (e.g., check module.weight_quantizer._scale or similar
attribute) and either raise a clear error or assert if a scale is present so we
don't silently import a quantized checkpoint without applying scales; ensure the
check references the actual module attribute (weight_quantizer._scale) and
short-circuits this code path when quantization metadata exists, leaving the
existing parallel_config handling unchanged.
In `@modelopt/torch/quantization/config.py`:
- Around line 239-250: Update the comment for the quantizer entry with
"quantizer_name": "*self_attention.linear_proj*" to accurately describe it as
the attention output projection (e.g., change "Skip QKV Output Projection (Mcore
naming)" to "Skip Attention Output Projection (Mcore naming)"); locate the entry
by the unique symbol "*self_attention.linear_proj*" (and optionally mirror
wording with the HF counterpart "*o_proj*") and replace the comment text
accordingly.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: be965339-147e-4389-8ddc-526a753fbe9d
📒 Files selected for processing (2)
modelopt/torch/export/plugins/megatron_importer.pymodelopt/torch/quantization/config.py
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
There was a problem hiding this comment.
🧹 Nitpick comments (1)
modelopt/torch/distill/plugins/megatron.py (1)
166-175: Regex expansion looks correct; consider scoping the replacement to the matched span.Expanding the lookahead to
(?=\.|$)correctly covers submodule names where the numeric layer index is the last token (e.g.,decoder.layers.5). One pre-existing fragility worth noting now that more names flow through this path: Line 175 usessubmodule_name.replace(match.group(0), str(new_layer_idx)), which replaces every occurrence of that digit substring, not just the matched layer index. For names likedecoder.layers.5.mlp.experts.5this would rewrite both. A targeted substitution using the match span is safer:♻️ Proposed refactor
- new_submodule_name = submodule_name.replace(match.group(0), str(new_layer_idx)) + start, end = match.span() + new_submodule_name = submodule_name[:start] + str(new_layer_idx) + submodule_name[end:]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/distill/plugins/megatron.py` around lines 166 - 175, The replacement currently uses submodule_name.replace(match.group(0), ...) which will replace every occurrence of the matched digits; update the logic to replace only the matched span: use the match span (match.start(), match.end()) to construct new_submodule_name (e.g., slice before + new index + slice after) or call re.sub with the compiled pattern and count=1 so only the found occurrence is replaced; keep use of TransformerLayer._get_layer_offset(model_cfg), match, and new_layer_idx unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@modelopt/torch/distill/plugins/megatron.py`:
- Around line 166-175: The replacement currently uses
submodule_name.replace(match.group(0), ...) which will replace every occurrence
of the matched digits; update the logic to replace only the matched span: use
the match span (match.start(), match.end()) to construct new_submodule_name
(e.g., slice before + new index + slice after) or call re.sub with the compiled
pattern and count=1 so only the found occurrence is replaced; keep use of
TransformerLayer._get_layer_offset(model_cfg), match, and new_layer_idx
unchanged.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 2fde340e-abbc-441d-9228-4c410e1a2f36
📒 Files selected for processing (2)
modelopt/torch/distill/plugins/megatron.pymodelopt/torch/export/unified_export_megatron.py
What does this PR do?
Type of change: Bug fix
Usage
Testing
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: ✅ / ❌ / N/AAdditional Information
Summary by CodeRabbit
Bug Fixes
Improvements