From f9747ab86d716fb31301e62abff24bc0124bbc37 Mon Sep 17 00:00:00 2001 From: cloudforge1 Date: Tue, 24 Mar 2026 16:34:39 +0800 Subject: [PATCH 1/3] feat: add MiniMax-M1 (MiniMaxText01) model support - Model scaffold: minimax_m1.py with hybrid attention (70 linear + 10 full GQA), MoE (32 experts top-2), DeepNorm scaling, weight loading - Lightning Attention: 5 Triton JIT kernels + 3 Python wrappers - Tests: 27 pytest cases covering attn dispatch, slope construction, registration, layer construction, and forward-pass smoke tests - Docs: EN/CN best practices + supported models list updates Architecture: MiniMaxText01ForCausalLM (456B MoE, 80 layers) --- docs/best_practices/MiniMax-M1.md | 45 ++ docs/supported_models.md | 1 + docs/zh/best_practices/MiniMax-M1.md | 45 ++ docs/zh/supported_models.md | 1 + .../model_executor/models/minimax_m1.py | 595 +++++++++++++++ .../ops/triton_ops/lightning_attn.py | 711 ++++++++++++++++++ tests/model_executor/test_minimax_m1.py | 349 +++++++++ 7 files changed, 1747 insertions(+) create mode 100644 docs/best_practices/MiniMax-M1.md create mode 100644 docs/zh/best_practices/MiniMax-M1.md create mode 100644 fastdeploy/model_executor/models/minimax_m1.py create mode 100644 fastdeploy/model_executor/ops/triton_ops/lightning_attn.py create mode 100644 tests/model_executor/test_minimax_m1.py diff --git a/docs/best_practices/MiniMax-M1.md b/docs/best_practices/MiniMax-M1.md new file mode 100644 index 00000000000..0a1c334cf83 --- /dev/null +++ b/docs/best_practices/MiniMax-M1.md @@ -0,0 +1,45 @@ +[简体中文](../zh/best_practices/MiniMax-M1.md) + +# MiniMax-M1 Model + +## I. Environment Preparation + +### 1.1 Support Requirements + +MiniMax-M1 support in FastDeploy uses a hybrid decoder stack: + +- Standard full-attention layers run through the existing FastDeploy attention backend. +- Linear-attention layers use the Lightning Attention Triton kernels in `fastdeploy/model_executor/ops/triton_ops/lightning_attn.py`. +- Current first-pass support targets BF16 inference. + +### 1.2 Installing FastDeploy + +Installation process reference document [FastDeploy GPU Installation](../get_started/installation/nvidia_gpu.md) + +## II. How to Use + +### 2.1 Basics: Starting the Service + +```shell +MODEL_PATH=/models/MiniMax-Text-01 + +python -m fastdeploy.entrypoints.openai.api_server \ + --model "$MODEL_PATH" \ + --port 8180 \ + --metrics-port 8181 \ + --engine-worker-queue-port 8182 \ + --max-model-len 32768 \ + --max-num-seqs 32 +``` + +### 2.2 Model Notes + +- HuggingFace architecture: `MiniMaxText01ForCausalLM` +- Hybrid layer layout: 70 linear-attention layers and 10 full-attention layers +- MoE routing: 32 experts, top-2 experts per token + +## III. Known Limitations + +- This initial integration is focused on model structure and backend wiring. +- Low-bit quantization support still requires follow-up validation against MiniMax-M1 weights. +- Production validation should include GPU runtime checks for Lightning Attention decode/prefill paths. diff --git a/docs/supported_models.md b/docs/supported_models.md index b0684affc11..fcf4b651f8c 100644 --- a/docs/supported_models.md +++ b/docs/supported_models.md @@ -38,6 +38,7 @@ These models accept text input. |⭐QWEN2.5|BF16/WINT8/FP8|Qwen/qwen2.5-72B;
Qwen/qwen2.5-32B;
Qwen/qwen2.5-14B;
Qwen/qwen2.5-7B;
Qwen/qwen2.5-3B;
Qwen/qwen2.5-1.5B;
Qwen/qwen2.5-0.5B, etc.| |⭐QWEN2|BF16/WINT8/FP8|Qwen/Qwen/qwen2-72B;
Qwen/Qwen/qwen2-7B;
Qwen/qwen2-1.5B;
Qwen/qwen2-0.5B;
Qwen/QwQ-32, etc.| |⭐DEEPSEEK|BF16/WINT4|unsloth/DeepSeek-V3.1-BF16;
unsloth/DeepSeek-V3-0324-BF16;
unsloth/DeepSeek-R1-BF16, etc.| +|MINIMAX-M1|BF16|[MiniMaxAI/MiniMax-Text-01](./best_practices/MiniMax-M1.md);
MiniMaxAI/MiniMax-Text-01-Large, etc.| |⭐GPT-OSS|BF16/WINT8|unsloth/gpt-oss-20b-BF16, etc.| |⭐GLM-4.5/4.6|BF16/wfp8afp8|zai-org/GLM-4.5-Air;
zai-org/GLM-4.6
 [最佳实践](./best_practices/GLM-4-MoE-Text.md) etc.| diff --git a/docs/zh/best_practices/MiniMax-M1.md b/docs/zh/best_practices/MiniMax-M1.md new file mode 100644 index 00000000000..4ba5f695a83 --- /dev/null +++ b/docs/zh/best_practices/MiniMax-M1.md @@ -0,0 +1,45 @@ +[English](../../best_practices/MiniMax-M1.md) + +# MiniMax-M1 模型 + +## 一、环境准备 + +### 1.1 支持说明 + +FastDeploy 中的 MiniMax-M1 采用混合解码器结构: + +- 全注意力层复用 FastDeploy 现有 Attention 后端。 +- 线性注意力层使用 `fastdeploy/model_executor/ops/triton_ops/lightning_attn.py` 中的 Lightning Attention Triton kernel。 +- 当前首版支持以 BF16 推理为主。 + +### 1.2 安装 FastDeploy + +安装流程可参考 [FastDeploy GPU 安装文档](../get_started/installation/nvidia_gpu.md) + +## 二、使用方式 + +### 2.1 基础启动命令 + +```shell +MODEL_PATH=/models/MiniMax-Text-01 + +python -m fastdeploy.entrypoints.openai.api_server \ + --model "$MODEL_PATH" \ + --port 8180 \ + --metrics-port 8181 \ + --engine-worker-queue-port 8182 \ + --max-model-len 32768 \ + --max-num-seqs 32 +``` + +### 2.2 模型特性 + +- HuggingFace 架构名:`MiniMaxText01ForCausalLM` +- 层类型分布:70 层线性注意力 + 10 层全注意力 +- MoE 路由:32 个专家,每个 token 选择 top-2 专家 + +## 三、当前限制 + +- 当前版本优先完成模型组网与后端接线。 +- 各类低比特量化推理能力还需要结合真实权重进一步验证。 +- Lightning Attention 的 prefill/decode 路径仍需在 GPU 环境完成端到端验证。 diff --git a/docs/zh/supported_models.md b/docs/zh/supported_models.md index 1424d2320fb..71cbfa09150 100644 --- a/docs/zh/supported_models.md +++ b/docs/zh/supported_models.md @@ -36,6 +36,7 @@ python -m fastdeploy.entrypoints.openai.api_server \ |⭐QWEN2.5|BF16/WINT8/FP8|Qwen/qwen2.5-72B;
