Skip to content

Support LoRA for MOE Megatron SequentialMLP#979

Open
jenchen13 wants to merge 5 commits intomainfrom
jennifchen/moe_lora
Open

Support LoRA for MOE Megatron SequentialMLP#979
jenchen13 wants to merge 5 commits intomainfrom
jennifchen/moe_lora

Conversation

@jenchen13
Copy link
Contributor

@jenchen13 jenchen13 commented Mar 5, 2026

What does this PR do?

Type of change: New Feature

  • Add LoRA support for Megatron SequentialMLP in MOE local experts
  • LoRA adapters are on a per expert level, with shared lora_down for local experts in a layer, and individual lora_up per local expert. This is to accommodate SVDQuant kernel which merges lora_down and quantize into one kernel.
Screenshot 2026-03-05 at 9 00 00 AM

Usage

# Add a code snippet demonstrating how to use this

Testing

Tested adding adapters to a MOE layer, and also gradient flow in MOE.
TODO test sharded state dict

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, using torch.load(..., weights_only=True), avoiding pickle, etc.).

  • Is this change backward compatible?: ✅ / ❌ / N/A
  • If you copied code from any other source, did you follow IP policy in CONTRIBUTING.md?: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

  • New Features

    • Added LoRA support for MoE SequentialMLP with shared down-projection and per-expert up-projections.
    • Added LoRA adapters for Transformer Engine parallel linear layers and quantized TE variants; adapters are registered and discoverable.
    • Added predefined LoRA configurations (dense and MoE) with named choices.
  • Tests

    • Added tests for MoE LoRA structure, gradients, TE LoRA behavior, save/restore, and quantization interactions.
  • Behavior

    • Sequence-parallel now enables automatically for MoE when multi-tensor-parallel is used.

Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
@jenchen13 jenchen13 requested a review from a team as a code owner March 5, 2026 17:03
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 5, 2026

📝 Walkthrough

Walkthrough

Adds Megatron-specific LoRA adapters: a SequentialMLP adapter with shared down / per-expert up projections, TE-based Column/Row Parallel LoRA adapters (and quantized variants), sharded-state handling, new Megatron LoRA configs, and tests exercising MoE, TE, and quantization interactions.

Changes

Cohort / File(s) Summary
Megatron LoRA adapters
modelopt/torch/peft/lora/plugins/megatron.py
Adds _LoRAMegatronSequentialMLP (shared lora_down, per-expert lora_up), TE adapters _LoRATEColumnParallelLinear and _LoRATERowParallelLinear, quantized TE variants, TE/Transformer Engine detection, sharded_state_dict extensions, and registry/quant-registry registrations.
LoRA config definitions
modelopt/torch/peft/lora/config.py
Introduces DENSE_LORA_CFG, MOE_LORA_CFG, and LORA_CFG_CHOICES with pattern rules targeting dense and MoE modules; exports these via __all__.
Model builder (tests helper)
tests/_test_utils/torch/megatron/models.py
Adds num_moe_experts, computes use_sp for sequence_parallel when MoE + TP >1, threads use_te and moe_grouped_gemm into layer/model spec creation.
MoE / TE / quantization tests
tests/gpu_megatron/torch/peft/plugins/test_megatron_peft.py
Adds MoE SequentialMLP LoRA configs, selective-pattern configs, threads MoE/TE params through model provider, and new tests for SequentialMLP LoRA structure and gradient isolation, TE LoRA module types, forward/enable-disable/save-restore, and quantization interactions.

Sequence Diagram(s)

