diff --git a/docs/best_practices/MiniMax-M1.md b/docs/best_practices/MiniMax-M1.md
new file mode 100644
index 00000000000..0a1c334cf83
--- /dev/null
+++ b/docs/best_practices/MiniMax-M1.md
@@ -0,0 +1,45 @@
+[简体中文](../zh/best_practices/MiniMax-M1.md)
+
+# MiniMax-M1 Model
+
+## I. Environment Preparation
+
+### 1.1 Support Requirements
+
+MiniMax-M1 support in FastDeploy uses a hybrid decoder stack:
+
+- Standard full-attention layers run through the existing FastDeploy attention backend.
+- Linear-attention layers use the Lightning Attention Triton kernels in `fastdeploy/model_executor/ops/triton_ops/lightning_attn.py`.
+- Current first-pass support targets BF16 inference.
+
+### 1.2 Installing FastDeploy
+
+Installation process reference document [FastDeploy GPU Installation](../get_started/installation/nvidia_gpu.md)
+
+## II. How to Use
+
+### 2.1 Basics: Starting the Service
+
+```shell
+MODEL_PATH=/models/MiniMax-Text-01
+
+python -m fastdeploy.entrypoints.openai.api_server \
+ --model "$MODEL_PATH" \
+ --port 8180 \
+ --metrics-port 8181 \
+ --engine-worker-queue-port 8182 \
+ --max-model-len 32768 \
+ --max-num-seqs 32
+```
+
+### 2.2 Model Notes
+
+- HuggingFace architecture: `MiniMaxText01ForCausalLM`
+- Hybrid layer layout: 70 linear-attention layers and 10 full-attention layers
+- MoE routing: 32 experts, top-2 experts per token
+
+## III. Known Limitations
+
+- This initial integration is focused on model structure and backend wiring.
+- Low-bit quantization support still requires follow-up validation against MiniMax-M1 weights.
+- Production validation should include GPU runtime checks for Lightning Attention decode/prefill paths.
diff --git a/docs/supported_models.md b/docs/supported_models.md
index b0684affc11..fcf4b651f8c 100644
--- a/docs/supported_models.md
+++ b/docs/supported_models.md
@@ -38,6 +38,7 @@ These models accept text input.
|⭐QWEN2.5|BF16/WINT8/FP8|Qwen/qwen2.5-72B;
Qwen/qwen2.5-32B;
Qwen/qwen2.5-14B;
Qwen/qwen2.5-7B;
Qwen/qwen2.5-3B;
Qwen/qwen2.5-1.5B;
Qwen/qwen2.5-0.5B, etc.|
|⭐QWEN2|BF16/WINT8/FP8|Qwen/Qwen/qwen2-72B;
Qwen/Qwen/qwen2-7B;
Qwen/qwen2-1.5B;
Qwen/qwen2-0.5B;
Qwen/QwQ-32, etc.|
|⭐DEEPSEEK|BF16/WINT4|unsloth/DeepSeek-V3.1-BF16;
unsloth/DeepSeek-V3-0324-BF16;
unsloth/DeepSeek-R1-BF16, etc.|
+|MINIMAX-M1|BF16|[MiniMaxAI/MiniMax-Text-01](./best_practices/MiniMax-M1.md);
MiniMaxAI/MiniMax-Text-01-Large, etc.|
|⭐GPT-OSS|BF16/WINT8|unsloth/gpt-oss-20b-BF16, etc.|
|⭐GLM-4.5/4.6|BF16/wfp8afp8|zai-org/GLM-4.5-Air;
zai-org/GLM-4.6
[最佳实践](./best_practices/GLM-4-MoE-Text.md) etc.|
diff --git a/docs/zh/best_practices/MiniMax-M1.md b/docs/zh/best_practices/MiniMax-M1.md
new file mode 100644
index 00000000000..4ba5f695a83
--- /dev/null
+++ b/docs/zh/best_practices/MiniMax-M1.md
@@ -0,0 +1,45 @@
+[English](../../best_practices/MiniMax-M1.md)
+
+# MiniMax-M1 模型
+
+## 一、环境准备
+
+### 1.1 支持说明
+
+FastDeploy 中的 MiniMax-M1 采用混合解码器结构:
+
+- 全注意力层复用 FastDeploy 现有 Attention 后端。
+- 线性注意力层使用 `fastdeploy/model_executor/ops/triton_ops/lightning_attn.py` 中的 Lightning Attention Triton kernel。
+- 当前首版支持以 BF16 推理为主。
+
+### 1.2 安装 FastDeploy
+
+安装流程可参考 [FastDeploy GPU 安装文档](../get_started/installation/nvidia_gpu.md)
+
+## 二、使用方式
+
+### 2.1 基础启动命令
+
+```shell
+MODEL_PATH=/models/MiniMax-Text-01
+
+python -m fastdeploy.entrypoints.openai.api_server \
+ --model "$MODEL_PATH" \
+ --port 8180 \
+ --metrics-port 8181 \
+ --engine-worker-queue-port 8182 \
+ --max-model-len 32768 \
+ --max-num-seqs 32
+```
+
+### 2.2 模型特性
+
+- HuggingFace 架构名:`MiniMaxText01ForCausalLM`
+- 层类型分布:70 层线性注意力 + 10 层全注意力
+- MoE 路由:32 个专家,每个 token 选择 top-2 专家
+
+## 三、当前限制
+
+- 当前版本优先完成模型组网与后端接线。
+- 各类低比特量化推理能力还需要结合真实权重进一步验证。
+- Lightning Attention 的 prefill/decode 路径仍需在 GPU 环境完成端到端验证。
diff --git a/docs/zh/supported_models.md b/docs/zh/supported_models.md
index 1424d2320fb..71cbfa09150 100644
--- a/docs/zh/supported_models.md
+++ b/docs/zh/supported_models.md
@@ -36,6 +36,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
|⭐QWEN2.5|BF16/WINT8/FP8|Qwen/qwen2.5-72B;
Qwen/qwen2.5-32B;
Qwen/qwen2.5-14B;
Qwen/qwen2.5-7B;
Qwen/qwen2.5-3B;
Qwen/qwen2.5-1.5B;
Qwen/qwen2.5-0.5B, etc.|
|⭐QWEN2|BF16/WINT8/FP8|Qwen/Qwen/qwen2-72B;
Qwen/Qwen/qwen2-7B;
Qwen/qwen2-1.5B;
Qwen/qwen2-0.5B;
Qwen/QwQ-32, etc.|
|⭐DEEPSEEK|BF16/WINT4|unsloth/DeepSeek-V3.1-BF16;
unsloth/DeepSeek-V3-0324-BF16;
unsloth/DeepSeek-R1-BF16, etc.|
+|MINIMAX-M1|BF16|[MiniMaxAI/MiniMax-Text-01](./best_practices/MiniMax-M1.md);
MiniMaxAI/MiniMax-Text-01-Large, etc.|
|⭐GPT-OSS|BF16/WINT8|unsloth/gpt-oss-20b-BF16, etc.|
|⭐GLM-4.5/4.6|BF16/wfp8afp8|zai-org/GLM-4.5-Air;
zai-org/GLM-4.6
[最佳实践](./best_practices/GLM-4-MoE-Text.md) etc.|
diff --git a/fastdeploy/model_executor/models/minimax_m1.py b/fastdeploy/model_executor/models/minimax_m1.py
new file mode 100644
index 00000000000..f7148e0bfb0
--- /dev/null
+++ b/fastdeploy/model_executor/models/minimax_m1.py
@@ -0,0 +1,814 @@
+"""
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+MiniMax-M1 Model for FastDeploy
+Hybrid architecture: 70 linear attention layers + 10 full attention layers
+MoE: 32 experts, top-2 routing per token
+"""
+
+from __future__ import annotations
+
+import math
+import re
+from typing import Dict, Union
+
+import numpy as np
+import paddle
+from paddle import nn
+from paddleformers.transformers import PretrainedModel
+from paddleformers.utils.log import logger
+
+from fastdeploy.config import FDConfig
+from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
+from fastdeploy.model_executor.forward_meta import ForwardMeta
+from fastdeploy.model_executor.graph_optimization.decorator import (
+ support_graph_optimization,
+)
+from fastdeploy.model_executor.layers.activation import SiluAndMul
+from fastdeploy.model_executor.layers.attention.attention import Attention
+from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
+from fastdeploy.model_executor.layers.linear import (
+ ColumnParallelLinear,
+ MergedColumnParallelLinear,
+ ReplicatedLinear,
+ RowParallelLinear,
+)
+from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
+from fastdeploy.model_executor.layers.moe.moe import FusedMoE
+from fastdeploy.model_executor.layers.normalization import RMSNorm
+from fastdeploy.model_executor.models.model_base import (
+ ModelCategory,
+ ModelForCasualLM,
+ ModelRegistry,
+)
+from fastdeploy.model_executor.ops.triton_ops.lightning_attn import lightning_attention
+
+
+class MiniMaxM1MLP(nn.Layer):
+ """MiniMax-M1 MLP Layer (Dense FFN)"""
+
+ def __init__(
+ self,
+ fd_config: FDConfig,
+ intermediate_size: int,
+ prefix: str = "",
+ reduce_results: bool = True,
+ ) -> None:
+ super().__init__()
+
+ self.gate_up_proj = MergedColumnParallelLinear(
+ fd_config=fd_config,
+ prefix=f"{prefix}.gate_up_proj",
+ input_size=fd_config.model_config.hidden_size,
+ output_size=intermediate_size * 2,
+ with_bias=False,
+ activation=fd_config.model_config.hidden_act,
+ )
+
+ self.down_proj = RowParallelLinear(
+ fd_config=fd_config,
+ prefix=f"{prefix}.down_proj",
+ input_size=intermediate_size,
+ output_size=fd_config.model_config.hidden_size,
+ with_bias=False,
+ reduce_results=reduce_results,
+ )
+
+ self.act_fn = SiluAndMul(
+ fd_config=fd_config,
+ bias=getattr(self.gate_up_proj, "bias", None),
+ act_method=fd_config.model_config.hidden_act,
+ )
+
+ def load_state_dict(self, state_dict):
+ self.gate_up_proj.load_state_dict(state_dict)
+ self.down_proj.load_state_dict(state_dict)
+
+ def forward(self, x, forward_meta=None):
+ gate_up_out = self.gate_up_proj(x)
+ act_out = self.act_fn(gate_up_out)
+ down_out = self.down_proj(act_out)
+ return down_out
+
+
+class MiniMaxM1MoE(nn.Layer):
+ """MiniMax-M1 MoE Layer with low-bit quantization support."""
+
+ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str) -> None:
+ super().__init__()
+
+ self.tp_size = fd_config.parallel_config.tensor_parallel_size
+ self.norm_topk_prob = getattr(fd_config.model_config, "norm_topk_prob", False)
+
+ # Build quantization-aware weight key map (mirrors Ernie4_5_MoE pattern)
+ moe_quant_type = ""
+ quant_config = getattr(fd_config, "quant_config", None)
+ if quant_config and hasattr(quant_config, "moe_quant_type"):
+ moe_quant_type = quant_config.moe_quant_type or ""
+
+ is_quantized = getattr(fd_config.model_config, "is_quantized", False)
+ moe_dynamic_quant = getattr(quant_config, "moe_dynamic_quant", False) if quant_config else False
+
+ if moe_quant_type in ("w4a8", "tensor_wise_fp8", "block_wise_fp8") or (
+ moe_quant_type == "w4afp8" and is_quantized and not moe_dynamic_quant
+ ):
+ weight_key_map = {
+ "gate_weight_key": f"{prefix}.gate.weight",
+ "up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.quant_weight",
+ "down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.quant_weight",
+ "up_gate_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.weight_scale",
+ "down_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.down_proj.weight_scale",
+ "up_gate_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.activation_scale",
+ "down_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.down_proj.activation_scale",
+ }
+ elif moe_quant_type == "w4afp8" and is_quantized:
+ # Dynamic w4afp8: no activation scales
+ weight_key_map = {
+ "gate_weight_key": f"{prefix}.gate.weight",
+ "up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.quant_weight",
+ "down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.quant_weight",
+ "up_gate_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.weight_scale",
+ "down_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.down_proj.weight_scale",
+ }
+ else:
+ # Default: unquantized
+ weight_key_map = {
+ "gate_weight_key": f"{prefix}.gate.weight",
+ "up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight",
+ "down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight",
+ }
+
+ self.gate = ReplicatedLinear(
+ fd_config=fd_config,
+ prefix=f"{prefix}.gate",
+ input_size=fd_config.model_config.hidden_size,
+ output_size=fd_config.model_config.num_local_experts,
+ with_bias=False,
+ skip_quant=True,
+ weight_dtype="float32",
+ )
+
+ self.experts = FusedMoE(
+ fd_config=fd_config,
+ reduce_results=True,
+ renormalize=self.norm_topk_prob,
+ moe_intermediate_size=fd_config.model_config.intermediate_size,
+ num_experts=fd_config.model_config.num_local_experts,
+ top_k=fd_config.model_config.num_experts_per_tok,
+ layer_idx=layer_id,
+ weight_key_map=weight_key_map,
+ )
+
+ def load_state_dict(self, state_dict):
+ self.gate.load_state_dict(state_dict)
+ self.experts.load_state_dict(state_dict)
+
+ def forward(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta):
+ """Forward pass with router gating."""
+ moe_out = self.experts(hidden_states, self.gate, forward_meta)
+ if self.tp_size > 1:
+ moe_out = tensor_model_parallel_all_reduce(moe_out)
+ return moe_out
+
+
+class MiniMaxM1Attention(nn.Layer):
+ """MiniMax-M1 Full Attention (standard GQA)"""
+
+ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None:
+ super().__init__()
+
+ self.hidden_size = fd_config.model_config.hidden_size
+ self.num_attention_heads = fd_config.model_config.num_attention_heads
+ self.head_dim = fd_config.model_config.head_dim
+ self.num_key_value_heads = fd_config.model_config.num_key_value_heads
+
+ self.qkv_proj = ColumnParallelLinear(
+ fd_config=fd_config,
+ prefix=f"{prefix}.qkv_proj",
+ input_size=self.hidden_size,
+ output_size=(self.num_attention_heads + 2 * self.num_key_value_heads) * self.head_dim,
+ with_bias=False,
+ )
+
+ self.o_proj = RowParallelLinear(
+ fd_config,
+ prefix=f"{prefix}.o_proj",
+ input_size=self.num_attention_heads * self.head_dim,
+ output_size=self.hidden_size,
+ with_bias=False,
+ layer_id=layer_id,
+ )
+
+ self.attn = Attention(
+ fd_config=fd_config,
+ layer_id=layer_id,
+ prefix=prefix,
+ use_neox_rotary_style=True,
+ )
+
+ def load_state_dict(self, state_dict):
+ self.qkv_proj.load_state_dict(state_dict)
+ self.o_proj.load_state_dict(state_dict)
+ self.attn.load_state_dict(state_dict)
+
+ def forward(
+ self,
+ forward_meta: ForwardMeta,
+ hidden_states: paddle.Tensor,
+ ):
+ """Full attention forward."""
+ q, k, v = self._compute_qkv(hidden_states)
+ attn_output = self.attn(q, k, v, forward_meta=forward_meta)
+ output = self.o_proj(attn_output)
+ return output
+
+ def _compute_qkv(self, hidden_states):
+ """Project hidden states to Q, K, V."""
+ qkv = self.qkv_proj(hidden_states)
+ q_size = self.num_attention_heads * self.head_dim
+ kv_size = self.num_key_value_heads * self.head_dim
+ q, k, v = qkv.split([q_size, kv_size, kv_size], axis=-1)
+ return q, k, v
+
+
+class MiniMaxM1LinearAttention(nn.Layer):
+ """MiniMax-M1 Linear Attention (Lightning Attention)"""
+
+ def __init__(
+ self,
+ fd_config: FDConfig,
+ layer_id: int,
+ linear_layer_id: int,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+
+ self.hidden_size = fd_config.model_config.hidden_size
+ self.head_dim = fd_config.model_config.head_dim
+ self.num_attention_heads = fd_config.model_config.num_attention_heads
+ hidden_inner = self.num_attention_heads * self.head_dim
+
+ # QKV projection
+ self.qkv_proj = ColumnParallelLinear(
+ fd_config=fd_config,
+ prefix=f"{prefix}.qkv_proj",
+ input_size=self.hidden_size,
+ output_size=hidden_inner * 3,
+ with_bias=False,
+ )
+
+ # Output gate (sigmoid gating on attention output)
+ self.output_gate = ColumnParallelLinear(
+ fd_config=fd_config,
+ prefix=f"{prefix}.output_gate",
+ input_size=self.hidden_size,
+ output_size=hidden_inner,
+ with_bias=False,
+ )
+
+ # Output projection (HF name: out_proj)
+ self.out_proj = RowParallelLinear(
+ fd_config,
+ prefix=f"{prefix}.out_proj",
+ input_size=hidden_inner,
+ output_size=self.hidden_size,
+ with_bias=False,
+ layer_id=layer_id,
+ )
+
+ # RMSNorm on attention output before gating
+ self.norm = RMSNorm(
+ fd_config,
+ hidden_size=hidden_inner,
+ eps=1e-5,
+ prefix=f"{prefix}.norm",
+ )
+
+ # Build slope tensor for exponential decay
+ slope_tensor = self._build_slope_tensor(self.num_attention_heads)
+ if fd_config.model_config.num_hidden_layers <= 1:
+ slope_tensor = slope_tensor * (1 + 1e-5)
+ else:
+ slope_tensor = slope_tensor * (1 - layer_id / (fd_config.model_config.num_hidden_layers - 1) + 1e-5)
+ # Register as buffer (not trainable)
+ self.register_buffer("slope_rate", slope_tensor)
+
+ # KV cache shape: [heads, head_dim, head_dim]
+ self.kv_cache_shape = (self.num_attention_heads, self.head_dim, self.head_dim)
+
+ def load_state_dict(self, state_dict):
+ self.qkv_proj.load_state_dict(state_dict)
+ self.output_gate.load_state_dict(state_dict)
+ self.out_proj.load_state_dict(state_dict)
+ self.norm.load_state_dict(state_dict)
+
+ @staticmethod
+ def _build_slope_tensor(n_heads: int):
+ """Build ALiBi-style slope tensor for exponential decay."""
+
+ def get_slopes_power_of_2(n):
+ start = 2 ** (-(2 ** (-(math.log2(n) - 3))))
+ return [start * (start**i) for i in range(n)]
+
+ if math.log2(n_heads).is_integer():
+ slopes = get_slopes_power_of_2(n_heads)
+ else:
+ closest_power = 2 ** math.floor(math.log2(n_heads))
+ slopes = get_slopes_power_of_2(closest_power)
+ slopes += get_slopes_power_of_2(2 * closest_power)[0::2][: n_heads - closest_power]
+
+ return paddle.to_tensor(slopes, dtype=paddle.float32).reshape([n_heads, 1, 1])
+
+ def forward(
+ self,
+ forward_meta: ForwardMeta,
+ hidden_states: paddle.Tensor,
+ ):
+ """Linear attention forward with output gating."""
+ # Project QKV
+ qkv = self.qkv_proj(hidden_states)
+ hidden_inner = self.num_attention_heads * self.head_dim
+ q, k, v = qkv.split([hidden_inner, hidden_inner, hidden_inner], axis=-1)
+
+ # Apply SiLU activation (matches HF MiniMax convention)
+ q = paddle.nn.functional.silu(q.astype("float32"))
+ k = paddle.nn.functional.silu(k.astype("float32"))
+ v = paddle.nn.functional.silu(v.astype("float32"))
+
+ # Reshape for lightning attention
+ batch_size = q.shape[0]
+ q = q.reshape([batch_size, -1, self.num_attention_heads, self.head_dim])
+ k = k.reshape([batch_size, -1, self.num_attention_heads, self.head_dim])
+ v = v.reshape([batch_size, -1, self.num_attention_heads, self.head_dim])
+
+ # Transpose to [batch, heads, seq_len, dim]
+ q = q.transpose([0, 2, 1, 3])
+ k = k.transpose([0, 2, 1, 3])
+ v = v.transpose([0, 2, 1, 3])
+
+ # Retrieve or initialize KV history for recurrent state persistence
+ if not hasattr(self, "_kv_history") or self._kv_history is None:
+ self._kv_history = paddle.zeros(
+ [batch_size, self.num_attention_heads, self.head_dim, self.head_dim],
+ dtype=q.dtype,
+ )
+
+ # Apply lightning attention
+ attn_output, new_kv_history = lightning_attention(
+ q, k, v, self.slope_rate.squeeze(-1), block_size=256, kv_history=self._kv_history
+ )
+ # Update persisted KV state for next token generation
+ self._kv_history = new_kv_history
+
+ # Reshape back to [batch, seq, hidden_inner]
+ attn_output = attn_output.transpose([0, 2, 1, 3])
+ attn_output = attn_output.reshape([batch_size, -1, self.num_attention_heads * self.head_dim])
+
+ # Norm → gate → output projection (matches vLLM/HF forward)
+ attn_output = self.norm(attn_output)[0]
+ gate = self.output_gate(hidden_states)
+ attn_output = paddle.nn.functional.sigmoid(gate) * attn_output.astype(hidden_states.dtype)
+ output = self.out_proj(attn_output)
+ return output
+
+
+class MiniMaxM1DecoderLayer(nn.Layer):
+ """MiniMax-M1 Decoder Layer with Hybrid Attention Dispatch"""
+
+ @staticmethod
+ def _build_attn_type_list(num_layers: int):
+ """Build attention type list: 70 linear + 10 full (at indices 7,15,23,...)."""
+ attn_type_list = [0] * num_layers # Default: all linear
+ # Full attention every 8 layers starting at layer 7
+ full_attn_indices = [7, 15, 23, 31, 39, 47, 55, 63, 71, 79]
+ for idx in full_attn_indices:
+ if idx < num_layers:
+ attn_type_list[idx] = 1
+ return attn_type_list
+
+ def __init__(
+ self,
+ fd_config: FDConfig,
+ layer_id: int,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+
+ self.hidden_size = fd_config.model_config.hidden_size
+ self.layer_id = layer_id
+ self.postnorm = getattr(fd_config.model_config, "postnorm", False)
+
+ # Determine attention type for this layer
+ # attn_type_list: 70 linear (0) + 10 full (1) at specific indices
+ attn_type_list = getattr(
+ fd_config.model_config,
+ "attn_type_list",
+ self._build_attn_type_list(fd_config.model_config.num_hidden_layers),
+ )
+ self.attention_type = attn_type_list[layer_id] if layer_id < len(attn_type_list) else 1
+
+ # Attention layer (dispatch based on type)
+ if self.attention_type == 0: # Linear attention
+ linear_layer_id = sum(1 for i in range(layer_id) if attn_type_list[i] == 0)
+ self.self_attn = MiniMaxM1LinearAttention(
+ fd_config,
+ layer_id=layer_id,
+ linear_layer_id=linear_layer_id,
+ prefix=f"{prefix}.self_attn",
+ )
+ else: # Full attention
+ self.self_attn = MiniMaxM1Attention(
+ fd_config,
+ layer_id=layer_id,
+ prefix=f"{prefix}.self_attn",
+ )
+
+ # Input layernorm (pre-norm)
+ self.input_layernorm = RMSNorm(
+ fd_config,
+ hidden_size=fd_config.model_config.hidden_size,
+ eps=fd_config.model_config.rms_norm_eps,
+ prefix=f"{prefix}.input_layernorm",
+ )
+
+ # Post-attention layernorm
+ self.post_attention_layernorm = RMSNorm(
+ fd_config,
+ hidden_size=fd_config.model_config.hidden_size,
+ eps=fd_config.model_config.rms_norm_eps,
+ prefix=f"{prefix}.post_attention_layernorm",
+ )
+
+ # DeepNorm alpha/beta scaling — separate coefficients for linear vs full attention
+ if self.attention_type == 0: # Linear attention
+ self.layernorm_attention_alpha = getattr(
+ fd_config.model_config, "layernorm_linear_attention_alpha", 3.5565588200778455
+ )
+ self.layernorm_attention_beta = getattr(fd_config.model_config, "layernorm_linear_attention_beta", 1.0)
+ else: # Full attention
+ self.layernorm_attention_alpha = getattr(
+ fd_config.model_config, "layernorm_full_attention_alpha", 3.5565588200778455
+ )
+ self.layernorm_attention_beta = getattr(fd_config.model_config, "layernorm_full_attention_beta", 1.0)
+ self.layernorm_mlp_alpha = getattr(fd_config.model_config, "layernorm_mlp_alpha", 3.5565588200778455)
+ self.layernorm_mlp_beta = getattr(fd_config.model_config, "layernorm_mlp_beta", 1.0)
+
+ # FFN (MLP or MoE)
+ if fd_config.model_config.num_local_experts > 1:
+ self.block_sparse_moe = MiniMaxM1MoE(
+ fd_config,
+ layer_id=layer_id,
+ prefix=f"{prefix}.block_sparse_moe",
+ )
+ else:
+ self.block_sparse_moe = MiniMaxM1MLP(
+ fd_config,
+ intermediate_size=fd_config.model_config.intermediate_size,
+ prefix=f"{prefix}.mlp",
+ reduce_results=True,
+ )
+
+ def load_state_dict(self, state_dict):
+ self.self_attn.load_state_dict(state_dict)
+ self.block_sparse_moe.load_state_dict(state_dict)
+ self.input_layernorm.load_state_dict(state_dict)
+ self.post_attention_layernorm.load_state_dict(state_dict)
+
+ def forward(
+ self,
+ forward_meta: ForwardMeta,
+ hidden_states: paddle.Tensor,
+ residual: paddle.Tensor = None,
+ ):
+ """Decoder layer forward with DeepNorm.
+
+ When postnorm=True (MiniMax-M1 default), the residual stream carries the
+ *normed* activations rather than the pre-norm sum. This follows the
+ vLLM reference: ``residual = layernorm_output if postnorm else layernorm_input``.
+ """
+ # Input layernorm (fused: x + residual → norm)
+ hidden_states, residual = self.input_layernorm(
+ hidden_states,
+ residual_input=residual,
+ forward_meta=forward_meta,
+ )
+ # hidden_states = norm(input + prev_residual)
+ # residual = input + prev_residual (pre-norm)
+ if self.postnorm:
+ residual = hidden_states # postnorm: residual = normed output
+
+ # Attention (dispatch based on type)
+ attn_output = self.self_attn(forward_meta=forward_meta, hidden_states=hidden_states)
+
+ # DeepNorm alpha/beta scaling
+ residual = residual * self.layernorm_attention_alpha
+ attn_output = attn_output * self.layernorm_attention_beta
+
+ # Post-attention layernorm
+ if self.postnorm:
+ layernorm_input = residual + attn_output
+ hidden_states, residual = self.post_attention_layernorm(
+ layernorm_input,
+ forward_meta=forward_meta,
+ )
+ residual = hidden_states # postnorm: residual = normed output
+ else:
+ hidden_states, residual = self.post_attention_layernorm(
+ attn_output,
+ residual_input=residual,
+ forward_meta=forward_meta,
+ )
+
+ # FFN
+ mlp_output = self.block_sparse_moe(hidden_states, forward_meta)
+
+ # DeepNorm MLP alpha/beta
+ residual = residual * self.layernorm_mlp_alpha
+ mlp_output = mlp_output * self.layernorm_mlp_beta
+
+ hidden_states = residual + mlp_output
+
+ return hidden_states, residual
+
+
+@support_graph_optimization
+class MiniMaxM1Model(nn.Layer):
+ """MiniMax-M1 Transformer Model"""
+
+ def __init__(self, fd_config: FDConfig = None):
+ super().__init__()
+
+ self.num_layers = fd_config.model_config.num_hidden_layers
+ self.hidden_size = fd_config.model_config.hidden_size
+ fd_config.model_config.pretrained_config.prefix_name = "model"
+
+ # Embedding
+ self.embed_tokens = VocabParallelEmbedding(
+ fd_config,
+ num_embeddings=fd_config.model_config.vocab_size,
+ embedding_dim=fd_config.model_config.hidden_size,
+ prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.embed_tokens",
+ )
+
+ # Decoder layers
+ self.layers = nn.LayerList(
+ [
+ MiniMaxM1DecoderLayer(
+ fd_config,
+ layer_id=i,
+ prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.layers.{i}",
+ )
+ for i in range(self.num_layers)
+ ]
+ )
+
+ # Final layernorm
+ self.norm = RMSNorm(
+ fd_config,
+ hidden_size=fd_config.model_config.hidden_size,
+ eps=fd_config.model_config.rms_norm_eps,
+ prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.norm",
+ )
+
+ def load_state_dict(self, state_dict):
+ """Load model parameters."""
+ self.embed_tokens.load_state_dict(state_dict)
+ self.norm.load_state_dict(state_dict)
+ for i in range(self.num_layers):
+ logger.info(f"Start load layer {i}")
+ self.layers[i].load_state_dict(state_dict)
+
+ def forward(
+ self,
+ ids_remove_padding: paddle.Tensor,
+ forward_meta: ForwardMeta,
+ ):
+ """Model forward pass."""
+ hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta)
+
+ residual = None
+
+ # Pass through decoder layers
+ for i in range(self.num_layers):
+ hidden_states, residual = self.layers[i](
+ forward_meta=forward_meta,
+ hidden_states=hidden_states,
+ residual=residual,
+ )
+
+ # Final layernorm
+ hidden_states = self.norm(hidden_states, residual)[0]
+
+ return hidden_states
+
+
+@ModelRegistry.register_model_class(
+ architecture="MiniMaxM1ForCausalLM",
+ module_name="minimax_m1",
+ category=ModelCategory.TEXT_GENERATION,
+ primary_use=ModelCategory.TEXT_GENERATION,
+)
+@ModelRegistry.register_model_class(
+ architecture="MiniMaxText01ForCausalLM",
+ module_name="minimax_m1",
+ category=ModelCategory.TEXT_GENERATION,
+ primary_use=ModelCategory.TEXT_GENERATION,
+)
+class MiniMaxM1ForCausalLM(ModelForCasualLM):
+ """MiniMax-M1 Causal LM Model"""
+
+ # Mapping HF checkpoint names → FD merged parameter names.
+ # For full attention layers: separate q/k/v → merged qkv_proj
+ # For MoE: gate_proj/up_proj → merged gate_up_proj (dense MLP fallback)
+ _STACKED_PARAMS_MAPPING = [
+ # (fd_param_name, hf_weight_name, shard_id)
+ ("qkv_proj", "q_proj", "q"),
+ ("qkv_proj", "k_proj", "k"),
+ ("qkv_proj", "v_proj", "v"),
+ ]
+
+ def __init__(self, fd_config: FDConfig):
+ super().__init__(fd_config)
+
+ self.model = MiniMaxM1Model(fd_config)
+ self.lm_head = ParallelLMHead(
+ fd_config,
+ embedding_dim=fd_config.model_config.hidden_size,
+ num_embeddings=fd_config.model_config.vocab_size,
+ prefix="lm_head",
+ )
+
+ @classmethod
+ def name(cls):
+ """Model name."""
+ return "MiniMaxM1ForCausalLM"
+
+ @paddle.no_grad()
+ def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray, paddle.Tensor]]):
+ """Load model parameters (v0 loader path).
+
+ Pre-processes HF weight keys to match FD naming conventions, then
+ delegates to sub-layer ``load_state_dict`` calls.
+ """
+ renamed: Dict[str, Union[np.ndarray, paddle.Tensor]] = {}
+ # Collect full-attention q/k/v weights for merging into qkv_proj
+ qkv_buffers: Dict[str, Dict[str, Union[np.ndarray, paddle.Tensor]]] = {}
+
+ for name, weight in list(state_dict.items()):
+ # Expert weights: w1→gate_proj, w3→up_proj, w2→down_proj
+ # Handles both .weight (FP) and .quant_weight / .weight_scale / .activation_scale (quantized)
+ if "block_sparse_moe.experts." in name:
+ name = re.sub(r"\.w1\.", ".gate_proj.", name)
+ name = re.sub(r"\.w3\.", ".up_proj.", name)
+ name = re.sub(r"\.w2\.", ".down_proj.", name)
+ renamed[name] = weight
+ # Full attention: merge separate q/k/v into qkv_proj
+ elif ".self_attn.q_proj." in name or ".self_attn.k_proj." in name or ".self_attn.v_proj." in name:
+ # Extract layer prefix: e.g. "model.layers.7.self_attn"
+ prefix_match = re.match(
+ r"(.*\.self_attn)\.(q|k|v)_proj\.(weight|quant_weight|weight_scale|activation_scale)$", name
+ )
+ if prefix_match:
+ attn_prefix = prefix_match.group(1)
+ proj_type = prefix_match.group(2)
+ suffix = prefix_match.group(3)
+ buf_key = f"{attn_prefix}|{suffix}"
+ if buf_key not in qkv_buffers:
+ qkv_buffers[buf_key] = {}
+ qkv_buffers[buf_key][proj_type] = weight
+ else:
+ renamed[name] = weight
+ else:
+ renamed[name] = weight
+
+ # Merge q/k/v into qkv_proj for full attention layers
+ for buf_key, projections in qkv_buffers.items():
+ if "q" in projections and "k" in projections and "v" in projections:
+ attn_prefix, suffix = buf_key.split("|", 1)
+ q_w = projections["q"]
+ k_w = projections["k"]
+ v_w = projections["v"]
+ if isinstance(q_w, np.ndarray):
+ merged = np.concatenate([q_w, k_w, v_w], axis=0)
+ else:
+ merged = paddle.concat([q_w, k_w, v_w], axis=0)
+ renamed[f"{attn_prefix}.qkv_proj.{suffix}"] = merged
+
+ self.model.load_state_dict(renamed)
+ self.lm_head.load_state_dict(renamed)
+
+ @paddle.no_grad()
+ def load_weights(self, weights_iterator) -> None:
+ """Load model parameters from a weights iterator (v1 loader path).
+
+ Handles HF→FD name mapping for:
+ - Full attention: q_proj/k_proj/v_proj → qkv_proj (stacked)
+ - MoE experts: w1/w3 → up_gate_proj, w2 → down_proj
+ """
+ from fastdeploy.model_executor.utils import (
+ default_weight_loader,
+ process_weights_after_loading,
+ )
+
+ stacked_params_mapping = list(self._STACKED_PARAMS_MAPPING)
+
+ # Expert weight mapping: HF w1/w2/w3 → FD up_gate_proj/down_proj
+ n_experts = getattr(self.fd_config.model_config, "num_local_experts", 1)
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
+ num_experts=n_experts,
+ ckpt_gate_proj_name="w1",
+ ckpt_down_proj_name="w2",
+ ckpt_up_proj_name="w3",
+ param_gate_up_proj_name="experts.up_gate_proj_",
+ param_down_proj_name="experts.down_proj_",
+ )
+
+ params_dict = dict(self.named_parameters())
+ process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()), self.fd_config)
+
+ for loaded_weight_name, loaded_weight in weights_iterator:
+ logger.debug(f"Loading weight: {loaded_weight_name}")
+
+ # Stacked params (q/k/v → qkv_proj)
+ for param_name, weight_name, shard_id in stacked_params_mapping:
+ if weight_name not in loaded_weight_name:
+ continue
+ # Skip expert weights — handled separately
+ if "block_sparse_moe.experts." in loaded_weight_name:
+ continue
+ model_param_name = loaded_weight_name.replace(weight_name, param_name)
+ if model_param_name not in params_dict:
+ continue
+ param = params_dict[model_param_name]
+ weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ # Expert params (w1/w2/w3 → up_gate_proj/down_proj)
+ for mapping in expert_params_mapping:
+ param_name, weight_name, expert_id, shard_id = mapping
+ if weight_name not in loaded_weight_name:
+ continue
+ model_param_name = loaded_weight_name.replace(weight_name, param_name)
+ if model_param_name not in params_dict:
+ continue
+ param = params_dict[model_param_name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id=shard_id, expert_id=expert_id)
+ break
+ else:
+ # Direct loading (norm, embed, lm_head, output_gate, out_proj, etc.)
+ model_param_name = loaded_weight_name
+ if model_param_name not in params_dict:
+ continue
+ param = params_dict[model_param_name]
+ weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
+ weight_loader(param, loaded_weight)
+
+ model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name)
+ process_weights_after_loading_fn(model_sublayer_name, param)
+
+ def compute_logits(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta = None):
+ """Compute logits."""
+ logits = self.lm_head(hidden_states)
+ logits = logits.astype(paddle.float32)
+ return logits
+
+ def forward(
+ self,
+ inputs: Dict,
+ forward_meta: ForwardMeta,
+ ):
+ """Forward pass."""
+ ids_remove_padding = inputs["ids_remove_padding"]
+
+ hidden_states = self.model(ids_remove_padding, forward_meta)
+ return hidden_states
+
+
+class MiniMaxM1PretrainedModel(PretrainedModel):
+ """MiniMax-M1 Pretrained Model"""
+
+ config_class = FDConfig
+
+ @classmethod
+ def arch_name(cls):
+ """Architecture name."""
+ return "MiniMaxM1ForCausalLM"
+
+ @classmethod
+ def name(cls):
+ """Model name."""
+ return "MiniMaxM1ForCausalLM"
diff --git a/fastdeploy/model_executor/ops/triton_ops/lightning_attn.py b/fastdeploy/model_executor/ops/triton_ops/lightning_attn.py
new file mode 100644
index 00000000000..df90674e747
--- /dev/null
+++ b/fastdeploy/model_executor/ops/triton_ops/lightning_attn.py
@@ -0,0 +1,711 @@
+"""
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+
+import paddle
+import triton
+import triton.language as tl
+
+from fastdeploy.model_executor.ops.triton_ops.triton_utils import (
+ enable_compat_on_triton_kernel,
+)
+
+# =============================================================================
+# Triton JIT Kernels — framework-agnostic, operate on raw pointers
+# =============================================================================
+
+
+@enable_compat_on_triton_kernel
+@triton.jit
+def _fwd_diag_kernel(
+ Q,
+ K,
+ V,
+ Out,
+ S,
+ b: tl.constexpr,
+ h: tl.constexpr,
+ n,
+ d: tl.constexpr,
+ e: tl.constexpr,
+ BLOCK: tl.constexpr,
+ NUM_BLOCK,
+ CBLOCK: tl.constexpr,
+):
+ # This kernel computes the diagonal blocks of the attention matrix
+ # Each diagonal block represents attention
+ # where queries attend to keys in the same block
+ off = tl.program_id(0)
+ off_bh = off // NUM_BLOCK # batch-head index
+ off_block = off % NUM_BLOCK # block index within the sequence
+ off_cblock = tl.program_id(1) # sub-block index within a block
+
+ off_h = off_bh % h # head index
+
+ # Calculate base offsets for the current batch and head
+ qk_offset = off_bh * n * d
+ v_offset = off_bh * n * e
+ o_offset = off_bh * n * e
+
+ # Calculate offsets for the current block
+ block_offset = off_block * BLOCK
+ qk_block_offset = block_offset * d
+ v_block_offset = block_offset * e
+ o_block_offset = block_offset * e
+
+ # Calculate offsets for the current sub-block
+ cblock_offset = off_cblock * CBLOCK
+ q_cblock_offset = cblock_offset * d
+ o_cblock_offset = cblock_offset * e
+
+ # Calculate pointers to the query, key, value, and output tensors
+ Q_block_ptr = (
+ Q
+ + qk_offset
+ + qk_block_offset
+ + q_cblock_offset
+ + tl.arange(0, CBLOCK)[:, None] * d
+ + tl.arange(0, d)[None, :]
+ )
+ K_trans_block_ptr = K + qk_offset + qk_block_offset + tl.arange(0, CBLOCK)[None, :] * d + tl.arange(0, d)[:, None]
+ V_block_ptr = V + v_offset + v_block_offset + tl.arange(0, CBLOCK)[:, None] * e + tl.arange(0, e)[None, :]
+ O_block_ptr = (
+ Out
+ + o_offset
+ + o_block_offset
+ + o_cblock_offset
+ + tl.arange(0, CBLOCK)[:, None] * e
+ + tl.arange(0, e)[None, :]
+ )
+
+ # Load the decay rate for the current head
+ S_block_ptr = S + off_h
+ s = tl.load(S_block_ptr)
+
+ i = off_cblock
+ q_index = tl.arange(0, CBLOCK) + i * CBLOCK
+
+ # Load query values
+ q = tl.load(Q_block_ptr, mask=block_offset + q_index[:, None] < n, other=0.0).to(tl.float32)
+
+ # Initialize output accumulator
+ qkv = tl.zeros([CBLOCK, e], dtype=tl.float32)
+
+ # Process all sub-blocks up to and
+ # including the current one (causal attention)
+ for j in range(i + 1):
+ kv_index = tl.arange(0, CBLOCK) + j * CBLOCK
+ diff = q_index[:, None] - kv_index[None, :]
+ s_index = s * diff
+ # Apply causal mask: only attend to positions before the current one
+ s_index = tl.where(diff >= 0, -s_index, float("-inf"))
+ decay = tl.exp(s_index)
+
+ # Load key and value
+ k_trans = tl.load(
+ K_trans_block_ptr,
+ mask=block_offset + kv_index[None, :] < n,
+ other=0.0,
+ ).to(tl.float32)
+ v = tl.load(
+ V_block_ptr,
+ mask=block_offset + kv_index[:, None] < n,
+ other=0.0,
+ ).to(tl.float32)
+
+ # Compute attention scores and apply decay
+ qk = tl.dot(q, k_trans) * decay
+
+ # Compute weighted values and accumulate
+ qkv += tl.dot(qk, v)
+
+ # Move to the next sub-block
+ K_trans_block_ptr += CBLOCK * d
+ V_block_ptr += CBLOCK * e
+
+ # Store the result
+ tl.store(
+ O_block_ptr,
+ qkv.to(O_block_ptr.dtype.element_ty),
+ mask=block_offset + q_index[:, None] < n,
+ )
+
+
+@enable_compat_on_triton_kernel
+@triton.jit
+def _fwd_kv_parallel(
+ K,
+ V,
+ K_decay,
+ KV,
+ b: tl.constexpr,
+ h: tl.constexpr,
+ n,
+ d: tl.constexpr,
+ e: tl.constexpr,
+ BLOCK: tl.constexpr,
+ NUM_BLOCK,
+ D_FBLOCK: tl.constexpr,
+ E_FBLOCK: tl.constexpr,
+ NUM_FBLOCK: tl.constexpr,
+ CBLOCK: tl.constexpr,
+ NUM_CBLOCK: tl.constexpr,
+):
+ # This kernel computes the key-value outer
+ # products for each block in parallel
+ off_bh = tl.program_id(0) # batch-head index
+ off_block = tl.program_id(1) # block index
+
+ off_h = off_bh % h # head index
+
+ block_offset = off_block * BLOCK
+
+ # Calculate offsets for the current block
+ k_block_offset = block_offset * d
+ v_block_offset = block_offset * e
+ kv_block_offset = off_block * d * e
+
+ # Calculate base offsets for the current batch and head
+ k_offset = off_bh * n * d
+ v_offset = off_bh * n * e
+ kv_offset = off_bh * NUM_BLOCK * d * e
+
+ # Calculate pointers to the key, value, and key-value tensors
+ K_trans_block_ptr = (
+ K + k_offset + k_block_offset + tl.arange(0, CBLOCK)[None, :] * d + tl.arange(0, D_FBLOCK)[:, None]
+ )
+ V_block_ptr = V + v_offset + v_block_offset + tl.arange(0, CBLOCK)[:, None] * e + tl.arange(0, E_FBLOCK)[None, :]
+ KV_block_ptr = (
+ KV + kv_offset + kv_block_offset + tl.arange(0, D_FBLOCK)[:, None] * e + tl.arange(0, E_FBLOCK)[None, :]
+ )
+
+ # Load the decay factors for the current head and block
+ k_decay_ptr = K_decay + off_h * BLOCK + tl.arange(0, CBLOCK)
+
+ kv_index = tl.arange(0, CBLOCK)
+
+ # Initialize the key-value outer product accumulator
+ kv = tl.zeros([D_FBLOCK, E_FBLOCK], dtype=tl.float32)
+
+ # Handle the last block which might be smaller than BLOCK
+ split_n = n - (NUM_BLOCK - 1) * BLOCK if off_block == NUM_BLOCK - 1 else BLOCK
+ left_shift = tl.cdiv(split_n, CBLOCK) * CBLOCK - split_n
+ num_blocks = min(tl.cdiv(split_n, CBLOCK), NUM_CBLOCK)
+ k_decay_ptr += (NUM_CBLOCK - num_blocks) * CBLOCK
+
+ # Process all sub-blocks in the current block
+ for j in range(num_blocks):
+ left_bound = (1 - j) * left_shift
+ # Load key and value, handling boundary conditions
+ k_trans = tl.load(
+ K_trans_block_ptr - left_shift * d,
+ mask=kv_index[None, :] >= left_bound,
+ other=0.0,
+ )
+ v = tl.load(
+ V_block_ptr - left_shift * e,
+ mask=kv_index[:, None] >= left_bound,
+ other=0.0,
+ )
+
+ # Load decay factor and compute weighted key-value outer product
+ k_decay = tl.load(k_decay_ptr)
+
+ # NOTE: Need to add the extra dim here due to AMD MLIR lowering error.
+ # Please don't move it back until issue is resolved.
+ # Issue: https://github.com/ROCm/triton/issues/907
+ k_decay = k_decay[None, :]
+
+ kv += tl.dot(k_trans * k_decay, v)
+
+ # Move to the next sub-block
+ K_trans_block_ptr += CBLOCK * d
+ V_block_ptr += CBLOCK * e
+ k_decay_ptr += CBLOCK
+
+ # Store the result
+ tl.store(KV_block_ptr, kv.to(KV_block_ptr.dtype.element_ty))
+
+
+@enable_compat_on_triton_kernel
+@triton.jit
+def _fwd_kv_reduce(
+ S,
+ KV,
+ KV_HISTORY,
+ b: tl.constexpr,
+ h: tl.constexpr,
+ n,
+ d: tl.constexpr,
+ e: tl.constexpr,
+ BLOCK: tl.constexpr,
+ NUM_BLOCK,
+ D_FBLOCK: tl.constexpr,
+ E_FBLOCK: tl.constexpr,
+):
+ # This kernel reduces the key-value outer products
+ # across blocks and updates the KV history
+ off_bh = tl.program_id(0) # batch-head index
+ off_h = off_bh % h # head index
+
+ kv_offset = off_bh * NUM_BLOCK * d * e
+
+ # Calculate pointer to the key-value tensor
+ KV_block_ptr = KV + kv_offset + tl.arange(0, D_FBLOCK)[:, None] * e + tl.arange(0, E_FBLOCK)[None, :]
+
+ # Load the decay rate for the current head
+ s_ptrs = S + off_h
+ s = tl.load(s_ptrs)
+
+ # Calculate pointer to the key-value history tensor
+ kv_history_offset = off_bh * d * e
+ KV_HISTORY_block_ptr = (
+ KV_HISTORY + kv_history_offset + tl.arange(0, D_FBLOCK)[:, None] * e + tl.arange(0, E_FBLOCK)[None, :]
+ )
+
+ # Load the previous key-value history
+ kv_pre = tl.load(KV_HISTORY_block_ptr).to(tl.float32)
+
+ # Process all blocks in reverse order to compute the prefix sum
+ for i in range(NUM_BLOCK):
+ block_size = min(n - i * BLOCK, BLOCK)
+ # Compute decay factor for the current block
+ block_decay = tl.exp(-s.to(tl.float32) * block_size)
+
+ # Load the current key-value outer product
+ kv_cur = tl.load(KV_block_ptr).to(tl.float32)
+ # Store the previous key-value history to the current block
+ tl.store(KV_block_ptr, kv_pre.to(KV_block_ptr.dtype.element_ty))
+
+ # Update the key-value history with the current block
+ kv_pre = block_decay * kv_pre + kv_cur
+ KV_block_ptr += d * e
+
+ # Store the updated key-value history
+ tl.store(KV_HISTORY_block_ptr, kv_pre)
+
+
+@enable_compat_on_triton_kernel
+@triton.jit
+def _fwd_none_diag_kernel(
+ Q,
+ Out,
+ S,
+ KV,
+ b: tl.constexpr,
+ h: tl.constexpr,
+ n,
+ d: tl.constexpr,
+ e: tl.constexpr,
+ BLOCK: tl.constexpr,
+ NUM_BLOCK,
+ E_FBLOCK: tl.constexpr,
+ CBLOCK: tl.constexpr,
+ NUM_CBLOCK: tl.constexpr,
+):
+ # This kernel computes the non-diagonal blocks of the attention matrix
+ # Each non-diagonal block represents attention
+ # where queries attend to keys in different blocks
+ off_bh = tl.program_id(0) # batch-head index
+ off_h = off_bh % h # head index
+
+ off_nc = tl.program_id(1)
+ off_n = off_nc // NUM_CBLOCK # block index
+ off_c = off_nc % NUM_CBLOCK # sub-block index
+ off_e = tl.program_id(2) # output feature block index
+
+ n_offset = off_n * BLOCK
+ c_offset = off_c * CBLOCK
+ e_offset = off_e * E_FBLOCK
+ block_offset = n_offset + c_offset
+
+ # Calculate offsets for the current batch, head, and block
+ q_offset = off_bh * n * d + (n_offset + c_offset) * d
+ o_offset = off_bh * n * e + (n_offset + c_offset) * e + e_offset
+ kv_offset = off_bh * NUM_BLOCK * d * e + off_n * d * e + e_offset
+
+ # Calculate pointers to the query, output, and key-value tensors
+ Q_block_ptr = Q + q_offset + tl.arange(0, CBLOCK)[:, None] * d + tl.arange(0, d)[None, :]
+ O_block_ptr = Out + o_offset + tl.arange(0, CBLOCK)[:, None] * e + tl.arange(0, E_FBLOCK)[None, :]
+ KV_block_ptr = KV + kv_offset + tl.arange(0, d)[:, None] * e + tl.arange(0, E_FBLOCK)[None, :]
+
+ # Load the decay rate for the current head
+ S_block_ptr = S + off_h
+ s = tl.load(S_block_ptr)
+
+ c_array = tl.arange(0, CBLOCK)
+
+ # Load the key-value outer product for the current block
+ kv = tl.load(KV_block_ptr).to(tl.float32)
+ q_index = block_offset + tl.arange(0, CBLOCK)
+
+ # Load query values
+ q = tl.load(Q_block_ptr, mask=q_index[:, None] < n, other=0.0).to(tl.float32)
+
+ # Compute decay factors for the current sub-block
+ q_decay = tl.exp(-s.to(tl.float32) * (off_c * CBLOCK + c_array[:, None]))
+
+ # Compute non-diagonal attention output
+ qkv_none_diag = tl.dot(q, kv) * q_decay
+
+ # Load diagonal attention output (computed by _fwd_diag_kernel)
+ qkv_diag = tl.load(O_block_ptr, mask=q_index[:, None] < n, other=0.0).to(tl.float32)
+
+ # Combine diagonal and non-diagonal attention outputs
+ qkv = qkv_diag + qkv_none_diag
+
+ # Store the result
+ tl.store(O_block_ptr, qkv.to(O_block_ptr.dtype.element_ty), mask=q_index[:, None] < n)
+
+
+@enable_compat_on_triton_kernel
+@triton.jit
+def _linear_attn_decode_kernel(
+ q_ptr,
+ k_ptr,
+ v_ptr,
+ kv_cache_ptr,
+ slope_rate,
+ slot_idx,
+ output_ptr,
+ D: tl.constexpr,
+ qkv_b_stride,
+ qkv_h_stride,
+ cache_b_stride,
+ cache_h_stride,
+ cache_d0_stride,
+ cache_d1_stride,
+ BLOCK_SIZE: tl.constexpr,
+):
+ """
+ Kernel for linear attention decoding with KV cache.
+
+ This kernel computes attention for a single token using the KV cache.
+ """
+ pid_b = tl.program_id(0) # batch index
+ pid_h = tl.program_id(1) # head index
+ pid_d = tl.program_id(2) # dimension block index
+
+ # Load slot index for the current batch
+ slot_id = tl.load(slot_idx + pid_b).to(tl.int64)
+
+ # Skip if slot_id is -1 (padding)
+ if slot_id == -1:
+ return
+
+ batch_id = pid_b
+ head_id = pid_h
+
+ # Load decay rate for the current head
+ ratio = tl.load(slope_rate + pid_h)
+
+ # Calculate offsets for dimensions
+ qk_d_offsets = tl.arange(0, D)
+ v_d_offsets = tl.arange(0, BLOCK_SIZE) + pid_d * BLOCK_SIZE
+ cache_d_offsets = qk_d_offsets[:, None] * cache_d0_stride + v_d_offsets[None, :] * cache_d1_stride
+
+ # Calculate offsets for the current batch and head
+ q_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride
+ k_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride
+ v_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride
+
+ cache_offset = slot_id * cache_b_stride + head_id * cache_h_stride
+
+ # Create masks for loading tensors
+ qk_mask = qk_d_offsets < D
+ v_mask = v_d_offsets < D
+
+ # Load query, key, and value tensors
+ q = tl.load(q_ptr + q_offset + qk_d_offsets, mask=qk_mask, other=0.0)
+ k = tl.load(k_ptr + k_offset + qk_d_offsets, mask=qk_mask, other=0.0)
+ v = tl.load(v_ptr + v_offset + v_d_offsets, mask=v_mask, other=0.0)
+
+ # Compute key-value outer product
+ kv_outer = k[:, None] * v[None, :]
+ kv_mask = qk_mask[:, None] & v_mask[None, :]
+
+ # Apply decay to previous KV cache
+ ratio = tl.exp(-ratio)
+ kv_ptr = kv_cache_ptr + cache_offset + cache_d_offsets
+ kv_cache_old = tl.load(kv_ptr, mask=kv_mask, other=0.0)
+ kv_outer = kv_outer + ratio * kv_cache_old
+
+ # Compute attention output
+ output = q[:, None].to(tl.float32) * kv_outer
+ output = tl.sum(output, axis=0)
+
+ # Update KV cache and store output
+ tl.store(kv_ptr, kv_outer, mask=kv_mask)
+ tl.store(output_ptr + q_offset + v_d_offsets, output, mask=v_mask)
+
+
+# =============================================================================
+# Python wrapper functions — Paddle API
+# =============================================================================
+
+
+def lightning_attention_forward(q, k, v, s, kv_history):
+ """
+ Forward pass of the lightning attention algorithm.
+ Converted from vLLM's torch.autograd.Function to a plain function
+ for inference-only use in FastDeploy.
+
+ Args:
+ q: Query tensor [b, h, n, d]
+ k: Key tensor [b, h, n, d]
+ v: Value tensor [b, h, n, e]
+ s: Decay rate tensor [1, h, 1, 1] or [h]
+ kv_history: KV history tensor [b, h, d, e]
+
+ Returns:
+ o: Output tensor [b, h, n, e]
+ kv_state: Updated KV state tensor
+ """
+ q = q.contiguous()
+ k = k.contiguous()
+ v = v.contiguous()
+ s = s.contiguous()
+
+ # Get input dimensions
+ b, h, n, d = q.shape
+ e = v.shape[-1]
+
+ # Initialize output tensor
+ o = paddle.empty([b, h, n, e], dtype=q.dtype)
+
+ # Set block sizes
+ BLOCK = 256
+ NUM_BLOCK = triton.cdiv(n, BLOCK)
+
+ CBLOCK = 32
+ NUM_CBLOCK = BLOCK // CBLOCK
+ assert BLOCK % CBLOCK == 0, "BLOCK must be a multiple of CBLOCK"
+
+ # Compute decay factors for keys
+ array = paddle.arange(0, BLOCK).astype("float32") + 1
+ k_decay = paddle.exp(-s * (BLOCK - array.reshape([1, -1])))
+
+ # Step 1: Compute diagonal blocks of attention
+ grid = (b * h * NUM_BLOCK, NUM_CBLOCK)
+ _fwd_diag_kernel[grid](
+ q,
+ k,
+ v,
+ o,
+ s,
+ b,
+ h,
+ n,
+ d,
+ e,
+ BLOCK=BLOCK,
+ NUM_BLOCK=NUM_BLOCK,
+ CBLOCK=CBLOCK,
+ )
+
+ # Set feature block sizes
+ NUM_FBLOCK = 1
+ D_FBLOCK = d // NUM_FBLOCK
+ assert d % NUM_FBLOCK == 0
+ E_FBLOCK = e // NUM_FBLOCK
+ assert e % NUM_FBLOCK == 0
+
+ CBLOCK = 64
+ NUM_CBLOCK = BLOCK // CBLOCK
+ assert BLOCK % CBLOCK == 0, "BLOCK must be a multiple of CBLOCK"
+
+ # Step 2: Compute key-value outer products for each block in parallel
+ kv = paddle.empty([b, h, NUM_BLOCK, d, e], dtype="float32")
+ grid = (b * h, NUM_BLOCK)
+ _fwd_kv_parallel[grid](
+ k,
+ v,
+ k_decay,
+ kv,
+ b,
+ h,
+ n,
+ d,
+ e,
+ BLOCK=BLOCK,
+ NUM_BLOCK=NUM_BLOCK,
+ D_FBLOCK=D_FBLOCK,
+ E_FBLOCK=E_FBLOCK,
+ NUM_FBLOCK=NUM_FBLOCK,
+ CBLOCK=CBLOCK,
+ NUM_CBLOCK=NUM_CBLOCK,
+ )
+
+ # Step 3: Reduce key-value outer products
+ # across blocks and update KV history
+ grid = (b * h, NUM_FBLOCK)
+ _fwd_kv_reduce[grid](
+ s,
+ kv,
+ kv_history,
+ b,
+ h,
+ n,
+ d,
+ e,
+ BLOCK=BLOCK,
+ NUM_BLOCK=NUM_BLOCK,
+ D_FBLOCK=D_FBLOCK,
+ E_FBLOCK=E_FBLOCK,
+ )
+
+ # Step 4: Compute non-diagonal blocks of attention
+ grid = (b * h, NUM_BLOCK * NUM_CBLOCK)
+ _fwd_none_diag_kernel[grid](
+ q,
+ o,
+ s,
+ kv,
+ b,
+ h,
+ n,
+ d,
+ e,
+ BLOCK=BLOCK,
+ NUM_BLOCK=NUM_BLOCK,
+ E_FBLOCK=E_FBLOCK,
+ CBLOCK=CBLOCK,
+ NUM_CBLOCK=NUM_CBLOCK,
+ )
+
+ return o, paddle.concat([kv, kv_history.unsqueeze(2)], axis=2)
+
+
+def lightning_attention(
+ q: paddle.Tensor,
+ k: paddle.Tensor,
+ v: paddle.Tensor,
+ ed: paddle.Tensor,
+ block_size: int = 256,
+ kv_history: paddle.Tensor | None = None,
+) -> tuple[paddle.Tensor, paddle.Tensor]:
+ """
+ Apply lightning attention algorithm to compute attention efficiently.
+
+ Args:
+ q: Query tensor of shape [batch, heads, seq_len, dim]
+ k: Key tensor of shape [batch, heads, seq_len, dim]
+ v: Value tensor of shape [batch, heads, seq_len, dim_v]
+ ed: Decay rate tensor of shape [heads]
+ block_size: Size of blocks for block-sparse attention
+ kv_history: Optional key-value history from previous computations
+
+ Returns:
+ output: Attention output
+ kv: Updated key-value history
+ """
+ d = q.shape[-1]
+ e = v.shape[-1]
+
+ if ed.ndim == 1:
+ ed = ed.reshape([1, -1, 1, 1])
+
+ # Split the computation into chunks for better parallelism
+ m = 128 if d >= 128 else 64
+ assert d % m == 0, f"Dimension d ({d}) must be divisible by m ({m})"
+ arr = [m * i for i in range(d // m + 1)]
+ if arr[-1] != d:
+ arr.append(d)
+ n = len(arr)
+ output = 0
+
+ # Initialize or clone key-value history
+ if kv_history is None:
+ kv_history = paddle.zeros([q.shape[0], q.shape[1], d, e], dtype="float32")
+ else:
+ kv_history = kv_history.clone().contiguous()
+
+ # Process each chunk and accumulate results
+ for i in range(n - 1):
+ s = arr[i]
+ e = arr[i + 1]
+ q1 = q[..., s:e]
+ k1 = k[..., s:e]
+ o, kv = lightning_attention_forward(q1, k1, v, ed, kv_history)
+ output = output + o
+ return output, kv
+
+
+def linear_decode_forward_triton(
+ q: paddle.Tensor,
+ k: paddle.Tensor,
+ v: paddle.Tensor,
+ kv_caches: paddle.Tensor,
+ slope_rate: paddle.Tensor,
+ slot_idx: paddle.Tensor,
+ BLOCK_SIZE: int = 32,
+) -> paddle.Tensor:
+ """
+ Perform linear attention decoding using Triton kernels.
+
+ Args:
+ q: Query tensor of shape [B, H, 1, D]
+ k: Key tensor of shape [B, H, 1, D]
+ v: Value tensor of shape [B, H, 1, D]
+ kv_caches: Key-value cache tensor
+ slope_rate: Decay rate tensor
+ slot_idx: Slot indices for batches
+ BLOCK_SIZE: Size of blocks for processing
+
+ Returns:
+ output: Attention output tensor of shape [B, H*D]
+ """
+ B, H, _, D = q.shape
+ assert k.shape == [B, H, 1, D]
+ assert v.shape == [B, H, 1, D]
+
+ # Initialize output tensor
+ output = paddle.empty_like(q)
+
+ # Set grid dimensions for the kernel
+ grid = (B, H, D // BLOCK_SIZE)
+
+ # Calculate strides for tensors
+ qkv_b_stride = q.strides[0]
+ qkv_h_stride = q.strides[1]
+
+ cache_b_stride = kv_caches.strides[0]
+ cache_h_stride = kv_caches.strides[1]
+ cache_d0_stride = kv_caches.strides[2]
+ cache_d1_stride = kv_caches.strides[3]
+
+ # Launch the kernel
+ _linear_attn_decode_kernel[grid](
+ q,
+ k,
+ v,
+ kv_caches,
+ slope_rate,
+ slot_idx,
+ output,
+ D,
+ qkv_b_stride,
+ qkv_h_stride,
+ cache_b_stride,
+ cache_h_stride,
+ cache_d0_stride,
+ cache_d1_stride,
+ BLOCK_SIZE=BLOCK_SIZE,
+ )
+
+ # Reshape output: "b h n d -> b n (h d)"
+ # output shape: [B, H, 1, D] -> transpose to [B, 1, H, D] -> reshape to [B, 1, H*D]
+ output = output.transpose([0, 2, 1, 3]).reshape([B, 1, -1])
+ return output.squeeze(1).contiguous()
diff --git a/tests/model_executor/test_minimax_m1.py b/tests/model_executor/test_minimax_m1.py
new file mode 100644
index 00000000000..c2da39d07d2
--- /dev/null
+++ b/tests/model_executor/test_minimax_m1.py
@@ -0,0 +1,405 @@
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Tests for MiniMax-M1 model scaffold.
+Validates architecture dispatch, slope construction, registration, and forward paths.
+
+Uses importlib to load minimax_m1.py directly, bypassing fastdeploy/__init__.py
+which pulls in the full inference engine (etcd, Redis, GPU ops, etc.).
+All heavy submodules are replaced with lightweight stubs so tests run on CPU.
+"""
+
+import importlib
+import importlib.util
+import math
+import sys
+from types import SimpleNamespace
+from unittest.mock import MagicMock
+
+import numpy as np
+import paddle
+import pytest
+
+# ---------------------------------------------------------------------------
+# Module-level setup: load minimax_m1 with stub dependencies
+# ---------------------------------------------------------------------------
+
+# 1) paddleformers stubs
+_PretrainedModel = type("PretrainedModel", (), {})
+
+
+class _PretrainedConfig:
+ prefix_name = ""
+
+ @classmethod
+ def get_config_dict(cls, model_path, **kw):
+ import json as _j
+ import os as _o
+
+ with open(_o.path.join(model_path, "config.json")) as f:
+ return _j.load(f), {}
+
+ @classmethod
+ def from_dict(cls, d):
+ ns = SimpleNamespace(**d)
+ ns.prefix_name = ""
+ return ns
+
+
+_cfg_mod = MagicMock()
+_cfg_mod.PretrainedConfig = _PretrainedConfig
+_transf = MagicMock()
+_transf.PretrainedModel = _PretrainedModel
+_transf.configuration_utils = _cfg_mod
+_transf.PretrainedConfig = _PretrainedConfig
+
+sys.modules.setdefault("paddleformers", MagicMock())
+sys.modules["paddleformers.transformers"] = _transf
+sys.modules["paddleformers.transformers.configuration_utils"] = _cfg_mod
+sys.modules.setdefault("paddleformers.utils", MagicMock())
+sys.modules["paddleformers.utils.log"] = MagicMock()
+
+# 2) Lightweight fastdeploy namespace (bypass __init__.py)
+_fd_ns = type(sys)("fastdeploy")
+_fd_ns.__path__ = ["fastdeploy"]
+_fd_ns.__file__ = "fastdeploy/__init__.py"
+sys.modules["fastdeploy"] = _fd_ns
+
+for _pkg, _path in [
+ ("fastdeploy.model_executor", "fastdeploy/model_executor"),
+ ("fastdeploy.model_executor.models", "fastdeploy/model_executor/models"),
+]:
+ _m = type(sys)(_pkg)
+ _m.__path__ = [_path]
+ sys.modules[_pkg] = _m
+
+# 3) Mock all heavy fastdeploy submodules
+for _mod_name in [
+ "fastdeploy.config",
+ "fastdeploy.distributed",
+ "fastdeploy.distributed.communication",
+ "fastdeploy.model_executor.forward_meta",
+ "fastdeploy.model_executor.graph_optimization",
+ "fastdeploy.model_executor.graph_optimization.decorator",
+ "fastdeploy.model_executor.layers",
+ "fastdeploy.model_executor.layers.activation",
+ "fastdeploy.model_executor.layers.attention",
+ "fastdeploy.model_executor.layers.attention.attention",
+ "fastdeploy.model_executor.layers.embeddings",
+ "fastdeploy.model_executor.layers.linear",
+ "fastdeploy.model_executor.layers.lm_head",
+ "fastdeploy.model_executor.layers.moe",
+ "fastdeploy.model_executor.layers.moe.moe",
+ "fastdeploy.model_executor.layers.normalization",
+ "fastdeploy.model_executor.models.model_base",
+ "fastdeploy.model_executor.ops",
+ "fastdeploy.model_executor.ops.triton_ops",
+ "fastdeploy.model_executor.ops.triton_ops.lightning_attn",
+]:
+ if _mod_name not in sys.modules:
+ sys.modules[_mod_name] = MagicMock()
+
+
+# 4) Real ModelRegistry so @register_model_class works
+class _ModelCategory:
+ TEXT_GENERATION = "text_generation"
+
+
+class _ModelRegistry:
+ _arch_to_model_cls = {}
+ _enhanced_models = {}
+
+ @classmethod
+ def register_model_class(cls, model_class=None, **kw):
+ def _register(mc):
+ arch = kw.get("architecture", mc.name())
+ cls._arch_to_model_cls[arch] = mc
+ return mc
+
+ return _register(model_class) if model_class is not None else _register
+
+
+_ModelForCasualLM = type("ModelForCasualLM", (), {"name": classmethod(lambda cls: "base")})
+
+_mb = sys.modules["fastdeploy.model_executor.models.model_base"]
+_mb.ModelCategory = _ModelCategory
+_mb.ModelRegistry = _ModelRegistry
+_mb.ModelForCasualLM = _ModelForCasualLM
+
+# support_graph_optimization → identity
+sys.modules["fastdeploy.model_executor.graph_optimization.decorator"].support_graph_optimization = lambda cls: cls
+
+# 5) Load minimax_m1.py via importlib
+_spec = importlib.util.spec_from_file_location(
+ "fastdeploy.model_executor.models.minimax_m1",
+ "fastdeploy/model_executor/models/minimax_m1.py",
+ submodule_search_locations=[],
+)
+_mod = importlib.util.module_from_spec(_spec)
+sys.modules[_spec.name] = _mod
+_spec.loader.exec_module(_mod)
+
+# Import symbols
+MiniMaxM1DecoderLayer = _mod.MiniMaxM1DecoderLayer
+MiniMaxM1LinearAttention = _mod.MiniMaxM1LinearAttention
+MiniMaxM1Attention = _mod.MiniMaxM1Attention
+MiniMaxM1MoE = _mod.MiniMaxM1MoE
+MiniMaxM1MLP = _mod.MiniMaxM1MLP
+MiniMaxM1ForCausalLM = _mod.MiniMaxM1ForCausalLM
+MiniMaxM1PretrainedModel = _mod.MiniMaxM1PretrainedModel
+MiniMaxM1Model = _mod.MiniMaxM1Model
+ModelRegistry = _ModelRegistry
+
+
+# ===================================================================
+# 1. Pure-logic tests
+# ===================================================================
+
+
+class TestBuildAttnTypeList:
+
+ def test_80_layers_has_10_full_attention(self):
+ attn_list = MiniMaxM1DecoderLayer._build_attn_type_list(80)
+ assert len(attn_list) == 80
+ full_indices = [i for i, t in enumerate(attn_list) if t == 1]
+ assert full_indices == [7, 15, 23, 31, 39, 47, 55, 63, 71, 79]
+
+ def test_short_model_clips_indices(self):
+ attn_list = MiniMaxM1DecoderLayer._build_attn_type_list(10)
+ assert len(attn_list) == 10
+ assert attn_list[7] == 1
+ assert sum(attn_list) == 1
+
+ def test_single_layer_all_linear(self):
+ assert MiniMaxM1DecoderLayer._build_attn_type_list(1) == [0]
+
+ def test_all_linear_below_first_full_index(self):
+ assert all(t == 0 for t in MiniMaxM1DecoderLayer._build_attn_type_list(7))
+
+
+class TestBuildSlopeTensor:
+
+ def test_power_of_two_heads(self):
+ slopes = MiniMaxM1LinearAttention._build_slope_tensor(8)
+ assert slopes.shape == [8, 1, 1]
+ assert (slopes.flatten().numpy() > 0).all()
+
+ def test_non_power_of_two_heads(self):
+ slopes = MiniMaxM1LinearAttention._build_slope_tensor(12)
+ assert slopes.shape == [12, 1, 1]
+ assert (slopes.flatten().numpy() > 0).all()
+
+ def test_64_heads_first_slope(self):
+ slopes = MiniMaxM1LinearAttention._build_slope_tensor(64)
+ assert slopes.shape == [64, 1, 1]
+ expected_start = 2 ** (-(2 ** (-(math.log2(64) - 3))))
+ np.testing.assert_allclose(slopes.flatten().numpy()[0], expected_start, rtol=1e-5)
+
+ @pytest.mark.parametrize("n", [1, 2, 4, 8, 16, 32, 64])
+ def test_slopes_all_positive(self, n):
+ slopes = MiniMaxM1LinearAttention._build_slope_tensor(n)
+ assert (slopes.flatten().numpy() > 0).all()
+
+
+# ===================================================================
+# 2. Model registration
+# ===================================================================
+
+
+class TestModelRegistration:
+
+ def test_primary_architecture_registered(self):
+ assert "MiniMaxM1ForCausalLM" in ModelRegistry._arch_to_model_cls
+
+ def test_alias_architecture_registered(self):
+ assert "MiniMaxText01ForCausalLM" in ModelRegistry._arch_to_model_cls
+
+ def test_registered_class(self):
+ assert ModelRegistry._arch_to_model_cls["MiniMaxM1ForCausalLM"] is MiniMaxM1ForCausalLM
+
+ def test_name_method(self):
+ assert MiniMaxM1ForCausalLM.name() == "MiniMaxM1ForCausalLM"
+
+ def test_pretrained_name(self):
+ assert MiniMaxM1PretrainedModel.arch_name() == "MiniMaxM1ForCausalLM"
+ assert MiniMaxM1PretrainedModel.name() == "MiniMaxM1ForCausalLM"
+
+
+# ===================================================================
+# 3. Layer construction (lightweight fd_config mock)
+# ===================================================================
+
+
+def _make_fd_config(num_layers=4, attn_type_list=None, num_local_experts=4):
+ if attn_type_list is None:
+ attn_type_list = [0, 0, 0, 1][:num_layers]
+ mc = SimpleNamespace(
+ hidden_size=256,
+ intermediate_size=512,
+ num_hidden_layers=num_layers,
+ num_attention_heads=8,
+ num_key_value_heads=2,
+ head_dim=32,
+ vocab_size=1024,
+ rms_norm_eps=1e-6,
+ hidden_act="silu",
+ num_local_experts=num_local_experts,
+ num_experts_per_tok=2,
+ norm_topk_prob=False,
+ postnorm=False,
+ attn_type_list=attn_type_list,
+ layernorm_full_attention_alpha=3.556,
+ layernorm_full_attention_beta=1.0,
+ layernorm_linear_attention_alpha=3.556,
+ layernorm_linear_attention_beta=1.0,
+ layernorm_mlp_alpha=3.556,
+ layernorm_mlp_beta=1.0,
+ pretrained_config=SimpleNamespace(prefix_name="model"),
+ )
+ pc = SimpleNamespace(tensor_parallel_size=1, tp_group=None)
+ return SimpleNamespace(model_config=mc, parallel_config=pc)
+
+
+class TestDecoderLayerConstruction:
+
+ def test_linear_attention_layer(self):
+ fd = _make_fd_config()
+ layer = MiniMaxM1DecoderLayer(fd, layer_id=0, prefix="model.layers.0")
+ assert layer.attention_type == 0
+ assert isinstance(layer.self_attn, MiniMaxM1LinearAttention)
+ assert hasattr(layer.self_attn, "slope_rate")
+ assert hasattr(layer.self_attn, "output_gate")
+ assert hasattr(layer.self_attn, "norm")
+ assert hasattr(layer.self_attn, "out_proj")
+
+ def test_full_attention_layer(self):
+ fd = _make_fd_config()
+ layer = MiniMaxM1DecoderLayer(fd, layer_id=3, prefix="model.layers.3")
+ assert layer.attention_type == 1
+ assert isinstance(layer.self_attn, MiniMaxM1Attention)
+
+ def test_deepnorm_defaults(self):
+ fd = _make_fd_config()
+ layer = MiniMaxM1DecoderLayer(fd, layer_id=0, prefix="model.layers.0")
+ assert layer.layernorm_attention_alpha == 3.556
+ assert layer.layernorm_mlp_alpha == 3.556
+
+ def test_moe_when_experts_gt_1(self):
+ fd = _make_fd_config(num_local_experts=4)
+ layer = MiniMaxM1DecoderLayer(fd, layer_id=0, prefix="model.layers.0")
+ assert isinstance(layer.block_sparse_moe, MiniMaxM1MoE)
+
+ def test_dense_mlp_when_single_expert(self):
+ fd = _make_fd_config(num_local_experts=1)
+ layer = MiniMaxM1DecoderLayer(fd, layer_id=0, prefix="model.layers.0")
+ assert isinstance(layer.block_sparse_moe, MiniMaxM1MLP)
+
+ def test_fallback_attn_type_when_no_config(self):
+ fd = _make_fd_config(num_layers=80)
+ delattr(fd.model_config, "attn_type_list")
+ layer = MiniMaxM1DecoderLayer(fd, layer_id=7, prefix="model.layers.7")
+ assert layer.attention_type == 1
+
+ def test_moe_default_weight_key_map(self):
+ """Unquantized config → weight_key_map has plain .weight keys."""
+ fd = _make_fd_config(num_local_experts=4)
+ FusedMoE = sys.modules["fastdeploy.model_executor.layers.moe.moe"].FusedMoE
+ FusedMoE.reset_mock()
+ MiniMaxM1MoE(fd, layer_id=0, prefix="model.layers.0.block_sparse_moe")
+ wkm = FusedMoE.call_args[1]["weight_key_map"]
+ assert "gate_weight_key" in wkm
+ assert wkm["up_gate_proj_expert_weight_key"].endswith(".up_gate_proj.weight")
+ assert "weight_scale" not in str(wkm)
+
+ def test_moe_w4a8_weight_key_map(self):
+ """w4a8 quant config → weight_key_map has .quant_weight + scales."""
+ fd = _make_fd_config(num_local_experts=4)
+ fd.quant_config = SimpleNamespace(moe_quant_type="w4a8")
+ fd.model_config.is_quantized = True
+ FusedMoE = sys.modules["fastdeploy.model_executor.layers.moe.moe"].FusedMoE
+ FusedMoE.reset_mock()
+ MiniMaxM1MoE(fd, layer_id=0, prefix="model.layers.0.block_sparse_moe")
+ wkm = FusedMoE.call_args[1]["weight_key_map"]
+ assert "quant_weight" in wkm["up_gate_proj_expert_weight_key"]
+ assert "weight_scale" in wkm["up_gate_proj_expert_weight_scale_key"]
+ assert "activation_scale" in wkm["up_gate_proj_expert_in_scale_key"]
+
+ def test_moe_w4afp8_dynamic_weight_key_map(self):
+ """Dynamic w4afp8 → quant_weight + weight_scale but no activation_scale."""
+ fd = _make_fd_config(num_local_experts=4)
+ fd.quant_config = SimpleNamespace(moe_quant_type="w4afp8", moe_dynamic_quant=True)
+ fd.model_config.is_quantized = True
+ FusedMoE = sys.modules["fastdeploy.model_executor.layers.moe.moe"].FusedMoE
+ FusedMoE.reset_mock()
+ MiniMaxM1MoE(fd, layer_id=0, prefix="model.layers.0.block_sparse_moe")
+ wkm = FusedMoE.call_args[1]["weight_key_map"]
+ assert "quant_weight" in wkm["up_gate_proj_expert_weight_key"]
+ assert "weight_scale" in wkm["up_gate_proj_expert_weight_scale_key"]
+ assert "in_scale_key" not in str(wkm)
+
+
+# ===================================================================
+# 4. Forward-pass smoke tests
+# ===================================================================
+
+
+class TestDecoderLayerForward:
+
+ @staticmethod
+ def _patch_layer(layer, hidden_size):
+ def _norm_fn(x, residual_input=None, forward_meta=None):
+ if residual_input is None:
+ residual_input = paddle.zeros_like(x)
+ return x, residual_input + x
+
+ object.__setattr__(layer, "input_layernorm", MagicMock(side_effect=_norm_fn))
+ object.__setattr__(layer, "post_attention_layernorm", MagicMock(side_effect=_norm_fn))
+ object.__setattr__(layer, "self_attn", MagicMock(return_value=paddle.randn([4, hidden_size])))
+ object.__setattr__(layer, "block_sparse_moe", MagicMock(return_value=paddle.randn([4, hidden_size])))
+
+ def test_linear_layer_returns_tuple(self):
+ fd = _make_fd_config()
+ layer = MiniMaxM1DecoderLayer(fd, layer_id=0, prefix="model.layers.0")
+ self._patch_layer(layer, 256)
+
+ out = layer(forward_meta=SimpleNamespace(), hidden_states=paddle.randn([4, 256]))
+ assert isinstance(out, tuple) and len(out) == 2
+ assert out[0].shape == [4, 256]
+
+ def test_full_attn_layer_returns_tuple(self):
+ fd = _make_fd_config()
+ layer = MiniMaxM1DecoderLayer(fd, layer_id=3, prefix="model.layers.3")
+ self._patch_layer(layer, 256)
+
+ out = layer(forward_meta=SimpleNamespace(), hidden_states=paddle.randn([4, 256]))
+ assert isinstance(out, tuple) and len(out) == 2
+
+ def test_deepnorm_scaling_applied(self):
+ fd = _make_fd_config()
+ fd.model_config.layernorm_linear_attention_alpha = 2.0
+ fd.model_config.layernorm_mlp_alpha = 3.0
+ layer = MiniMaxM1DecoderLayer(fd, layer_id=0, prefix="model.layers.0")
+ assert layer.layernorm_attention_alpha == 2.0
+ assert layer.layernorm_mlp_alpha == 3.0
+
+ def test_postnorm_forward(self):
+ fd = _make_fd_config()
+ fd.model_config.postnorm = True
+ layer = MiniMaxM1DecoderLayer(fd, layer_id=0, prefix="model.layers.0")
+ self._patch_layer(layer, 256)
+
+ out = layer(forward_meta=SimpleNamespace(), hidden_states=paddle.randn([4, 256]))
+ assert isinstance(out, tuple) and len(out) == 2
+ assert out[0].shape == [4, 256]