Qwen/qwen2.5-32B;
Qwen/qwen2.5-14B;
Qwen/qwen2.5-7B;
Qwen/qwen2.5-3B;
Qwen/qwen2.5-1.5B;
Qwen/qwen2.5-0.5B, etc.| |⭐QWEN2|BF16/WINT8/FP8|Qwen/Qwen/qwen2-72B;
Qwen/Qwen/qwen2-7B;
Qwen/qwen2-1.5B;
Qwen/qwen2-0.5B;
Qwen/QwQ-32, etc.| |⭐DEEPSEEK|BF16/WINT4|unsloth/DeepSeek-V3.1-BF16;
unsloth/DeepSeek-V3-0324-BF16;
unsloth/DeepSeek-R1-BF16, etc.| +|MINIMAX-M1|BF16|[MiniMaxAI/MiniMax-Text-01](./best_practices/MiniMax-M1.md);
MiniMaxAI/MiniMax-Text-01-Large, etc.| |⭐GPT-OSS|BF16/WINT8|unsloth/gpt-oss-20b-BF16, etc.| |⭐GLM-4.5/4.6|BF16/wfp8afp8|zai-org/GLM-4.5-Air;
zai-org/GLM-4.6
 [最佳实践](./best_practices/GLM-4-MoE-Text.md) etc.| diff --git a/fastdeploy/model_executor/models/minimax_m1.py b/fastdeploy/model_executor/models/minimax_m1.py new file mode 100644 index 00000000000..8b558d56c52 --- /dev/null +++ b/fastdeploy/model_executor/models/minimax_m1.py @@ -0,0 +1,595 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +MiniMax-M1 Model for FastDeploy +Hybrid architecture: 70 linear attention layers + 10 full attention layers +MoE: 32 experts, top-2 routing per token +""" + +from __future__ import annotations + +import math +from typing import Dict + +import paddle +from paddle import nn +from paddleformers.transformers import PretrainedModel +from paddleformers.utils.log import logger + +from fastdeploy.config import FDConfig +from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce +from fastdeploy.model_executor.forward_meta import ForwardMeta +from fastdeploy.model_executor.graph_optimization.decorator import ( + support_graph_optimization, +) +from fastdeploy.model_executor.layers.activation import SiluAndMul +from fastdeploy.model_executor.layers.attention.attention import Attention +from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding +from fastdeploy.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from fastdeploy.model_executor.layers.lm_head import ParallelLMHead +from fastdeploy.model_executor.layers.moe.moe import FusedMoE +from fastdeploy.model_executor.layers.normalization import RMSNorm +from fastdeploy.model_executor.models.model_base import ( + ModelCategory, + ModelForCasualLM, + ModelRegistry, +) +from fastdeploy.model_executor.ops.triton_ops.lightning_attn import lightning_attention + + +class MiniMaxM1MLP(nn.Layer): + """MiniMax-M1 MLP Layer (Dense FFN)""" + + def __init__( + self, + fd_config: FDConfig, + intermediate_size: int, + prefix: str = "", + reduce_results: bool = True, + ) -> None: + super().__init__() + + self.gate_up_proj = MergedColumnParallelLinear( + fd_config=fd_config, + prefix=f"{prefix}.gate_up_proj", + input_size=fd_config.model_config.hidden_size, + output_size=intermediate_size * 2, + with_bias=False, + activation=fd_config.model_config.hidden_act, + ) + + self.down_proj = RowParallelLinear( + fd_config=fd_config, + prefix=f"{prefix}.down_proj", + input_size=intermediate_size, + output_size=fd_config.model_config.hidden_size, + with_bias=False, + reduce_results=reduce_results, + ) + + self.act_fn = SiluAndMul( + fd_config=fd_config, + bias=getattr(self.gate_up_proj, "bias", None), + act_method=fd_config.model_config.hidden_act, + ) + + def load_state_dict(self, state_dict): + self.gate_up_proj.load_state_dict(state_dict) + self.down_proj.load_state_dict(state_dict) + + def forward(self, x, forward_meta=None): + gate_up_out = self.gate_up_proj(x) + act_out = self.act_fn(gate_up_out) + down_out = self.down_proj(act_out) + return down_out + + +class MiniMaxM1MoE(nn.Layer): + """MiniMax-M1 MoE Layer""" + + def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str) -> None: + super().__init__() + + self.tp_size = fd_config.parallel_config.tensor_parallel_size + self.norm_topk_prob = getattr(fd_config.model_config, "norm_topk_prob", False) + + weight_key_map = { + "up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight", + "down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight", + } + + self.gate = ReplicatedLinear( + fd_config=fd_config, + prefix=f"{prefix}.gate", + input_size=fd_config.model_config.hidden_size, + output_size=fd_config.model_config.num_local_experts, + with_bias=False, + skip_quant=True, + weight_dtype="float32", + ) + + self.experts = FusedMoE( + fd_config=fd_config, + reduce_results=True, + renormalize=self.norm_topk_prob, + moe_intermediate_size=fd_config.model_config.intermediate_size, + num_experts=fd_config.model_config.num_local_experts, + top_k=fd_config.model_config.num_experts_per_tok, + layer_idx=layer_id, + weight_key_map=weight_key_map, + ) + + def load_state_dict(self, state_dict): + self.gate.load_state_dict(state_dict) + self.experts.load_state_dict(state_dict) + + def forward(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta): + """Forward pass with router gating.""" + moe_out = self.experts(hidden_states, self.gate, forward_meta) + if self.tp_size > 1: + moe_out = tensor_model_parallel_all_reduce(moe_out) + return moe_out + + +class MiniMaxM1Attention(nn.Layer): + """MiniMax-M1 Full Attention (standard GQA)""" + + def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None: + super().__init__() + + self.hidden_size = fd_config.model_config.hidden_size + self.num_attention_heads = fd_config.model_config.num_attention_heads + self.head_dim = fd_config.model_config.head_dim + self.num_key_value_heads = fd_config.model_config.num_key_value_heads + + self.qkv_proj = ColumnParallelLinear( + fd_config=fd_config, + prefix=f"{prefix}.qkv_proj", + input_size=self.hidden_size, + output_size=(self.num_attention_heads + 2 * self.num_key_value_heads) * self.head_dim, + with_bias=False, + ) + + self.o_proj = RowParallelLinear( + fd_config, + prefix=f"{prefix}.o_proj", + input_size=self.num_attention_heads * self.head_dim, + output_size=self.hidden_size, + with_bias=False, + layer_id=layer_id, + ) + + self.attn = Attention( + fd_config=fd_config, + layer_id=layer_id, + prefix=prefix, + use_neox_rotary_style=True, + ) + + def load_state_dict(self, state_dict): + self.qkv_proj.load_state_dict(state_dict) + self.o_proj.load_state_dict(state_dict) + self.attn.load_state_dict(state_dict) + + def forward( + self, + forward_meta: ForwardMeta, + hidden_states: paddle.Tensor, + ): + """Full attention forward.""" + q, k, v = self._compute_qkv(hidden_states) + attn_output = self.attn(q, k, v, forward_meta=forward_meta) + output = self.o_proj(attn_output) + return output + + def _compute_qkv(self, hidden_states): + """Project hidden states to Q, K, V.""" + qkv = self.qkv_proj(hidden_states) + q_size = self.num_attention_heads * self.head_dim + kv_size = self.num_key_value_heads * self.head_dim + q, k, v = qkv.split([q_size, kv_size, kv_size], axis=-1) + return q, k, v + + +class MiniMaxM1LinearAttention(nn.Layer): + """MiniMax-M1 Linear Attention (Lightning Attention)""" + + def __init__( + self, + fd_config: FDConfig, + layer_id: int, + linear_layer_id: int, + prefix: str = "", + ) -> None: + super().__init__() + + self.hidden_size = fd_config.model_config.hidden_size + self.head_dim = fd_config.model_config.head_dim + self.num_attention_heads = fd_config.model_config.num_attention_heads + + # QKV projection + self.qkv_proj = ColumnParallelLinear( + fd_config=fd_config, + prefix=f"{prefix}.qkv_proj", + input_size=self.hidden_size, + output_size=self.num_attention_heads * self.head_dim * 3, + with_bias=False, + ) + + # Output projection + self.o_proj = RowParallelLinear( + fd_config, + prefix=f"{prefix}.o_proj", + input_size=self.num_attention_heads * self.head_dim, + output_size=self.hidden_size, + with_bias=False, + layer_id=layer_id, + ) + + # Build slope tensor for exponential decay + slope_tensor = self._build_slope_tensor(self.num_attention_heads) + if fd_config.model_config.num_hidden_layers <= 1: + slope_tensor = slope_tensor * (1 + 1e-5) + else: + slope_tensor = slope_tensor * (1 - layer_id / (fd_config.model_config.num_hidden_layers - 1) + 1e-5) + # Register as buffer (not trainable) + self.register_buffer("slope_rate", slope_tensor) + + # KV cache shape: [heads, head_dim, head_dim] + self.kv_cache_shape = (self.num_attention_heads, self.head_dim, self.head_dim) + + def load_state_dict(self, state_dict): + self.qkv_proj.load_state_dict(state_dict) + self.o_proj.load_state_dict(state_dict) + + @staticmethod + def _build_slope_tensor(n_heads: int): + """Build ALiBi-style slope tensor for exponential decay.""" + + def get_slopes_power_of_2(n): + start = 2 ** (-(2 ** (-(math.log2(n) - 3)))) + return [start * (start**i) for i in range(n)] + + if math.log2(n_heads).is_integer(): + slopes = get_slopes_power_of_2(n_heads) + else: + closest_power = 2 ** math.floor(math.log2(n_heads)) + slopes = get_slopes_power_of_2(closest_power) + slopes += get_slopes_power_of_2(2 * closest_power)[0::2][: n_heads - closest_power] + + return paddle.to_tensor(slopes, dtype=paddle.float32).reshape([n_heads, 1, 1]) + + def forward( + self, + forward_meta: ForwardMeta, + hidden_states: paddle.Tensor, + ): + """Linear attention forward.""" + # Project QKV + qkv = self.qkv_proj(hidden_states) + q, k, v = qkv.split( + [ + self.num_attention_heads * self.head_dim, + self.num_attention_heads * self.head_dim, + self.num_attention_heads * self.head_dim, + ], + axis=-1, + ) + + # Reshape for lightning attention + batch_size = q.shape[0] + q = q.reshape([batch_size, -1, self.num_attention_heads, self.head_dim]) + k = k.reshape([batch_size, -1, self.num_attention_heads, self.head_dim]) + v = v.reshape([batch_size, -1, self.num_attention_heads, self.head_dim]) + + # Transpose to [batch, heads, seq_len, dim] + q = q.transpose([0, 2, 1, 3]) + k = k.transpose([0, 2, 1, 3]) + v = v.transpose([0, 2, 1, 3]) + + # Initialize KV history if needed + kv_history = paddle.zeros( + [batch_size, self.num_attention_heads, self.head_dim, self.head_dim], + dtype=q.dtype, + ) + + # Apply lightning attention + attn_output, _ = lightning_attention( + q, k, v, self.slope_rate.squeeze(-1), block_size=256, kv_history=kv_history + ) + + # Reshape back + attn_output = attn_output.transpose([0, 2, 1, 3]) + attn_output = attn_output.reshape([batch_size, -1, self.num_attention_heads * self.head_dim]) + + # Output projection + output = self.o_proj(attn_output) + return output + + +class MiniMaxM1DecoderLayer(nn.Layer): + """MiniMax-M1 Decoder Layer with Hybrid Attention Dispatch""" + + @staticmethod + def _build_attn_type_list(num_layers: int): + """Build attention type list: 70 linear + 10 full (at indices 7,15,23,...).""" + attn_type_list = [0] * num_layers # Default: all linear + # Full attention every 8 layers starting at layer 7 + full_attn_indices = [7, 15, 23, 31, 39, 47, 55, 63, 71, 79] + for idx in full_attn_indices: + if idx < num_layers: + attn_type_list[idx] = 1 + return attn_type_list + + def __init__( + self, + fd_config: FDConfig, + layer_id: int, + prefix: str = "", + ) -> None: + super().__init__() + + self.hidden_size = fd_config.model_config.hidden_size + self.layer_id = layer_id + self.postnorm = getattr(fd_config.model_config, "postnorm", False) + + # Determine attention type for this layer + # attn_type_list: 70 linear (0) + 10 full (1) at specific indices + attn_type_list = getattr( + fd_config.model_config, + "attn_type_list", + self._build_attn_type_list(fd_config.model_config.num_hidden_layers), + ) + self.attention_type = attn_type_list[layer_id] if layer_id < len(attn_type_list) else 1 + + # Attention layer (dispatch based on type) + if self.attention_type == 0: # Linear attention + linear_layer_id = sum(1 for i in range(layer_id) if attn_type_list[i] == 0) + self.self_attn = MiniMaxM1LinearAttention( + fd_config, + layer_id=layer_id, + linear_layer_id=linear_layer_id, + prefix=f"{prefix}.self_attn", + ) + else: # Full attention + self.self_attn = MiniMaxM1Attention( + fd_config, + layer_id=layer_id, + prefix=f"{prefix}.self_attn", + ) + + # Input layernorm (pre-norm) + self.input_layernorm = RMSNorm( + fd_config, + hidden_size=fd_config.model_config.hidden_size, + eps=fd_config.model_config.rms_norm_eps, + prefix=f"{prefix}.input_layernorm", + ) + + # Post-attention layernorm + self.post_attention_layernorm = RMSNorm( + fd_config, + hidden_size=fd_config.model_config.hidden_size, + eps=fd_config.model_config.rms_norm_eps, + prefix=f"{prefix}.post_attention_layernorm", + ) + + # DeepNorm alpha/beta scaling + self.layernorm_attention_alpha = getattr( + fd_config.model_config, "layernorm_full_attention_alpha", 3.5565588200778455 + ) + self.layernorm_attention_beta = getattr(fd_config.model_config, "layernorm_full_attention_beta", 1.0) + self.layernorm_mlp_alpha = getattr(fd_config.model_config, "layernorm_mlp_alpha", 3.5565588200778455) + self.layernorm_mlp_beta = getattr(fd_config.model_config, "layernorm_mlp_beta", 1.0) + + # FFN (MLP or MoE) + if fd_config.model_config.num_local_experts > 1: + self.mlp = MiniMaxM1MoE( + fd_config, + layer_id=layer_id, + prefix=f"{prefix}.mlp", + ) + else: + self.mlp = MiniMaxM1MLP( + fd_config, + intermediate_size=fd_config.model_config.intermediate_size, + prefix=f"{prefix}.mlp", + reduce_results=True, + ) + + def load_state_dict(self, state_dict): + self.self_attn.load_state_dict(state_dict) + self.mlp.load_state_dict(state_dict) + self.input_layernorm.load_state_dict(state_dict) + self.post_attention_layernorm.load_state_dict(state_dict) + + def forward( + self, + forward_meta: ForwardMeta, + hidden_states: paddle.Tensor, + residual: paddle.Tensor = None, + ): + """Decoder layer forward with DeepNorm.""" + # Pre-norm + hidden_states, residual = self.input_layernorm( + hidden_states, + residual_input=residual, + forward_meta=forward_meta, + ) + + # Attention (dispatch based on type) + if self.attention_type == 1: # Full attention + attn_output = self.self_attn(forward_meta=forward_meta, hidden_states=hidden_states) + else: # Linear attention + attn_output = self.self_attn(forward_meta=forward_meta, hidden_states=hidden_states) + + # DeepNorm alpha/beta scaling + residual = residual * self.layernorm_attention_alpha + attn_output = attn_output * self.layernorm_attention_beta + + # Post-attention + hidden_states, residual = self.post_attention_layernorm(attn_output, residual) + + # FFN + mlp_output = self.mlp(hidden_states, forward_meta) + + # DeepNorm MLPalpha/beta + residual = residual * self.layernorm_mlp_alpha + mlp_output = mlp_output * self.layernorm_mlp_beta + + hidden_states = residual + mlp_output + + return hidden_states, residual + + +@support_graph_optimization +class MiniMaxM1Model(nn.Layer): + """MiniMax-M1 Transformer Model""" + + def __init__(self, fd_config: FDConfig = None): + super().__init__() + + self.num_layers = fd_config.model_config.num_hidden_layers + self.hidden_size = fd_config.model_config.hidden_size + fd_config.model_config.pretrained_config.prefix_name = "model" + + # Embedding + self.embed_tokens = VocabParallelEmbedding( + fd_config, + num_embeddings=fd_config.model_config.vocab_size, + embedding_dim=fd_config.model_config.hidden_size, + prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.embed_tokens", + ) + + # Decoder layers + self.layers = nn.LayerList( + [ + MiniMaxM1DecoderLayer( + fd_config, + layer_id=i, + prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.layers.{i}", + ) + for i in range(self.num_layers) + ] + ) + + # Final layernorm + self.norm = RMSNorm( + fd_config, + hidden_size=fd_config.model_config.hidden_size, + eps=fd_config.model_config.rms_norm_eps, + prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.norm", + ) + + def load_state_dict(self, state_dict): + """Load model parameters.""" + self.embed_tokens.load_state_dict(state_dict) + self.norm.load_state_dict(state_dict) + for i in range(self.num_layers): + logger.info(f"Start load layer {i}") + self.layers[i].load_state_dict(state_dict) + + def forward( + self, + ids_remove_padding: paddle.Tensor, + forward_meta: ForwardMeta, + ): + """Model forward pass.""" + hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta) + + residual = None + + # Pass through decoder layers + for i in range(self.num_layers): + hidden_states, residual = self.layers[i]( + forward_meta=forward_meta, + hidden_states=hidden_states, + residual=residual, + ) + + # Final layernorm + hidden_states = self.norm(hidden_states, residual)[0] + + return hidden_states + + +@ModelRegistry.register_model_class( + architecture="MiniMaxText01ForCausalLM", + module_name="minimax_m1", + category=ModelCategory.TEXT_GENERATION, + primary_use=ModelCategory.TEXT_GENERATION, +) +class MiniMaxM1ForCausalLM(ModelForCasualLM): + """MiniMax-M1 Causal LM Model""" + + def __init__(self, fd_config: FDConfig): + super().__init__(fd_config) + + self.model = MiniMaxM1Model(fd_config) + self.lm_head = ParallelLMHead( + fd_config, + embedding_dim=fd_config.model_config.hidden_size, + num_embeddings=fd_config.model_config.vocab_size, + prefix="lm_head", + ) + + @classmethod + def name(cls): + """Model name.""" + return "MiniMaxText01ForCausalLM" + + @paddle.no_grad() + def set_state_dict(self, state_dict: Dict): + """Load model parameters.""" + self.model.load_state_dict(state_dict) + self.lm_head.load_state_dict(state_dict) + + def compute_logits(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta = None): + """Compute logits.""" + logits = self.lm_head(hidden_states) + logits = logits.astype(paddle.float32) + return logits + + def forward( + self, + inputs: Dict, + forward_meta: ForwardMeta, + ): + """Forward pass.""" + ids_remove_padding = inputs["ids_remove_padding"] + + hidden_states = self.model(ids_remove_padding, forward_meta) + return hidden_states + + +class MiniMaxM1PretrainedModel(PretrainedModel): + """MiniMax-M1 Pretrained Model""" + + config_class = FDConfig + + @classmethod + def arch_name(cls): + """Architecture name.""" + return "MiniMaxText01ForCausalLM" + + @classmethod + def name(cls): + """Model name.""" + return "MiniMaxText01ForCausalLM" diff --git a/fastdeploy/model_executor/ops/triton_ops/lightning_attn.py b/fastdeploy/model_executor/ops/triton_ops/lightning_attn.py new file mode 100644 index 00000000000..df90674e747 --- /dev/null +++ b/fastdeploy/model_executor/ops/triton_ops/lightning_attn.py @@ -0,0 +1,711 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import paddle +import triton +import triton.language as tl + +from fastdeploy.model_executor.ops.triton_ops.triton_utils import ( + enable_compat_on_triton_kernel, +) + +# ============================================================================= +# Triton JIT Kernels — framework-agnostic, operate on raw pointers +# ============================================================================= + + +@enable_compat_on_triton_kernel +@triton.jit +def _fwd_diag_kernel( + Q, + K, + V, + Out, + S, + b: tl.constexpr, + h: tl.constexpr, + n, + d: tl.constexpr, + e: tl.constexpr, + BLOCK: tl.constexpr, + NUM_BLOCK, + CBLOCK: tl.constexpr, +): + # This kernel computes the diagonal blocks of the attention matrix + # Each diagonal block represents attention + # where queries attend to keys in the same block + off = tl.program_id(0) + off_bh = off // NUM_BLOCK # batch-head index + off_block = off % NUM_BLOCK # block index within the sequence + off_cblock = tl.program_id(1) # sub-block index within a block + + off_h = off_bh % h # head index + + # Calculate base offsets for the current batch and head + qk_offset = off_bh * n * d + v_offset = off_bh * n * e + o_offset = off_bh * n * e + + # Calculate offsets for the current block + block_offset = off_block * BLOCK + qk_block_offset = block_offset * d + v_block_offset = block_offset * e + o_block_offset = block_offset * e + + # Calculate offsets for the current sub-block + cblock_offset = off_cblock * CBLOCK + q_cblock_offset = cblock_offset * d + o_cblock_offset = cblock_offset * e + + # Calculate pointers to the query, key, value, and output tensors + Q_block_ptr = ( + Q + + qk_offset + + qk_block_offset + + q_cblock_offset + + tl.arange(0, CBLOCK)[:, None] * d + + tl.arange(0, d)[None, :] + ) + K_trans_block_ptr = K + qk_offset + qk_block_offset + tl.arange(0, CBLOCK)[None, :] * d + tl.arange(0, d)[:, None] + V_block_ptr = V + v_offset + v_block_offset + tl.arange(0, CBLOCK)[:, None] * e + tl.arange(0, e)[None, :] + O_block_ptr = ( + Out + + o_offset + + o_block_offset + + o_cblock_offset + + tl.arange(0, CBLOCK)[:, None] * e + + tl.arange(0, e)[None, :] + ) + + # Load the decay rate for the current head + S_block_ptr = S + off_h + s = tl.load(S_block_ptr) + + i = off_cblock + q_index = tl.arange(0, CBLOCK) + i * CBLOCK + + # Load query values + q = tl.load(Q_block_ptr, mask=block_offset + q_index[:, None] < n, other=0.0).to(tl.float32) + + # Initialize output accumulator + qkv = tl.zeros([CBLOCK, e], dtype=tl.float32) + + # Process all sub-blocks up to and + # including the current one (causal attention) + for j in range(i + 1): + kv_index = tl.arange(0, CBLOCK) + j * CBLOCK + diff = q_index[:, None] - kv_index[None, :] + s_index = s * diff + # Apply causal mask: only attend to positions before the current one + s_index = tl.where(diff >= 0, -s_index, float("-inf")) + decay = tl.exp(s_index) + + # Load key and value + k_trans = tl.load( + K_trans_block_ptr, + mask=block_offset + kv_index[None, :] < n, + other=0.0, + ).to(tl.float32) + v = tl.load( + V_block_ptr, + mask=block_offset + kv_index[:, None] < n, + other=0.0, + ).to(tl.float32) + + # Compute attention scores and apply decay + qk = tl.dot(q, k_trans) * decay + + # Compute weighted values and accumulate + qkv += tl.dot(qk, v) + + # Move to the next sub-block + K_trans_block_ptr += CBLOCK * d + V_block_ptr += CBLOCK * e + + # Store the result + tl.store( + O_block_ptr, + qkv.to(O_block_ptr.dtype.element_ty), + mask=block_offset + q_index[:, None] < n, + ) + + +@enable_compat_on_triton_kernel +@triton.jit +def _fwd_kv_parallel( + K, + V, + K_decay, + KV, + b: tl.constexpr, + h: tl.constexpr, + n, + d: tl.constexpr, + e: tl.constexpr, + BLOCK: tl.constexpr, + NUM_BLOCK, + D_FBLOCK: tl.constexpr, + E_FBLOCK: tl.constexpr, + NUM_FBLOCK: tl.constexpr, + CBLOCK: tl.constexpr, + NUM_CBLOCK: tl.constexpr, +): + # This kernel computes the key-value outer + # products for each block in parallel + off_bh = tl.program_id(0) # batch-head index + off_block = tl.program_id(1) # block index + + off_h = off_bh % h # head index + + block_offset = off_block * BLOCK + + # Calculate offsets for the current block + k_block_offset = block_offset * d + v_block_offset = block_offset * e + kv_block_offset = off_block * d * e + + # Calculate base offsets for the current batch and head + k_offset = off_bh * n * d + v_offset = off_bh * n * e + kv_offset = off_bh * NUM_BLOCK * d * e + + # Calculate pointers to the key, value, and key-value tensors + K_trans_block_ptr = ( + K + k_offset + k_block_offset + tl.arange(0, CBLOCK)[None, :] * d + tl.arange(0, D_FBLOCK)[:, None] + ) + V_block_ptr = V + v_offset + v_block_offset + tl.arange(0, CBLOCK)[:, None] * e + tl.arange(0, E_FBLOCK)[None, :] + KV_block_ptr = ( + KV + kv_offset + kv_block_offset + tl.arange(0, D_FBLOCK)[:, None] * e + tl.arange(0, E_FBLOCK)[None, :] + ) + + # Load the decay factors for the current head and block + k_decay_ptr = K_decay + off_h * BLOCK + tl.arange(0, CBLOCK) + + kv_index = tl.arange(0, CBLOCK) + + # Initialize the key-value outer product accumulator + kv = tl.zeros([D_FBLOCK, E_FBLOCK], dtype=tl.float32) + + # Handle the last block which might be smaller than BLOCK + split_n = n - (NUM_BLOCK - 1) * BLOCK if off_block == NUM_BLOCK - 1 else BLOCK + left_shift = tl.cdiv(split_n, CBLOCK) * CBLOCK - split_n + num_blocks = min(tl.cdiv(split_n, CBLOCK), NUM_CBLOCK) + k_decay_ptr += (NUM_CBLOCK - num_blocks) * CBLOCK + + # Process all sub-blocks in the current block + for j in range(num_blocks): + left_bound = (1 - j) * left_shift + # Load key and value, handling boundary conditions + k_trans = tl.load( + K_trans_block_ptr - left_shift * d, + mask=kv_index[None, :] >= left_bound, + other=0.0, + ) + v = tl.load( + V_block_ptr - left_shift * e, + mask=kv_index[:, None] >= left_bound, + other=0.0, + ) + + # Load decay factor and compute weighted key-value outer product + k_decay = tl.load(k_decay_ptr) + + # NOTE: Need to add the extra dim here due to AMD MLIR lowering error. + # Please don't move it back until issue is resolved. + # Issue: https://github.com/ROCm/triton/issues/907 + k_decay = k_decay[None, :] + + kv += tl.dot(k_trans * k_decay, v) + + # Move to the next sub-block + K_trans_block_ptr += CBLOCK * d + V_block_ptr += CBLOCK * e + k_decay_ptr += CBLOCK + + # Store the result + tl.store(KV_block_ptr, kv.to(KV_block_ptr.dtype.element_ty)) + + +@enable_compat_on_triton_kernel +@triton.jit +def _fwd_kv_reduce( + S, + KV, + KV_HISTORY, + b: tl.constexpr, + h: tl.constexpr, + n, + d: tl.constexpr, + e: tl.constexpr, + BLOCK: tl.constexpr, + NUM_BLOCK, + D_FBLOCK: tl.constexpr, + E_FBLOCK: tl.constexpr, +): + # This kernel reduces the key-value outer products + # across blocks and updates the KV history + off_bh = tl.program_id(0) # batch-head index + off_h = off_bh % h # head index + + kv_offset = off_bh * NUM_BLOCK * d * e + + # Calculate pointer to the key-value tensor + KV_block_ptr = KV + kv_offset + tl.arange(0, D_FBLOCK)[:, None] * e + tl.arange(0, E_FBLOCK)[None, :] + + # Load the decay rate for the current head + s_ptrs = S + off_h + s = tl.load(s_ptrs) + + # Calculate pointer to the key-value history tensor + kv_history_offset = off_bh * d * e + KV_HISTORY_block_ptr = ( + KV_HISTORY + kv_history_offset + tl.arange(0, D_FBLOCK)[:, None] * e + tl.arange(0, E_FBLOCK)[None, :] + ) + + # Load the previous key-value history + kv_pre = tl.load(KV_HISTORY_block_ptr).to(tl.float32) + + # Process all blocks in reverse order to compute the prefix sum + for i in range(NUM_BLOCK): + block_size = min(n - i * BLOCK, BLOCK) + # Compute decay factor for the current block + block_decay = tl.exp(-s.to(tl.float32) * block_size) + + # Load the current key-value outer product + kv_cur = tl.load(KV_block_ptr).to(tl.float32) + # Store the previous key-value history to the current block + tl.store(KV_block_ptr, kv_pre.to(KV_block_ptr.dtype.element_ty)) + + # Update the key-value history with the current block + kv_pre = block_decay * kv_pre + kv_cur + KV_block_ptr += d * e + + # Store the updated key-value history + tl.store(KV_HISTORY_block_ptr, kv_pre) + + +@enable_compat_on_triton_kernel +@triton.jit +def _fwd_none_diag_kernel( + Q, + Out, + S, + KV, + b: tl.constexpr, + h: tl.constexpr, + n, + d: tl.constexpr, + e: tl.constexpr, + BLOCK: tl.constexpr, + NUM_BLOCK, + E_FBLOCK: tl.constexpr, + CBLOCK: tl.constexpr, + NUM_CBLOCK: tl.constexpr, +): + # This kernel computes the non-diagonal blocks of the attention matrix + # Each non-diagonal block represents attention + # where queries attend to keys in different blocks + off_bh = tl.program_id(0) # batch-head index + off_h = off_bh % h # head index + + off_nc = tl.program_id(1) + off_n = off_nc // NUM_CBLOCK # block index + off_c = off_nc % NUM_CBLOCK # sub-block index + off_e = tl.program_id(2) # output feature block index + + n_offset = off_n * BLOCK + c_offset = off_c * CBLOCK + e_offset = off_e * E_FBLOCK + block_offset = n_offset + c_offset + + # Calculate offsets for the current batch, head, and block + q_offset = off_bh * n * d + (n_offset + c_offset) * d + o_offset = off_bh * n * e + (n_offset + c_offset) * e + e_offset + kv_offset = off_bh * NUM_BLOCK * d * e + off_n * d * e + e_offset + + # Calculate pointers to the query, output, and key-value tensors + Q_block_ptr = Q + q_offset + tl.arange(0, CBLOCK)[:, None] * d + tl.arange(0, d)[None, :] + O_block_ptr = Out + o_offset + tl.arange(0, CBLOCK)[:, None] * e + tl.arange(0, E_FBLOCK)[None, :] + KV_block_ptr = KV + kv_offset + tl.arange(0, d)[:, None] * e + tl.arange(0, E_FBLOCK)[None, :] + + # Load the decay rate for the current head + S_block_ptr = S + off_h + s = tl.load(S_block_ptr) + + c_array = tl.arange(0, CBLOCK) + + # Load the key-value outer product for the current block + kv = tl.load(KV_block_ptr).to(tl.float32) + q_index = block_offset + tl.arange(0, CBLOCK) + + # Load query values + q = tl.load(Q_block_ptr, mask=q_index[:, None] < n, other=0.0).to(tl.float32) + + # Compute decay factors for the current sub-block + q_decay = tl.exp(-s.to(tl.float32) * (off_c * CBLOCK + c_array[:, None])) + + # Compute non-diagonal attention output + qkv_none_diag = tl.dot(q, kv) * q_decay + + # Load diagonal attention output (computed by _fwd_diag_kernel) + qkv_diag = tl.load(O_block_ptr, mask=q_index[:, None] < n, other=0.0).to(tl.float32) + + # Combine diagonal and non-diagonal attention outputs + qkv = qkv_diag + qkv_none_diag + + # Store the result + tl.store(O_block_ptr, qkv.to(O_block_ptr.dtype.element_ty), mask=q_index[:, None] < n) + + +@enable_compat_on_triton_kernel +@triton.jit +def _linear_attn_decode_kernel( + q_ptr, + k_ptr, + v_ptr, + kv_cache_ptr, + slope_rate, + slot_idx, + output_ptr, + D: tl.constexpr, + qkv_b_stride, + qkv_h_stride, + cache_b_stride, + cache_h_stride, + cache_d0_stride, + cache_d1_stride, + BLOCK_SIZE: tl.constexpr, +): + """ + Kernel for linear attention decoding with KV cache. + + This kernel computes attention for a single token using the KV cache. + """ + pid_b = tl.program_id(0) # batch index + pid_h = tl.program_id(1) # head index + pid_d = tl.program_id(2) # dimension block index + + # Load slot index for the current batch + slot_id = tl.load(slot_idx + pid_b).to(tl.int64) + + # Skip if slot_id is -1 (padding) + if slot_id == -1: + return + + batch_id = pid_b + head_id = pid_h + + # Load decay rate for the current head + ratio = tl.load(slope_rate + pid_h) + + # Calculate offsets for dimensions + qk_d_offsets = tl.arange(0, D) + v_d_offsets = tl.arange(0, BLOCK_SIZE) + pid_d * BLOCK_SIZE + cache_d_offsets = qk_d_offsets[:, None] * cache_d0_stride + v_d_offsets[None, :] * cache_d1_stride + + # Calculate offsets for the current batch and head + q_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride + k_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride + v_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride + + cache_offset = slot_id * cache_b_stride + head_id * cache_h_stride + + # Create masks for loading tensors + qk_mask = qk_d_offsets < D + v_mask = v_d_offsets < D + + # Load query, key, and value tensors + q = tl.load(q_ptr + q_offset + qk_d_offsets, mask=qk_mask, other=0.0) + k = tl.load(k_ptr + k_offset + qk_d_offsets, mask=qk_mask, other=0.0) + v = tl.load(v_ptr + v_offset + v_d_offsets, mask=v_mask, other=0.0) + + # Compute key-value outer product + kv_outer = k[:, None] * v[None, :] + kv_mask = qk_mask[:, None] & v_mask[None, :] + + # Apply decay to previous KV cache + ratio = tl.exp(-ratio) + kv_ptr = kv_cache_ptr + cache_offset + cache_d_offsets + kv_cache_old = tl.load(kv_ptr, mask=kv_mask, other=0.0) + kv_outer = kv_outer + ratio * kv_cache_old + + # Compute attention output + output = q[:, None].to(tl.float32) * kv_outer + output = tl.sum(output, axis=0) + + # Update KV cache and store output + tl.store(kv_ptr, kv_outer, mask=kv_mask) + tl.store(output_ptr + q_offset + v_d_offsets, output, mask=v_mask) + + +# ============================================================================= +# Python wrapper functions — Paddle API +# ============================================================================= + + +def lightning_attention_forward(q, k, v, s, kv_history): + """ + Forward pass of the lightning attention algorithm. + Converted from vLLM's torch.autograd.Function to a plain function + for inference-only use in FastDeploy. + + Args: + q: Query tensor [b, h, n, d] + k: Key tensor [b, h, n, d] + v: Value tensor [b, h, n, e] + s: Decay rate tensor [1, h, 1, 1] or [h] + kv_history: KV history tensor [b, h, d, e] + + Returns: + o: Output tensor [b, h, n, e] + kv_state: Updated KV state tensor + """ + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + s = s.contiguous() + + # Get input dimensions + b, h, n, d = q.shape + e = v.shape[-1] + + # Initialize output tensor + o = paddle.empty([b, h, n, e], dtype=q.dtype) + + # Set block sizes + BLOCK = 256 + NUM_BLOCK = triton.cdiv(n, BLOCK) + + CBLOCK = 32 + NUM_CBLOCK = BLOCK // CBLOCK + assert BLOCK % CBLOCK == 0, "BLOCK must be a multiple of CBLOCK" + + # Compute decay factors for keys + array = paddle.arange(0, BLOCK).astype("float32") + 1 + k_decay = paddle.exp(-s * (BLOCK - array.reshape([1, -1]))) + + # Step 1: Compute diagonal blocks of attention + grid = (b * h * NUM_BLOCK, NUM_CBLOCK) + _fwd_diag_kernel[grid]( + q, + k, + v, + o, + s, + b, + h, + n, + d, + e, + BLOCK=BLOCK, + NUM_BLOCK=NUM_BLOCK, + CBLOCK=CBLOCK, + ) + + # Set feature block sizes + NUM_FBLOCK = 1 + D_FBLOCK = d // NUM_FBLOCK + assert d % NUM_FBLOCK == 0 + E_FBLOCK = e // NUM_FBLOCK + assert e % NUM_FBLOCK == 0 + + CBLOCK = 64 + NUM_CBLOCK = BLOCK // CBLOCK + assert BLOCK % CBLOCK == 0, "BLOCK must be a multiple of CBLOCK" + + # Step 2: Compute key-value outer products for each block in parallel + kv = paddle.empty([b, h, NUM_BLOCK, d, e], dtype="float32") + grid = (b * h, NUM_BLOCK) + _fwd_kv_parallel[grid]( + k, + v, + k_decay, + kv, + b, + h, + n, + d, + e, + BLOCK=BLOCK, + NUM_BLOCK=NUM_BLOCK, + D_FBLOCK=D_FBLOCK, + E_FBLOCK=E_FBLOCK, + NUM_FBLOCK=NUM_FBLOCK, + CBLOCK=CBLOCK, + NUM_CBLOCK=NUM_CBLOCK, + ) + + # Step 3: Reduce key-value outer products + # across blocks and update KV history + grid = (b * h, NUM_FBLOCK) + _fwd_kv_reduce[grid]( + s, + kv, + kv_history, + b, + h, + n, + d, + e, + BLOCK=BLOCK, + NUM_BLOCK=NUM_BLOCK, + D_FBLOCK=D_FBLOCK, + E_FBLOCK=E_FBLOCK, + ) + + # Step 4: Compute non-diagonal blocks of attention + grid = (b * h, NUM_BLOCK * NUM_CBLOCK) + _fwd_none_diag_kernel[grid]( + q, + o, + s, + kv, + b, + h, + n, + d, + e, + BLOCK=BLOCK, + NUM_BLOCK=NUM_BLOCK, + E_FBLOCK=E_FBLOCK, + CBLOCK=CBLOCK, + NUM_CBLOCK=NUM_CBLOCK, + ) + + return o, paddle.concat([kv, kv_history.unsqueeze(2)], axis=2) + + +def lightning_attention( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + ed: paddle.Tensor, + block_size: int = 256, + kv_history: paddle.Tensor | None = None, +) -> tuple[paddle.Tensor, paddle.Tensor]: + """ + Apply lightning attention algorithm to compute attention efficiently. + + Args: + q: Query tensor of shape [batch, heads, seq_len, dim] + k: Key tensor of shape [batch, heads, seq_len, dim] + v: Value tensor of shape [batch, heads, seq_len, dim_v] + ed: Decay rate tensor of shape [heads] + block_size: Size of blocks for block-sparse attention + kv_history: Optional key-value history from previous computations + + Returns: + output: Attention output + kv: Updated key-value history + """ + d = q.shape[-1] + e = v.shape[-1] + + if ed.ndim == 1: + ed = ed.reshape([1, -1, 1, 1]) + + # Split the computation into chunks for better parallelism + m = 128 if d >= 128 else 64 + assert d % m == 0, f"Dimension d ({d}) must be divisible by m ({m})" + arr = [m * i for i in range(d // m + 1)] + if arr[-1] != d: + arr.append(d) + n = len(arr) + output = 0 + + # Initialize or clone key-value history + if kv_history is None: + kv_history = paddle.zeros([q.shape[0], q.shape[1], d, e], dtype="float32") + else: + kv_history = kv_history.clone().contiguous() + + # Process each chunk and accumulate results + for i in range(n - 1): + s = arr[i] + e = arr[i + 1] + q1 = q[..., s:e] + k1 = k[..., s:e] + o, kv = lightning_attention_forward(q1, k1, v, ed, kv_history) + output = output + o + return output, kv + + +def linear_decode_forward_triton( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + kv_caches: paddle.Tensor, + slope_rate: paddle.Tensor, + slot_idx: paddle.Tensor, + BLOCK_SIZE: int = 32, +) -> paddle.Tensor: + """ + Perform linear attention decoding using Triton kernels. + + Args: + q: Query tensor of shape [B, H, 1, D] + k: Key tensor of shape [B, H, 1, D] + v: Value tensor of shape [B, H, 1, D] + kv_caches: Key-value cache tensor + slope_rate: Decay rate tensor + slot_idx: Slot indices for batches + BLOCK_SIZE: Size of blocks for processing + + Returns: + output: Attention output tensor of shape [B, H*D] + """ + B, H, _, D = q.shape + assert k.shape == [B, H, 1, D] + assert v.shape == [B, H, 1, D] + + # Initialize output tensor + output = paddle.empty_like(q) + + # Set grid dimensions for the kernel + grid = (B, H, D // BLOCK_SIZE) + + # Calculate strides for tensors + qkv_b_stride = q.strides[0] + qkv_h_stride = q.strides[1] + + cache_b_stride = kv_caches.strides[0] + cache_h_stride = kv_caches.strides[1] + cache_d0_stride = kv_caches.strides[2] + cache_d1_stride = kv_caches.strides[3] + + # Launch the kernel + _linear_attn_decode_kernel[grid]( + q, + k, + v, + kv_caches, + slope_rate, + slot_idx, + output, + D, + qkv_b_stride, + qkv_h_stride, + cache_b_stride, + cache_h_stride, + cache_d0_stride, + cache_d1_stride, + BLOCK_SIZE=BLOCK_SIZE, + ) + + # Reshape output: "b h n d -> b n (h d)" + # output shape: [B, H, 1, D] -> transpose to [B, 1, H, D] -> reshape to [B, 1, H*D] + output = output.transpose([0, 2, 1, 3]).reshape([B, 1, -1]) + return output.squeeze(1).contiguous() diff --git a/tests/model_executor/test_minimax_m1.py b/tests/model_executor/test_minimax_m1.py new file mode 100644 index 00000000000..9333d7a7e00 --- /dev/null +++ b/tests/model_executor/test_minimax_m1.py @@ -0,0 +1,349 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for MiniMax-M1 model scaffold. +Validates architecture dispatch, slope construction, registration, and forward paths. + +Uses importlib to load minimax_m1.py directly, bypassing fastdeploy/__init__.py +which pulls in the full inference engine (etcd, Redis, GPU ops, etc.). +All heavy submodules are replaced with lightweight stubs so tests run on CPU. +""" + +import importlib +import importlib.util +import math +import sys +from types import SimpleNamespace +from unittest.mock import MagicMock + +import numpy as np +import paddle +import pytest + +# --------------------------------------------------------------------------- +# Module-level setup: load minimax_m1 with stub dependencies +# --------------------------------------------------------------------------- + +# 1) paddleformers stubs +_PretrainedModel = type("PretrainedModel", (), {}) + + +class _PretrainedConfig: + prefix_name = "" + + @classmethod + def get_config_dict(cls, model_path, **kw): + import json as _j + import os as _o + + with open(_o.path.join(model_path, "config.json")) as f: + return _j.load(f), {} + + @classmethod + def from_dict(cls, d): + ns = SimpleNamespace(**d) + ns.prefix_name = "" + return ns + + +_cfg_mod = MagicMock() +_cfg_mod.PretrainedConfig = _PretrainedConfig +_transf = MagicMock() +_transf.PretrainedModel = _PretrainedModel +_transf.configuration_utils = _cfg_mod +_transf.PretrainedConfig = _PretrainedConfig + +sys.modules.setdefault("paddleformers", MagicMock()) +sys.modules["paddleformers.transformers"] = _transf +sys.modules["paddleformers.transformers.configuration_utils"] = _cfg_mod +sys.modules.setdefault("paddleformers.utils", MagicMock()) +sys.modules["paddleformers.utils.log"] = MagicMock() + +# 2) Lightweight fastdeploy namespace (bypass __init__.py) +_fd_ns = type(sys)("fastdeploy") +_fd_ns.__path__ = ["fastdeploy"] +_fd_ns.__file__ = "fastdeploy/__init__.py" +sys.modules["fastdeploy"] = _fd_ns + +for _pkg, _path in [ + ("fastdeploy.model_executor", "fastdeploy/model_executor"), + ("fastdeploy.model_executor.models", "fastdeploy/model_executor/models"), +]: + _m = type(sys)(_pkg) + _m.__path__ = [_path] + sys.modules[_pkg] = _m + +# 3) Mock all heavy fastdeploy submodules +for _mod_name in [ + "fastdeploy.config", + "fastdeploy.distributed", + "fastdeploy.distributed.communication", + "fastdeploy.model_executor.forward_meta", + "fastdeploy.model_executor.graph_optimization", + "fastdeploy.model_executor.graph_optimization.decorator", + "fastdeploy.model_executor.layers", + "fastdeploy.model_executor.layers.activation", + "fastdeploy.model_executor.layers.attention", + "fastdeploy.model_executor.layers.attention.attention", + "fastdeploy.model_executor.layers.embeddings", + "fastdeploy.model_executor.layers.linear", + "fastdeploy.model_executor.layers.lm_head", + "fastdeploy.model_executor.layers.moe", + "fastdeploy.model_executor.layers.moe.moe", + "fastdeploy.model_executor.layers.normalization", + "fastdeploy.model_executor.models.model_base", + "fastdeploy.model_executor.ops", + "fastdeploy.model_executor.ops.triton_ops", + "fastdeploy.model_executor.ops.triton_ops.lightning_attn", +]: + if _mod_name not in sys.modules: + sys.modules[_mod_name] = MagicMock() + + +# 4) Real ModelRegistry so @register_model_class works +class _ModelCategory: + TEXT_GENERATION = "text_generation" + + +class _ModelRegistry: + _arch_to_model_cls = {} + _enhanced_models = {} + + @classmethod + def register_model_class(cls, model_class=None, **kw): + def _register(mc): + cls._arch_to_model_cls[mc.name()] = mc + return mc + + return _register(model_class) if model_class is not None else _register + + +_ModelForCasualLM = type("ModelForCasualLM", (), {"name": classmethod(lambda cls: "base")}) + +_mb = sys.modules["fastdeploy.model_executor.models.model_base"] +_mb.ModelCategory = _ModelCategory +_mb.ModelRegistry = _ModelRegistry +_mb.ModelForCasualLM = _ModelForCasualLM + +# support_graph_optimization → identity +sys.modules["fastdeploy.model_executor.graph_optimization.decorator"].support_graph_optimization = lambda cls: cls + +# 5) Load minimax_m1.py via importlib +_spec = importlib.util.spec_from_file_location( + "fastdeploy.model_executor.models.minimax_m1", + "fastdeploy/model_executor/models/minimax_m1.py", + submodule_search_locations=[], +) +_mod = importlib.util.module_from_spec(_spec) +sys.modules[_spec.name] = _mod +_spec.loader.exec_module(_mod) + +# Import symbols +MiniMaxM1DecoderLayer = _mod.MiniMaxM1DecoderLayer +MiniMaxM1LinearAttention = _mod.MiniMaxM1LinearAttention +MiniMaxM1Attention = _mod.MiniMaxM1Attention +MiniMaxM1MoE = _mod.MiniMaxM1MoE +MiniMaxM1MLP = _mod.MiniMaxM1MLP +MiniMaxM1ForCausalLM = _mod.MiniMaxM1ForCausalLM +MiniMaxM1PretrainedModel = _mod.MiniMaxM1PretrainedModel +MiniMaxM1Model = _mod.MiniMaxM1Model +ModelRegistry = _ModelRegistry + + +# =================================================================== +# 1. Pure-logic tests +# =================================================================== + + +class TestBuildAttnTypeList: + + def test_80_layers_has_10_full_attention(self): + attn_list = MiniMaxM1DecoderLayer._build_attn_type_list(80) + assert len(attn_list) == 80 + full_indices = [i for i, t in enumerate(attn_list) if t == 1] + assert full_indices == [7, 15, 23, 31, 39, 47, 55, 63, 71, 79] + + def test_short_model_clips_indices(self): + attn_list = MiniMaxM1DecoderLayer._build_attn_type_list(10) + assert len(attn_list) == 10 + assert attn_list[7] == 1 + assert sum(attn_list) == 1 + + def test_single_layer_all_linear(self): + assert MiniMaxM1DecoderLayer._build_attn_type_list(1) == [0] + + def test_all_linear_below_first_full_index(self): + assert all(t == 0 for t in MiniMaxM1DecoderLayer._build_attn_type_list(7)) + + +class TestBuildSlopeTensor: + + def test_power_of_two_heads(self): + slopes = MiniMaxM1LinearAttention._build_slope_tensor(8) + assert slopes.shape == [8, 1, 1] + assert (slopes.flatten().numpy() > 0).all() + + def test_non_power_of_two_heads(self): + slopes = MiniMaxM1LinearAttention._build_slope_tensor(12) + assert slopes.shape == [12, 1, 1] + assert (slopes.flatten().numpy() > 0).all() + + def test_64_heads_first_slope(self): + slopes = MiniMaxM1LinearAttention._build_slope_tensor(64) + assert slopes.shape == [64, 1, 1] + expected_start = 2 ** (-(2 ** (-(math.log2(64) - 3)))) + np.testing.assert_allclose(slopes.flatten().numpy()[0], expected_start, rtol=1e-5) + + @pytest.mark.parametrize("n", [1, 2, 4, 8, 16, 32, 64]) + def test_slopes_all_positive(self, n): + slopes = MiniMaxM1LinearAttention._build_slope_tensor(n) + assert (slopes.flatten().numpy() > 0).all() + + +# =================================================================== +# 2. Model registration +# =================================================================== + + +class TestModelRegistration: + + def test_architecture_registered(self): + assert "MiniMaxText01ForCausalLM" in ModelRegistry._arch_to_model_cls + + def test_registered_class(self): + assert ModelRegistry._arch_to_model_cls["MiniMaxText01ForCausalLM"] is MiniMaxM1ForCausalLM + + def test_name_method(self): + assert MiniMaxM1ForCausalLM.name() == "MiniMaxText01ForCausalLM" + + def test_pretrained_name(self): + assert MiniMaxM1PretrainedModel.arch_name() == "MiniMaxText01ForCausalLM" + assert MiniMaxM1PretrainedModel.name() == "MiniMaxText01ForCausalLM" + + +# =================================================================== +# 3. Layer construction (lightweight fd_config mock) +# =================================================================== + + +def _make_fd_config(num_layers=4, attn_type_list=None, num_local_experts=4): + if attn_type_list is None: + attn_type_list = [0, 0, 0, 1][:num_layers] + mc = SimpleNamespace( + hidden_size=256, + intermediate_size=512, + num_hidden_layers=num_layers, + num_attention_heads=8, + num_key_value_heads=2, + head_dim=32, + vocab_size=1024, + rms_norm_eps=1e-6, + hidden_act="silu", + num_local_experts=num_local_experts, + num_experts_per_tok=2, + norm_topk_prob=False, + postnorm=False, + attn_type_list=attn_type_list, + layernorm_full_attention_alpha=3.556, + layernorm_full_attention_beta=1.0, + layernorm_mlp_alpha=3.556, + layernorm_mlp_beta=1.0, + pretrained_config=SimpleNamespace(prefix_name="model"), + ) + pc = SimpleNamespace(tensor_parallel_size=1, tp_group=None) + return SimpleNamespace(model_config=mc, parallel_config=pc) + + +class TestDecoderLayerConstruction: + + def test_linear_attention_layer(self): + fd = _make_fd_config() + layer = MiniMaxM1DecoderLayer(fd, layer_id=0, prefix="model.layers.0") + assert layer.attention_type == 0 + assert isinstance(layer.self_attn, MiniMaxM1LinearAttention) + assert hasattr(layer.self_attn, "slope_rate") + + def test_full_attention_layer(self): + fd = _make_fd_config() + layer = MiniMaxM1DecoderLayer(fd, layer_id=3, prefix="model.layers.3") + assert layer.attention_type == 1 + assert isinstance(layer.self_attn, MiniMaxM1Attention) + + def test_deepnorm_defaults(self): + fd = _make_fd_config() + layer = MiniMaxM1DecoderLayer(fd, layer_id=0, prefix="model.layers.0") + assert layer.layernorm_attention_alpha == 3.556 + assert layer.layernorm_mlp_alpha == 3.556 + + def test_moe_when_experts_gt_1(self): + fd = _make_fd_config(num_local_experts=4) + layer = MiniMaxM1DecoderLayer(fd, layer_id=0, prefix="model.layers.0") + assert isinstance(layer.mlp, MiniMaxM1MoE) + + def test_dense_mlp_when_single_expert(self): + fd = _make_fd_config(num_local_experts=1) + layer = MiniMaxM1DecoderLayer(fd, layer_id=0, prefix="model.layers.0") + assert isinstance(layer.mlp, MiniMaxM1MLP) + + def test_fallback_attn_type_when_no_config(self): + fd = _make_fd_config(num_layers=80) + delattr(fd.model_config, "attn_type_list") + layer = MiniMaxM1DecoderLayer(fd, layer_id=7, prefix="model.layers.7") + assert layer.attention_type == 1 + + +# =================================================================== +# 4. Forward-pass smoke tests +# =================================================================== + + +class TestDecoderLayerForward: + + @staticmethod + def _patch_layer(layer, hidden_size): + def _norm_fn(x, residual_input=None, forward_meta=None): + if residual_input is None: + residual_input = paddle.zeros_like(x) + return x, residual_input + x + + object.__setattr__(layer, "input_layernorm", MagicMock(side_effect=_norm_fn)) + object.__setattr__(layer, "post_attention_layernorm", MagicMock(side_effect=_norm_fn)) + object.__setattr__(layer, "self_attn", MagicMock(return_value=paddle.randn([4, hidden_size]))) + object.__setattr__(layer, "mlp", MagicMock(return_value=paddle.randn([4, hidden_size]))) + + def test_linear_layer_returns_tuple(self): + fd = _make_fd_config() + layer = MiniMaxM1DecoderLayer(fd, layer_id=0, prefix="model.layers.0") + self._patch_layer(layer, 256) + + out = layer(forward_meta=SimpleNamespace(), hidden_states=paddle.randn([4, 256])) + assert isinstance(out, tuple) and len(out) == 2 + assert out[0].shape == [4, 256] + + def test_full_attn_layer_returns_tuple(self): + fd = _make_fd_config() + layer = MiniMaxM1DecoderLayer(fd, layer_id=3, prefix="model.layers.3") + self._patch_layer(layer, 256) + + out = layer(forward_meta=SimpleNamespace(), hidden_states=paddle.randn([4, 256])) + assert isinstance(out, tuple) and len(out) == 2 + + def test_deepnorm_scaling_applied(self): + fd = _make_fd_config() + fd.model_config.layernorm_full_attention_alpha = 2.0 + fd.model_config.layernorm_mlp_alpha = 3.0 + layer = MiniMaxM1DecoderLayer(fd, layer_id=0, prefix="model.layers.0") + assert layer.layernorm_attention_alpha == 2.0 + assert layer.layernorm_mlp_alpha == 3.0 From d10614727767069f56b9dbf759b07291f18bce0f Mon Sep 17 00:00:00 2001 From: cloudforge1 Date: Tue, 24 Mar 2026 19:32:55 +0800 Subject: [PATCH 2/3] fix: align HF weight keys, add output_gate/norm to linear attn, implement load_weights MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - LinearAttention: add output_gate (sigmoid gating), norm (RMSNorm), rename o_proj → out_proj. Forward: SiLU on QKV → lightning_attn → norm → gate → out_proj - DecoderLayer: rename self.mlp → self.block_sparse_moe to match HF config - DeepNorm: branch alpha/beta on attention_type (linear vs full) - Postnorm: add two code paths following vLLM reference - KV state: persist _kv_history across forward calls - Dual registration: MiniMaxM1ForCausalLM + MiniMaxText01ForCausalLM - set_state_dict: preprocess HF keys (w1→gate_proj, w3→up_proj, w2→down_proj, q/k/v→qkv_proj concatenation) - load_weights: v1 loader with stacked_params_mapping + expert_params_mapping - Tests: 29/29 passing --- .../model_executor/models/minimax_m1.py | 284 ++++++++++++++---- tests/model_executor/test_minimax_m1.py | 39 ++- 2 files changed, 261 insertions(+), 62 deletions(-) diff --git a/fastdeploy/model_executor/models/minimax_m1.py b/fastdeploy/model_executor/models/minimax_m1.py index 8b558d56c52..bad38c304f4 100644 --- a/fastdeploy/model_executor/models/minimax_m1.py +++ b/fastdeploy/model_executor/models/minimax_m1.py @@ -21,8 +21,10 @@ from __future__ import annotations import math -from typing import Dict +import re +from typing import Dict, Union +import numpy as np import paddle from paddle import nn from paddleformers.transformers import PretrainedModel @@ -223,26 +225,44 @@ def __init__( self.hidden_size = fd_config.model_config.hidden_size self.head_dim = fd_config.model_config.head_dim self.num_attention_heads = fd_config.model_config.num_attention_heads + hidden_inner = self.num_attention_heads * self.head_dim # QKV projection self.qkv_proj = ColumnParallelLinear( fd_config=fd_config, prefix=f"{prefix}.qkv_proj", input_size=self.hidden_size, - output_size=self.num_attention_heads * self.head_dim * 3, + output_size=hidden_inner * 3, with_bias=False, ) - # Output projection - self.o_proj = RowParallelLinear( + # Output gate (sigmoid gating on attention output) + self.output_gate = ColumnParallelLinear( + fd_config=fd_config, + prefix=f"{prefix}.output_gate", + input_size=self.hidden_size, + output_size=hidden_inner, + with_bias=False, + ) + + # Output projection (HF name: out_proj) + self.out_proj = RowParallelLinear( fd_config, - prefix=f"{prefix}.o_proj", - input_size=self.num_attention_heads * self.head_dim, + prefix=f"{prefix}.out_proj", + input_size=hidden_inner, output_size=self.hidden_size, with_bias=False, layer_id=layer_id, ) + # RMSNorm on attention output before gating + self.norm = RMSNorm( + fd_config, + hidden_size=hidden_inner, + eps=1e-5, + prefix=f"{prefix}.norm", + ) + # Build slope tensor for exponential decay slope_tensor = self._build_slope_tensor(self.num_attention_heads) if fd_config.model_config.num_hidden_layers <= 1: @@ -257,7 +277,9 @@ def __init__( def load_state_dict(self, state_dict): self.qkv_proj.load_state_dict(state_dict) - self.o_proj.load_state_dict(state_dict) + self.output_gate.load_state_dict(state_dict) + self.out_proj.load_state_dict(state_dict) + self.norm.load_state_dict(state_dict) @staticmethod def _build_slope_tensor(n_heads: int): @@ -281,17 +303,16 @@ def forward( forward_meta: ForwardMeta, hidden_states: paddle.Tensor, ): - """Linear attention forward.""" + """Linear attention forward with output gating.""" # Project QKV qkv = self.qkv_proj(hidden_states) - q, k, v = qkv.split( - [ - self.num_attention_heads * self.head_dim, - self.num_attention_heads * self.head_dim, - self.num_attention_heads * self.head_dim, - ], - axis=-1, - ) + hidden_inner = self.num_attention_heads * self.head_dim + q, k, v = qkv.split([hidden_inner, hidden_inner, hidden_inner], axis=-1) + + # Apply SiLU activation (matches HF MiniMax convention) + q = paddle.nn.functional.silu(q.astype("float32")) + k = paddle.nn.functional.silu(k.astype("float32")) + v = paddle.nn.functional.silu(v.astype("float32")) # Reshape for lightning attention batch_size = q.shape[0] @@ -304,23 +325,29 @@ def forward( k = k.transpose([0, 2, 1, 3]) v = v.transpose([0, 2, 1, 3]) - # Initialize KV history if needed - kv_history = paddle.zeros( - [batch_size, self.num_attention_heads, self.head_dim, self.head_dim], - dtype=q.dtype, - ) + # Retrieve or initialize KV history for recurrent state persistence + if not hasattr(self, "_kv_history") or self._kv_history is None: + self._kv_history = paddle.zeros( + [batch_size, self.num_attention_heads, self.head_dim, self.head_dim], + dtype=q.dtype, + ) # Apply lightning attention - attn_output, _ = lightning_attention( - q, k, v, self.slope_rate.squeeze(-1), block_size=256, kv_history=kv_history + attn_output, new_kv_history = lightning_attention( + q, k, v, self.slope_rate.squeeze(-1), block_size=256, kv_history=self._kv_history ) + # Update persisted KV state for next token generation + self._kv_history = new_kv_history - # Reshape back + # Reshape back to [batch, seq, hidden_inner] attn_output = attn_output.transpose([0, 2, 1, 3]) attn_output = attn_output.reshape([batch_size, -1, self.num_attention_heads * self.head_dim]) - # Output projection - output = self.o_proj(attn_output) + # Norm → gate → output projection (matches vLLM/HF forward) + attn_output = self.norm(attn_output)[0] + gate = self.output_gate(hidden_states) + attn_output = paddle.nn.functional.sigmoid(gate) * attn_output.astype(hidden_states.dtype) + output = self.out_proj(attn_output) return output @@ -391,23 +418,29 @@ def __init__( prefix=f"{prefix}.post_attention_layernorm", ) - # DeepNorm alpha/beta scaling - self.layernorm_attention_alpha = getattr( - fd_config.model_config, "layernorm_full_attention_alpha", 3.5565588200778455 - ) - self.layernorm_attention_beta = getattr(fd_config.model_config, "layernorm_full_attention_beta", 1.0) + # DeepNorm alpha/beta scaling — separate coefficients for linear vs full attention + if self.attention_type == 0: # Linear attention + self.layernorm_attention_alpha = getattr( + fd_config.model_config, "layernorm_linear_attention_alpha", 3.5565588200778455 + ) + self.layernorm_attention_beta = getattr(fd_config.model_config, "layernorm_linear_attention_beta", 1.0) + else: # Full attention + self.layernorm_attention_alpha = getattr( + fd_config.model_config, "layernorm_full_attention_alpha", 3.5565588200778455 + ) + self.layernorm_attention_beta = getattr(fd_config.model_config, "layernorm_full_attention_beta", 1.0) self.layernorm_mlp_alpha = getattr(fd_config.model_config, "layernorm_mlp_alpha", 3.5565588200778455) self.layernorm_mlp_beta = getattr(fd_config.model_config, "layernorm_mlp_beta", 1.0) # FFN (MLP or MoE) if fd_config.model_config.num_local_experts > 1: - self.mlp = MiniMaxM1MoE( + self.block_sparse_moe = MiniMaxM1MoE( fd_config, layer_id=layer_id, - prefix=f"{prefix}.mlp", + prefix=f"{prefix}.block_sparse_moe", ) else: - self.mlp = MiniMaxM1MLP( + self.block_sparse_moe = MiniMaxM1MLP( fd_config, intermediate_size=fd_config.model_config.intermediate_size, prefix=f"{prefix}.mlp", @@ -416,7 +449,7 @@ def __init__( def load_state_dict(self, state_dict): self.self_attn.load_state_dict(state_dict) - self.mlp.load_state_dict(state_dict) + self.block_sparse_moe.load_state_dict(state_dict) self.input_layernorm.load_state_dict(state_dict) self.post_attention_layernorm.load_state_dict(state_dict) @@ -426,31 +459,49 @@ def forward( hidden_states: paddle.Tensor, residual: paddle.Tensor = None, ): - """Decoder layer forward with DeepNorm.""" - # Pre-norm + """Decoder layer forward with DeepNorm. + + When postnorm=True (MiniMax-M1 default), the residual stream carries the + *normed* activations rather than the pre-norm sum. This follows the + vLLM reference: ``residual = layernorm_output if postnorm else layernorm_input``. + """ + # Input layernorm (fused: x + residual → norm) hidden_states, residual = self.input_layernorm( hidden_states, residual_input=residual, forward_meta=forward_meta, ) + # hidden_states = norm(input + prev_residual) + # residual = input + prev_residual (pre-norm) + if self.postnorm: + residual = hidden_states # postnorm: residual = normed output # Attention (dispatch based on type) - if self.attention_type == 1: # Full attention - attn_output = self.self_attn(forward_meta=forward_meta, hidden_states=hidden_states) - else: # Linear attention - attn_output = self.self_attn(forward_meta=forward_meta, hidden_states=hidden_states) + attn_output = self.self_attn(forward_meta=forward_meta, hidden_states=hidden_states) # DeepNorm alpha/beta scaling residual = residual * self.layernorm_attention_alpha attn_output = attn_output * self.layernorm_attention_beta - # Post-attention - hidden_states, residual = self.post_attention_layernorm(attn_output, residual) + # Post-attention layernorm + if self.postnorm: + layernorm_input = residual + attn_output + hidden_states, residual = self.post_attention_layernorm( + layernorm_input, + forward_meta=forward_meta, + ) + residual = hidden_states # postnorm: residual = normed output + else: + hidden_states, residual = self.post_attention_layernorm( + attn_output, + residual_input=residual, + forward_meta=forward_meta, + ) # FFN - mlp_output = self.mlp(hidden_states, forward_meta) + mlp_output = self.block_sparse_moe(hidden_states, forward_meta) - # DeepNorm MLPalpha/beta + # DeepNorm MLP alpha/beta residual = residual * self.layernorm_mlp_alpha mlp_output = mlp_output * self.layernorm_mlp_beta @@ -530,6 +581,12 @@ def forward( return hidden_states +@ModelRegistry.register_model_class( + architecture="MiniMaxM1ForCausalLM", + module_name="minimax_m1", + category=ModelCategory.TEXT_GENERATION, + primary_use=ModelCategory.TEXT_GENERATION, +) @ModelRegistry.register_model_class( architecture="MiniMaxText01ForCausalLM", module_name="minimax_m1", @@ -539,6 +596,16 @@ def forward( class MiniMaxM1ForCausalLM(ModelForCasualLM): """MiniMax-M1 Causal LM Model""" + # Mapping HF checkpoint names → FD merged parameter names. + # For full attention layers: separate q/k/v → merged qkv_proj + # For MoE: gate_proj/up_proj → merged gate_up_proj (dense MLP fallback) + _STACKED_PARAMS_MAPPING = [ + # (fd_param_name, hf_weight_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + def __init__(self, fd_config: FDConfig): super().__init__(fd_config) @@ -553,13 +620,126 @@ def __init__(self, fd_config: FDConfig): @classmethod def name(cls): """Model name.""" - return "MiniMaxText01ForCausalLM" + return "MiniMaxM1ForCausalLM" @paddle.no_grad() - def set_state_dict(self, state_dict: Dict): - """Load model parameters.""" - self.model.load_state_dict(state_dict) - self.lm_head.load_state_dict(state_dict) + def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray, paddle.Tensor]]): + """Load model parameters (v0 loader path). + + Pre-processes HF weight keys to match FD naming conventions, then + delegates to sub-layer ``load_state_dict`` calls. + """ + renamed: Dict[str, Union[np.ndarray, paddle.Tensor]] = {} + # Collect full-attention q/k/v weights for merging into qkv_proj + qkv_buffers: Dict[str, Dict[str, Union[np.ndarray, paddle.Tensor]]] = {} + + for name, weight in list(state_dict.items()): + # Expert weights: w1→gate_proj, w3→up_proj, w2→down_proj + if "block_sparse_moe.experts." in name: + name = re.sub(r"\.w1\.weight$", ".gate_proj.weight", name) + name = re.sub(r"\.w3\.weight$", ".up_proj.weight", name) + name = re.sub(r"\.w2\.weight$", ".down_proj.weight", name) + renamed[name] = weight + # Full attention: merge separate q/k/v into qkv_proj + elif ".self_attn.q_proj." in name or ".self_attn.k_proj." in name or ".self_attn.v_proj." in name: + # Extract layer prefix: e.g. "model.layers.7.self_attn" + prefix_match = re.match(r"(.*\.self_attn)\.(q|k|v)_proj\.weight$", name) + if prefix_match: + attn_prefix = prefix_match.group(1) + proj_type = prefix_match.group(2) + if attn_prefix not in qkv_buffers: + qkv_buffers[attn_prefix] = {} + qkv_buffers[attn_prefix][proj_type] = weight + else: + renamed[name] = weight + else: + renamed[name] = weight + + # Merge q/k/v into qkv_proj for full attention layers + for attn_prefix, projections in qkv_buffers.items(): + if "q" in projections and "k" in projections and "v" in projections: + q_w = projections["q"] + k_w = projections["k"] + v_w = projections["v"] + if isinstance(q_w, np.ndarray): + merged = np.concatenate([q_w, k_w, v_w], axis=0) + else: + merged = paddle.concat([q_w, k_w, v_w], axis=0) + renamed[f"{attn_prefix}.qkv_proj.weight"] = merged + + self.model.load_state_dict(renamed) + self.lm_head.load_state_dict(renamed) + + @paddle.no_grad() + def load_weights(self, weights_iterator) -> None: + """Load model parameters from a weights iterator (v1 loader path). + + Handles HF→FD name mapping for: + - Full attention: q_proj/k_proj/v_proj → qkv_proj (stacked) + - MoE experts: w1/w3 → up_gate_proj, w2 → down_proj + """ + from fastdeploy.model_executor.utils import ( + default_weight_loader, + process_weights_after_loading, + ) + + stacked_params_mapping = list(self._STACKED_PARAMS_MAPPING) + + # Expert weight mapping: HF w1/w2/w3 → FD up_gate_proj/down_proj + n_experts = getattr(self.fd_config.model_config, "num_local_experts", 1) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + num_experts=n_experts, + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + param_gate_up_proj_name="experts.up_gate_proj_", + param_down_proj_name="experts.down_proj_", + ) + + params_dict = dict(self.named_parameters()) + process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()), self.fd_config) + + for loaded_weight_name, loaded_weight in weights_iterator: + logger.debug(f"Loading weight: {loaded_weight_name}") + + # Stacked params (q/k/v → qkv_proj) + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in loaded_weight_name: + continue + # Skip expert weights — handled separately + if "block_sparse_moe.experts." in loaded_weight_name: + continue + model_param_name = loaded_weight_name.replace(weight_name, param_name) + if model_param_name not in params_dict: + continue + param = params_dict[model_param_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) + weight_loader(param, loaded_weight, shard_id) + break + else: + # Expert params (w1/w2/w3 → up_gate_proj/down_proj) + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in loaded_weight_name: + continue + model_param_name = loaded_weight_name.replace(weight_name, param_name) + if model_param_name not in params_dict: + continue + param = params_dict[model_param_name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id=shard_id, expert_id=expert_id) + break + else: + # Direct loading (norm, embed, lm_head, output_gate, out_proj, etc.) + model_param_name = loaded_weight_name + if model_param_name not in params_dict: + continue + param = params_dict[model_param_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) + weight_loader(param, loaded_weight) + + model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name) + process_weights_after_loading_fn(model_sublayer_name, param) def compute_logits(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta = None): """Compute logits.""" @@ -587,9 +767,9 @@ class MiniMaxM1PretrainedModel(PretrainedModel): @classmethod def arch_name(cls): """Architecture name.""" - return "MiniMaxText01ForCausalLM" + return "MiniMaxM1ForCausalLM" @classmethod def name(cls): """Model name.""" - return "MiniMaxText01ForCausalLM" + return "MiniMaxM1ForCausalLM" diff --git a/tests/model_executor/test_minimax_m1.py b/tests/model_executor/test_minimax_m1.py index 9333d7a7e00..5de24a8a65c 100644 --- a/tests/model_executor/test_minimax_m1.py +++ b/tests/model_executor/test_minimax_m1.py @@ -124,7 +124,8 @@ class _ModelRegistry: @classmethod def register_model_class(cls, model_class=None, **kw): def _register(mc): - cls._arch_to_model_cls[mc.name()] = mc + arch = kw.get("architecture", mc.name()) + cls._arch_to_model_cls[arch] = mc return mc return _register(model_class) if model_class is not None else _register @@ -219,18 +220,21 @@ def test_slopes_all_positive(self, n): class TestModelRegistration: - def test_architecture_registered(self): + def test_primary_architecture_registered(self): + assert "MiniMaxM1ForCausalLM" in ModelRegistry._arch_to_model_cls + + def test_alias_architecture_registered(self): assert "MiniMaxText01ForCausalLM" in ModelRegistry._arch_to_model_cls def test_registered_class(self): - assert ModelRegistry._arch_to_model_cls["MiniMaxText01ForCausalLM"] is MiniMaxM1ForCausalLM + assert ModelRegistry._arch_to_model_cls["MiniMaxM1ForCausalLM"] is MiniMaxM1ForCausalLM def test_name_method(self): - assert MiniMaxM1ForCausalLM.name() == "MiniMaxText01ForCausalLM" + assert MiniMaxM1ForCausalLM.name() == "MiniMaxM1ForCausalLM" def test_pretrained_name(self): - assert MiniMaxM1PretrainedModel.arch_name() == "MiniMaxText01ForCausalLM" - assert MiniMaxM1PretrainedModel.name() == "MiniMaxText01ForCausalLM" + assert MiniMaxM1PretrainedModel.arch_name() == "MiniMaxM1ForCausalLM" + assert MiniMaxM1PretrainedModel.name() == "MiniMaxM1ForCausalLM" # =================================================================== @@ -258,6 +262,8 @@ def _make_fd_config(num_layers=4, attn_type_list=None, num_local_experts=4): attn_type_list=attn_type_list, layernorm_full_attention_alpha=3.556, layernorm_full_attention_beta=1.0, + layernorm_linear_attention_alpha=3.556, + layernorm_linear_attention_beta=1.0, layernorm_mlp_alpha=3.556, layernorm_mlp_beta=1.0, pretrained_config=SimpleNamespace(prefix_name="model"), @@ -274,6 +280,9 @@ def test_linear_attention_layer(self): assert layer.attention_type == 0 assert isinstance(layer.self_attn, MiniMaxM1LinearAttention) assert hasattr(layer.self_attn, "slope_rate") + assert hasattr(layer.self_attn, "output_gate") + assert hasattr(layer.self_attn, "norm") + assert hasattr(layer.self_attn, "out_proj") def test_full_attention_layer(self): fd = _make_fd_config() @@ -290,12 +299,12 @@ def test_deepnorm_defaults(self): def test_moe_when_experts_gt_1(self): fd = _make_fd_config(num_local_experts=4) layer = MiniMaxM1DecoderLayer(fd, layer_id=0, prefix="model.layers.0") - assert isinstance(layer.mlp, MiniMaxM1MoE) + assert isinstance(layer.block_sparse_moe, MiniMaxM1MoE) def test_dense_mlp_when_single_expert(self): fd = _make_fd_config(num_local_experts=1) layer = MiniMaxM1DecoderLayer(fd, layer_id=0, prefix="model.layers.0") - assert isinstance(layer.mlp, MiniMaxM1MLP) + assert isinstance(layer.block_sparse_moe, MiniMaxM1MLP) def test_fallback_attn_type_when_no_config(self): fd = _make_fd_config(num_layers=80) @@ -321,7 +330,7 @@ def _norm_fn(x, residual_input=None, forward_meta=None): object.__setattr__(layer, "input_layernorm", MagicMock(side_effect=_norm_fn)) object.__setattr__(layer, "post_attention_layernorm", MagicMock(side_effect=_norm_fn)) object.__setattr__(layer, "self_attn", MagicMock(return_value=paddle.randn([4, hidden_size]))) - object.__setattr__(layer, "mlp", MagicMock(return_value=paddle.randn([4, hidden_size]))) + object.__setattr__(layer, "block_sparse_moe", MagicMock(return_value=paddle.randn([4, hidden_size]))) def test_linear_layer_returns_tuple(self): fd = _make_fd_config() @@ -342,8 +351,18 @@ def test_full_attn_layer_returns_tuple(self): def test_deepnorm_scaling_applied(self): fd = _make_fd_config() - fd.model_config.layernorm_full_attention_alpha = 2.0 + fd.model_config.layernorm_linear_attention_alpha = 2.0 fd.model_config.layernorm_mlp_alpha = 3.0 layer = MiniMaxM1DecoderLayer(fd, layer_id=0, prefix="model.layers.0") assert layer.layernorm_attention_alpha == 2.0 assert layer.layernorm_mlp_alpha == 3.0 + + def test_postnorm_forward(self): + fd = _make_fd_config() + fd.model_config.postnorm = True + layer = MiniMaxM1DecoderLayer(fd, layer_id=0, prefix="model.layers.0") + self._patch_layer(layer, 256) + + out = layer(forward_meta=SimpleNamespace(), hidden_states=paddle.randn([4, 256])) + assert isinstance(out, tuple) and len(out) == 2 + assert out[0].shape == [4, 256] From e068f019ad451351295b99a5cbedf4f22156e009 Mon Sep 17 00:00:00 2001 From: cloudforge1 Date: Wed, 25 Mar 2026 01:33:39 +0800 Subject: [PATCH 3/3] Add low-bit quantization support for MoE layer - Quantization-aware weight_key_map in MiniMaxM1MoE (w4a8, w4afp8 static/dynamic, tensor_wise_fp8, block_wise_fp8) mirroring Ernie4_5_MoE - Gate layer uses skip_quant=True, weight_dtype='float32' - set_state_dict v0 loader: quant-aware regex for expert weights (.quant_weight, .weight_scale, .activation_scale) - set_state_dict v0 loader: quant-aware qkv merge (suffix-keyed buffers) - 3 new tests: default/w4a8/w4afp8-dynamic weight_key_map branches --- .../model_executor/models/minimax_m1.py | 67 +++++++++++++++---- tests/model_executor/test_minimax_m1.py | 37 ++++++++++ 2 files changed, 90 insertions(+), 14 deletions(-) diff --git a/fastdeploy/model_executor/models/minimax_m1.py b/fastdeploy/model_executor/models/minimax_m1.py index bad38c304f4..f7148e0bfb0 100644 --- a/fastdeploy/model_executor/models/minimax_m1.py +++ b/fastdeploy/model_executor/models/minimax_m1.py @@ -104,7 +104,7 @@ def forward(self, x, forward_meta=None): class MiniMaxM1MoE(nn.Layer): - """MiniMax-M1 MoE Layer""" + """MiniMax-M1 MoE Layer with low-bit quantization support.""" def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str) -> None: super().__init__() @@ -112,10 +112,43 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str) -> None: self.tp_size = fd_config.parallel_config.tensor_parallel_size self.norm_topk_prob = getattr(fd_config.model_config, "norm_topk_prob", False) - weight_key_map = { - "up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight", - "down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight", - } + # Build quantization-aware weight key map (mirrors Ernie4_5_MoE pattern) + moe_quant_type = "" + quant_config = getattr(fd_config, "quant_config", None) + if quant_config and hasattr(quant_config, "moe_quant_type"): + moe_quant_type = quant_config.moe_quant_type or "" + + is_quantized = getattr(fd_config.model_config, "is_quantized", False) + moe_dynamic_quant = getattr(quant_config, "moe_dynamic_quant", False) if quant_config else False + + if moe_quant_type in ("w4a8", "tensor_wise_fp8", "block_wise_fp8") or ( + moe_quant_type == "w4afp8" and is_quantized and not moe_dynamic_quant + ): + weight_key_map = { + "gate_weight_key": f"{prefix}.gate.weight", + "up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.quant_weight", + "down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.quant_weight", + "up_gate_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.weight_scale", + "down_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.down_proj.weight_scale", + "up_gate_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.activation_scale", + "down_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.down_proj.activation_scale", + } + elif moe_quant_type == "w4afp8" and is_quantized: + # Dynamic w4afp8: no activation scales + weight_key_map = { + "gate_weight_key": f"{prefix}.gate.weight", + "up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.quant_weight", + "down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.quant_weight", + "up_gate_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.weight_scale", + "down_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.down_proj.weight_scale", + } + else: + # Default: unquantized + weight_key_map = { + "gate_weight_key": f"{prefix}.gate.weight", + "up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight", + "down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight", + } self.gate = ReplicatedLinear( fd_config=fd_config, @@ -635,29 +668,35 @@ def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray, paddle.Tensor]] for name, weight in list(state_dict.items()): # Expert weights: w1→gate_proj, w3→up_proj, w2→down_proj + # Handles both .weight (FP) and .quant_weight / .weight_scale / .activation_scale (quantized) if "block_sparse_moe.experts." in name: - name = re.sub(r"\.w1\.weight$", ".gate_proj.weight", name) - name = re.sub(r"\.w3\.weight$", ".up_proj.weight", name) - name = re.sub(r"\.w2\.weight$", ".down_proj.weight", name) + name = re.sub(r"\.w1\.", ".gate_proj.", name) + name = re.sub(r"\.w3\.", ".up_proj.", name) + name = re.sub(r"\.w2\.", ".down_proj.", name) renamed[name] = weight # Full attention: merge separate q/k/v into qkv_proj elif ".self_attn.q_proj." in name or ".self_attn.k_proj." in name or ".self_attn.v_proj." in name: # Extract layer prefix: e.g. "model.layers.7.self_attn" - prefix_match = re.match(r"(.*\.self_attn)\.(q|k|v)_proj\.weight$", name) + prefix_match = re.match( + r"(.*\.self_attn)\.(q|k|v)_proj\.(weight|quant_weight|weight_scale|activation_scale)$", name + ) if prefix_match: attn_prefix = prefix_match.group(1) proj_type = prefix_match.group(2) - if attn_prefix not in qkv_buffers: - qkv_buffers[attn_prefix] = {} - qkv_buffers[attn_prefix][proj_type] = weight + suffix = prefix_match.group(3) + buf_key = f"{attn_prefix}|{suffix}" + if buf_key not in qkv_buffers: + qkv_buffers[buf_key] = {} + qkv_buffers[buf_key][proj_type] = weight else: renamed[name] = weight else: renamed[name] = weight # Merge q/k/v into qkv_proj for full attention layers - for attn_prefix, projections in qkv_buffers.items(): + for buf_key, projections in qkv_buffers.items(): if "q" in projections and "k" in projections and "v" in projections: + attn_prefix, suffix = buf_key.split("|", 1) q_w = projections["q"] k_w = projections["k"] v_w = projections["v"] @@ -665,7 +704,7 @@ def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray, paddle.Tensor]] merged = np.concatenate([q_w, k_w, v_w], axis=0) else: merged = paddle.concat([q_w, k_w, v_w], axis=0) - renamed[f"{attn_prefix}.qkv_proj.weight"] = merged + renamed[f"{attn_prefix}.qkv_proj.{suffix}"] = merged self.model.load_state_dict(renamed) self.lm_head.load_state_dict(renamed) diff --git a/tests/model_executor/test_minimax_m1.py b/tests/model_executor/test_minimax_m1.py index 5de24a8a65c..c2da39d07d2 100644 --- a/tests/model_executor/test_minimax_m1.py +++ b/tests/model_executor/test_minimax_m1.py @@ -312,6 +312,43 @@ def test_fallback_attn_type_when_no_config(self): layer = MiniMaxM1DecoderLayer(fd, layer_id=7, prefix="model.layers.7") assert layer.attention_type == 1 + def test_moe_default_weight_key_map(self): + """Unquantized config → weight_key_map has plain .weight keys.""" + fd = _make_fd_config(num_local_experts=4) + FusedMoE = sys.modules["fastdeploy.model_executor.layers.moe.moe"].FusedMoE + FusedMoE.reset_mock() + MiniMaxM1MoE(fd, layer_id=0, prefix="model.layers.0.block_sparse_moe") + wkm = FusedMoE.call_args[1]["weight_key_map"] + assert "gate_weight_key" in wkm + assert wkm["up_gate_proj_expert_weight_key"].endswith(".up_gate_proj.weight") + assert "weight_scale" not in str(wkm) + + def test_moe_w4a8_weight_key_map(self): + """w4a8 quant config → weight_key_map has .quant_weight + scales.""" + fd = _make_fd_config(num_local_experts=4) + fd.quant_config = SimpleNamespace(moe_quant_type="w4a8") + fd.model_config.is_quantized = True + FusedMoE = sys.modules["fastdeploy.model_executor.layers.moe.moe"].FusedMoE + FusedMoE.reset_mock() + MiniMaxM1MoE(fd, layer_id=0, prefix="model.layers.0.block_sparse_moe") + wkm = FusedMoE.call_args[1]["weight_key_map"] + assert "quant_weight" in wkm["up_gate_proj_expert_weight_key"] + assert "weight_scale" in wkm["up_gate_proj_expert_weight_scale_key"] + assert "activation_scale" in wkm["up_gate_proj_expert_in_scale_key"] + + def test_moe_w4afp8_dynamic_weight_key_map(self): + """Dynamic w4afp8 → quant_weight + weight_scale but no activation_scale.""" + fd = _make_fd_config(num_local_experts=4) + fd.quant_config = SimpleNamespace(moe_quant_type="w4afp8", moe_dynamic_quant=True) + fd.model_config.is_quantized = True + FusedMoE = sys.modules["fastdeploy.model_executor.layers.moe.moe"].FusedMoE + FusedMoE.reset_mock() + MiniMaxM1MoE(fd, layer_id=0, prefix="model.layers.0.block_sparse_moe") + wkm = FusedMoE.call_args[1]["weight_key_map"] + assert "quant_weight" in wkm["up_gate_proj_expert_weight_key"] + assert "weight_scale" in wkm["up_gate_proj_expert_weight_scale_key"] + assert "in_scale_key" not in str(wkm) + # =================================================================== # 4. Forward-pass smoke tests