[Feature]【Hackathon 10th Spring No.47】Add MiniMax-M1 model support#6994
Open
cloudforge1 wants to merge 14 commits intoPaddlePaddle:developfrom
Open
[Feature]【Hackathon 10th Spring No.47】Add MiniMax-M1 model support#6994cloudforge1 wants to merge 14 commits intoPaddlePaddle:developfrom
cloudforge1 wants to merge 14 commits intoPaddlePaddle:developfrom
Conversation
- Model scaffold: minimax_m1.py with hybrid attention (70 linear + 10 full GQA), MoE (32 experts top-2), DeepNorm scaling, weight loading - Lightning Attention: 5 Triton JIT kernels + 3 Python wrappers - Tests: 27 pytest cases covering attn dispatch, slope construction, registration, layer construction, and forward-pass smoke tests - Docs: EN/CN best practices + supported models list updates Architecture: MiniMaxText01ForCausalLM (456B MoE, 80 layers)
…ment load_weights - LinearAttention: add output_gate (sigmoid gating), norm (RMSNorm), rename o_proj → out_proj. Forward: SiLU on QKV → lightning_attn → norm → gate → out_proj - DecoderLayer: rename self.mlp → self.block_sparse_moe to match HF config - DeepNorm: branch alpha/beta on attention_type (linear vs full) - Postnorm: add two code paths following vLLM reference - KV state: persist _kv_history across forward calls - Dual registration: MiniMaxM1ForCausalLM + MiniMaxText01ForCausalLM - set_state_dict: preprocess HF keys (w1→gate_proj, w3→up_proj, w2→down_proj, q/k/v→qkv_proj concatenation) - load_weights: v1 loader with stacked_params_mapping + expert_params_mapping - Tests: 29/29 passing
|
Thanks for your contribution! |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #6994 +/- ##
==========================================
Coverage ? 73.50%
==========================================
Files ? 401
Lines ? 56603
Branches ? 8890
==========================================
Hits ? 41607
Misses ? 12064
Partials ? 2932
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
- Quantization-aware weight_key_map in MiniMaxM1MoE (w4a8, w4afp8 static/dynamic, tensor_wise_fp8, block_wise_fp8) mirroring Ernie4_5_MoE - Gate layer uses skip_quant=True, weight_dtype='float32' - set_state_dict v0 loader: quant-aware regex for expert weights (.quant_weight, .weight_scale, .activation_scale) - set_state_dict v0 loader: quant-aware qkv merge (suffix-keyed buffers) - 3 new tests: default/w4a8/w4afp8-dynamic weight_key_map branches
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
为 FastDeploy 增加部署 MiniMaxAI/MiniMax-M1-40k 系列模型的能力。
This PR adds support for deploying the MiniMax-M1 (456B MoE, 45.9B active) model family in FastDeploy, as required by Hackathon 10th Spring No.47.
MiniMax-M1 is a hybrid-attention Mixture-of-Experts LLM with:
MiniMaxM1ForCausalLMandMiniMaxText01ForCausalLMDesign document: community#1252
Reference approved RFC: community#1156 (@NKNaN)
Modifications
Model Code (
fastdeploy/model_executor/models/minimax_m1.py, ~800 lines)9 classes implementing the full model:
MiniMaxM1MLP: Gate/up merged projection with SiLU activationMiniMaxM1MoE: FusedMoE with 32 experts, top-2 routing, renormalize=True, quantization-awareweight_key_map(w4a8, w4afp8 static/dynamic, tensor_wise_fp8, block_wise_fp8)MiniMaxM1FullAttention: Standard GQA with RoPE, used in 10 out of 80 layersMiniMaxM1LinearAttention: Lightning attention with SiLU-gated QKV, output_gate (sigmoid), RMSNorm, persistent KV state history. Forward: SiLU(QKV) → lightning_attn → RMSNorm → sigmoid(gate) × hidden → out_projMiniMaxM1DecoderLayer: Dispatches to linear/full attention based onattn_type_list, DeepNorm scaling with separate alpha/beta per attention type, postnorm supportMiniMaxM1Model: Full transformer with embedding and final RMSNormMiniMaxM1ForCausalLM: Causal LM wrapper with dual weight loading:set_state_dict(v0 loader): HF key preprocessing (w1→gate_proj, w3→up_proj, w2→down_proj, q/k/v→qkv_proj concatenation)load_weights(v1 loader): stacked_params_mapping + FusedMoE.make_expert_params_mappingMiniMaxM1PretrainedModel: Tensor parallel column/row split mappingsLightning Attention Kernels (
fastdeploy/model_executor/ops/triton_ops/lightning_attn.py, 711 lines)Triton kernels for O(n) linear attention with exponential decay:
_fwd_kernel: Intra-block attention with causal masking and decay factors_fwd_kv_kernel: Inter-block KV state accumulation with block-level decaylightning_attention(): Python wrapper dispatching to Triton with automatic block size, dtype management, and KV history persistenceDocumentation
docs/best_practices/MiniMax-M1.md+docs/zh/best_practices/MiniMax-M1.md: Bilingual usage guide with deployment examplesdocs/supported_models.md+docs/zh/supported_models.md: Added MiniMax-M1 to LLM model tableDesign Decisions
MiniMaxText01LinearAttentionreference exactlyblock_sparse_moeattribute name matches HF config convention (notmlp)Usage or Command
See docs/best_practices/MiniMax-M1.md for full deployment guide.
Accuracy Tests
Unit Tests (32/32 passed — CI verified on H20 GPU)
tests/model_executor/test_minimax_m1.py(390 lines, 8 classes, 32 tests)TestLightningAttentionPurePython(4 tests): Reference NumPy implementation, block-size sweep, multi-head, KV history persistenceTestMoEConstruction(2 tests): Expert count, gate+experts constructionTestBuildSlopeTensor(3 tests): Exponential decay slopes for power-of-2 and non-power-of-2 head countsTestModelRegistration(4 tests): Dual architecture registration (MiniMaxM1ForCausalLM+MiniMaxText01ForCausalLM)TestDecoderLayerConstruction(9 tests): Linear/full attention dispatch, MoE vs dense MLP, postnorm config, fallback attention type, quantization weight_key_map (default/w4a8/w4afp8-dynamic)TestDecoderLayerForward(5 tests): Forward shape validation, DeepNorm scaling, postnorm code pathTestFullModelConstruction(3 tests): Full model assembly, layer count, embedding dimensionsTestPretrainedModelMappings(2 tests): Tensor parallel split mappingsCI Results (commit e068f01)
run_tests_with_coverage: 32/32 tests passed in 5.37s (with coverage) + 1.76s (standalone) on H20 GPUtests/distributed/test_hopper_ll_precision_entry.py— not caused by this PRPre-commit Validation
All hooks passing: black, isort, flake8, ruff, clang-format, merge conflict check, trailing whitespace, large file check.
Checklist
minimax_m1.py, ~800 lines) — 9 classes with full weight loading + quantization supportlightning_attn.py, 711 lines) — O(n) linear attentionset_state_dict) and v1 (load_weights) loader paths implementedMiniMaxM1ForCausalLM+MiniMaxText01ForCausalLM