sequenceDiagram
    participant Input as Input\n(permuted_local_hidden_states,\n tokens_per_expert, permuted_probs)
    participant SharedDown as Shared Down-Projection\n(lora_down)
    participant PerExpert as Per-Expert Up-Projections\n(lora_up ModuleList)
    participant Aggregator as Aggregator
    participant Output as Output

    Input->>SharedDown: concat/extract per-expert chunks
    SharedDown->>SharedDown: apply shared down-projection
    SharedDown->>PerExpert: split into per-expert activations
    PerExpert->>PerExpert: apply each lora_up[i] to its chunk
    PerExpert->>Aggregator: emit per-expert LoRA outputs
    Aggregator->>Aggregator: weight/aggregate using tokens_per_expert\nand permuted_probs
    Aggregator->>Output: add aggregated LoRA output to base SequentialMLP output
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 46.43% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Support LoRA for MOE Megatron SequentialMLP' clearly and specifically describes the main feature addition: LoRA support for the MOE (Mixture of Experts) SequentialMLP component.
Security Anti-Patterns ✅ Passed PR adds LoRA adapter implementations without introducing security anti-patterns: no unsafe torch.load() or numpy.load() calls, no hardcoded trust_remote_code=True, no eval()/exec() on untrusted input, no nosec comments, and no restricted license PIP dependencies.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch jennifchen/moe_lora

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@jenchen13 jenchen13 requested review from jingyu-ml and sychen52 March 5, 2026 17:04
expert_model_parallel_size=expert_model_parallel_size,
expert_tensor_parallel_size=expert_tensor_parallel_size,
sequence_parallel=False,
sequence_parallel=num_moe_experts > 0 and tensor_model_parallel_size > 1,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mcore throws an error if you use MOE + TP without SP

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/peft/lora/plugins/megatron.py`:
- Line 336: Replace the direct Megatron-Core import of
ensure_metadata_has_dp_cp_group (currently imported from
megatron.core.transformer.utils) with the version-stable compatibility helper
used by other plugins: import ensure_metadata_has_dp_cp_group from
modelopt.torch.opt.plugins.megatron instead; update the import statement in the
file so all usages of ensure_metadata_has_dp_cp_group in this module use the
compatibility wrapper rather than the unstable megatron.core.transformer.utils
symbol.

In `@tests/_test_utils/torch/megatron/models.py`:
- Line 172: The comparison on line using "sequence_parallel=num_moe_experts > 0
and tensor_model_parallel_size > 1" can raise TypeError when num_moe_experts is
None; change the condition to guard against None by evaluating num_moe_experts
is not None and num_moe_experts > 0 (or equivalently coerce with a default like
(num_moe_experts or 0) > 0) before combining with tensor_model_parallel_size > 1
so sequence_parallel is always a bool; update the call/site where
sequence_parallel is set (the argument named sequence_parallel in the model
construction) to use this guarded condition.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 7cdca600-b524-488d-a5cb-940d5ede1314

📥 Commits

Reviewing files that changed from the base of the PR and between 31f0783 and 03b5140.

📒 Files selected for processing (3)
  • modelopt/torch/peft/lora/plugins/megatron.py
  • tests/_test_utils/torch/megatron/models.py
  • tests/gpu_megatron/torch/peft/plugins/test_megatron_peft.py

@codecov
Copy link

codecov bot commented Mar 5, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 72.14%. Comparing base (e8f9687) to head (7583222).
⚠️ Report is 9 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #979      +/-   ##
==========================================
+ Coverage   72.12%   72.14%   +0.02%     
==========================================
  Files         209      209              
  Lines       23628    23667      +39     
==========================================
+ Hits        17042    17075      +33     
- Misses       6586     6592       +6     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
modelopt/torch/peft/lora/plugins/megatron.py (1)

336-336: ⚠️ Potential issue | 🟠 Major

Use the ModelOpt compatibility helper for metadata normalization.

Line 336 still imports ensure_metadata_has_dp_cp_group directly from Megatron-Core. This is version-fragile; please switch to the compatibility helper used in other ModelOpt plugins.

🔧 Proposed import change
-        from megatron.core.transformer import ensure_metadata_has_dp_cp_group
+        from modelopt.torch.opt.plugins.megatron import ensure_metadata_has_dp_cp_group
#!/bin/bash
set -euo pipefail

# Verify compatibility helper exists and inspect current usages.
rg -n "def ensure_metadata_has_dp_cp_group" modelopt/torch/opt/plugins/megatron.py
rg -n "ensure_metadata_has_dp_cp_group" modelopt/torch -g '*.py'

Expected: helper definition is present in modelopt/torch/opt/plugins/megatron.py, and this file should align with that import pattern.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/peft/lora/plugins/megatron.py` at line 336, The file imports
ensure_metadata_has_dp_cp_group directly from megatron.core.transformer which is
version-fragile; replace that direct import with the ModelOpt compatibility
helper used elsewhere (the helper defined in
modelopt/torch/opt/plugins/megatron.py) so metadata normalization uses the
centralized compatibility wrapper. Locate the import line referencing
ensure_metadata_has_dp_cp_group in modelopt/torch/peft/lora/plugins/megatron.py
and change it to import the helper from the ModelOpt plugin module (matching
other plugins' import pattern), ensuring all usages of
ensure_metadata_has_dp_cp_group in this file call the compatibility helper
instead of the Megatron-Core symbol.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/peft/lora/plugins/megatron.py`:
- Around line 373-383: The lora_up_key is reused for every local expert when
singleton_local_shards is False, causing earlier shards to be overwritten in
sharded_state_dict; update the key generation so it is unique per expert (e.g.,
include expert_global_idx or another expert identifier in the f-string) when
building lora_up_key in the block that sets up_offsets/up_offsets and calls
ShardedTensor.from_rank_offsets (refer to variables lora_up_key,
singleton_local_shards, up_offsets, sharded_offsets, expert_global_idx,
adapter_name, and the call to ShardedTensor.from_rank_offsets) so each expert
writes to a distinct dict key and no shards are dropped.

---

Duplicate comments:
In `@modelopt/torch/peft/lora/plugins/megatron.py`:
- Line 336: The file imports ensure_metadata_has_dp_cp_group directly from
megatron.core.transformer which is version-fragile; replace that direct import
with the ModelOpt compatibility helper used elsewhere (the helper defined in
modelopt/torch/opt/plugins/megatron.py) so metadata normalization uses the
centralized compatibility wrapper. Locate the import line referencing
ensure_metadata_has_dp_cp_group in modelopt/torch/peft/lora/plugins/megatron.py
and change it to import the helper from the ModelOpt plugin module (matching
other plugins' import pattern), ensuring all usages of
ensure_metadata_has_dp_cp_group in this file call the compatibility helper
instead of the Megatron-Core symbol.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 5222a6e7-5d85-4fc3-b56d-4af312e8fe36

📥 Commits

Reviewing files that changed from the base of the PR and between 03b5140 and d3827e0.

📒 Files selected for processing (2)
  • modelopt/torch/peft/lora/plugins/megatron.py
  • tests/_test_utils/torch/megatron/models.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/_test_utils/torch/megatron/models.py

Comment on lines +373 to +383
if singleton_local_shards:
lora_up_key = f"{prefix}lora_b_{adapter_name}.{expert_global_idx}.weight"
up_offsets = sharded_offsets
else:
lora_up_key = f"{prefix}lora_b_{adapter_name}.weight"
up_offsets = (
*sharded_offsets,
(len(sharded_offsets), expert_global_idx, num_global_experts),
)

sharded_state_dict[lora_up_key] = ShardedTensor.from_rank_offsets(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Prevent silent shard loss in lora_up checkpoint mapping.

At Line 377, lora_up_key is identical for every local expert when singleton_local_shards is False, so Line 383 overwrites earlier experts in sharded_state_dict. This drops shards for all but the last expert.

💡 Minimal fix to avoid key overwrite
                 if singleton_local_shards:
                     lora_up_key = f"{prefix}lora_b_{adapter_name}.{expert_global_idx}.weight"
                     up_offsets = sharded_offsets
                 else:
-                    lora_up_key = f"{prefix}lora_b_{adapter_name}.weight"
+                    lora_up_key = f"{prefix}lora_b_{adapter_name}.{expert_global_idx}.weight"
                     up_offsets = (
                         *sharded_offsets,
                         (len(sharded_offsets), expert_global_idx, num_global_experts),
                     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/peft/lora/plugins/megatron.py` around lines 373 - 383, The
lora_up_key is reused for every local expert when singleton_local_shards is
False, causing earlier shards to be overwritten in sharded_state_dict; update
the key generation so it is unique per expert (e.g., include expert_global_idx
or another expert identifier in the f-string) when building lora_up_key in the
block that sets up_offsets/up_offsets and calls ShardedTensor.from_rank_offsets
(refer to variables lora_up_key, singleton_local_shards, up_offsets,
sharded_offsets, expert_global_idx, adapter_name, and the call to
ShardedTensor.from_rank_offsets) so each expert writes to a distinct dict key
and no shards are dropped.

Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/gpu_megatron/torch/peft/plugins/test_megatron_peft.py (1)

179-218: ⚠️ Potential issue | 🟡 Minor

Inconsistent parameter passing between meta_device and CUDA branches.

The meta_device=True branch (lines 190-204) does not pass moe_grouped_gemm to get_mcore_gpt_model, while the CUDA branch (lines 206-218) does pass it. This could lead to inconsistent behavior when testing with meta device initialization.

🔧 Suggested fix
     if meta_device:
         with torch.device("meta"):
             gpt_model = get_mcore_gpt_model(
                 tensor_model_parallel_size=tp_size,
                 num_layers=2,
                 ffn_hidden_size=None,
                 num_attention_heads=4,
                 activation_func="squared_relu",
                 transformer_impl="local",
                 use_te=use_te,
                 hidden_size=hidden_size,
                 vocab_size=vocab_size,
                 use_cpu_initialization=meta_device,
                 num_moe_experts=num_moe_experts,
+                moe_grouped_gemm=moe_grouped_gemm,
             )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu_megatron/torch/peft/plugins/test_megatron_peft.py` around lines 179
- 218, _gpt_model_provider creates models inconsistently: when meta_device=True
it calls get_mcore_gpt_model without passing the moe_grouped_gemm argument
whereas the non-meta (CUDA) branch passes moe_grouped_gemm; update the
meta_device branch call to include moe_grouped_gemm=moe_grouped_gemm so both
branches call get_mcore_gpt_model with the same set of parameters (reference
get_mcore_gpt_model, meta_device, moe_grouped_gemm, and _gpt_model_provider).
♻️ Duplicate comments (1)
modelopt/torch/peft/lora/plugins/megatron.py (1)

388-404: ⚠️ Potential issue | 🔴 Critical

Dict key collision causes silent shard loss for lora_up weights.

When singleton_local_shards is False, line 392 generates the same lora_up_key for every local expert (f"{prefix}lora_b_{adapter_name}.weight"). Since sharded_state_dict is a Python dict, line 398 overwrites entries from previous experts, causing all but the last expert's lora_up weights to be lost during checkpointing.

🐛 Suggested fix
                 if singleton_local_shards:
                     lora_up_key = f"{prefix}lora_b_{adapter_name}.{expert_global_idx}.weight"
                     up_offsets = sharded_offsets
                 else:
-                    lora_up_key = f"{prefix}lora_b_{adapter_name}.weight"
+                    lora_up_key = f"{prefix}lora_b_{adapter_name}.{expert_local_idx}.weight"
                     up_offsets = (
                         *sharded_offsets,
                         (len(sharded_offsets), expert_global_idx, num_global_experts),
                     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/peft/lora/plugins/megatron.py` around lines 388 - 404, The
dict key for lora_up is colliding when singleton_local_shards is False because
lora_up_key is built as f"{prefix}lora_b_{adapter_name}.weight" for every local
expert; change that branch to include the expert identifier (e.g., use
f"{prefix}lora_b_{adapter_name}.{expert_global_idx}.weight" like the singleton
branch) so each expert gets a unique key before calling
ShardedTensor.from_rank_offsets (ensure this adjustment is made alongside the
existing up_offsets logic involving expert_global_idx and num_global_experts).
🧹 Nitpick comments (1)
modelopt/torch/peft/lora/config.py (1)

39-48: Consider using dataclasses or Pydantic for configuration objects.

The configs are defined as plain dictionaries. As per coding guidelines for modelopt/torch/**/config.py: "Use dataclasses or Pydantic for mode configuration objects." While dict-based configs work and align with existing usage patterns in the codebase (e.g., adapter_cfg dict format), consider whether a typed config class would improve validation and IDE support.

Given that these configs are passed to update_model which already expects dict-like structures, the current approach may be intentional for compatibility.

Also applies to: 67-77

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/peft/lora/config.py` around lines 39 - 48, The DENSE_LORA_CFG
(and the similar config at lines 67-77) are plain dicts; convert them into a
typed configuration class (use a dataclass or Pydantic model named e.g.,
LoraConfig with an adapter_cfg field) to provide validation and IDE type hints
while preserving the existing dict shape passed to update_model; implement a
to_dict/asdict/.dict() method and update call sites that pass DENSE_LORA_CFG to
call that conversion (or keep a module-level constant DENSE_LORA_CFG =
LoraConfig(...).to_dict() if you need immediate backward compatibility), and
ensure fields like adapter_cfg and nested patterns (e.g., "*linear_qkv*") remain
unchanged in the resulting dict.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@tests/gpu_megatron/torch/peft/plugins/test_megatron_peft.py`:
- Around line 179-218: _gpt_model_provider creates models inconsistently: when
meta_device=True it calls get_mcore_gpt_model without passing the
moe_grouped_gemm argument whereas the non-meta (CUDA) branch passes
moe_grouped_gemm; update the meta_device branch call to include
moe_grouped_gemm=moe_grouped_gemm so both branches call get_mcore_gpt_model with
the same set of parameters (reference get_mcore_gpt_model, meta_device,
moe_grouped_gemm, and _gpt_model_provider).

---

Duplicate comments:
In `@modelopt/torch/peft/lora/plugins/megatron.py`:
- Around line 388-404: The dict key for lora_up is colliding when
singleton_local_shards is False because lora_up_key is built as
f"{prefix}lora_b_{adapter_name}.weight" for every local expert; change that
branch to include the expert identifier (e.g., use
f"{prefix}lora_b_{adapter_name}.{expert_global_idx}.weight" like the singleton
branch) so each expert gets a unique key before calling
ShardedTensor.from_rank_offsets (ensure this adjustment is made alongside the
existing up_offsets logic involving expert_global_idx and num_global_experts).

---

Nitpick comments:
In `@modelopt/torch/peft/lora/config.py`:
- Around line 39-48: The DENSE_LORA_CFG (and the similar config at lines 67-77)
are plain dicts; convert them into a typed configuration class (use a dataclass
or Pydantic model named e.g., LoraConfig with an adapter_cfg field) to provide
validation and IDE type hints while preserving the existing dict shape passed to
update_model; implement a to_dict/asdict/.dict() method and update call sites
that pass DENSE_LORA_CFG to call that conversion (or keep a module-level
constant DENSE_LORA_CFG = LoraConfig(...).to_dict() if you need immediate
backward compatibility), and ensure fields like adapter_cfg and nested patterns
(e.g., "*linear_qkv*") remain unchanged in the resulting dict.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: b3f605d6-8b34-49b5-b97a-6e7a981044d7

📥 Commits

Reviewing files that changed from the base of the PR and between d3827e0 and 7583222.

📒 Files selected for processing (4)
  • modelopt/torch/peft/lora/config.py
  • modelopt/torch/peft/lora/plugins/megatron.py
  • tests/_test_utils/torch/megatron/models.py
  • tests/gpu_megatron/torch/peft/plugins/test_megatron_peft.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/_test_utils/torch/megatron/models.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant