Skip to content
4 changes: 3 additions & 1 deletion docs/CN/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ lightllm 支持大多数的主流的开源大语言模型以及多模态模型
-
* - `Qwen3-Moe <https://github.com/QwenLM/Qwen3>`_
-
* - `GLM-5.2 <https://huggingface.co/zai-org/GLM-5.2>`_
- 支持 BF16/FP8 和 MTP。


多模态模型
Expand Down Expand Up @@ -93,4 +95,4 @@ Reward模型
* - `internLM-reward <https://huggingface.co/internlm/internlm2-1_8b-reward>`_
- :code:`--use_reward_model`
* - `Qwen2-Reward <https://huggingface.co/Qwen/Qwen2-Reward>`_
- :code:`--use_reward_model`
- :code:`--use_reward_model`
3 changes: 2 additions & 1 deletion docs/EN/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ Large Language Models
-
* - `DeepSeek-V3.2 `_
-
* - `GLM-5.2 <https://huggingface.co/zai-org/GLM-5.2>`_
- Supports BF16/FP8 and MTP.

Multimodal Models
^^^^^^^^^^^^^^^^^
Expand Down Expand Up @@ -94,4 +96,3 @@ Reward Models
- :code:`--use_reward_model`
* - `Qwen2-Reward <https://huggingface.co/Qwen/Qwen2-Reward>`_
- :code:`--use_reward_model`

9 changes: 8 additions & 1 deletion lightllm/common/basemodel/attention/nsa/flashmla_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import dataclasses
import torch
import torch.nn.functional as F
from typing import Tuple, TYPE_CHECKING

from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl
Expand Down Expand Up @@ -86,14 +87,20 @@ def _nsa_prefill_att(
if topk_mem_indices.ndim == 2:
topk_mem_indices = topk_mem_indices.unsqueeze(1)

real_head_num = q.shape[1]
head_block_size = 64
pad_head_num = (-real_head_num) % head_block_size
if pad_head_num:
q = F.pad(q, (0, 0, 0, pad_head_num))

mla_out, _, _ = flash_mla_sparse_fwd(
q=q,
kv=kv,
indices=topk_mem_indices,
sm_scale=softmax_scale,
d_v=kv_lora_rank,
)
return mla_out
return mla_out[:, :real_head_num, :]


@dataclasses.dataclass
Expand Down
66 changes: 50 additions & 16 deletions lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import dataclasses
import torch
import torch.nn.functional as F
from typing import TYPE_CHECKING, Tuple

from ..base_att import AttControl, BaseAttBackend, BaseDecodeAttState, BasePrefillAttState
Expand Down Expand Up @@ -70,10 +71,9 @@ def _nsa_prefill_att(
packed_kv: torch.Tensor,
att_control: AttControl,
) -> torch.Tensor:
import flash_mla
from sgl_kernel.flash_mla import flash_mla_sparse_fwd

nsa_dict = att_control.nsa_prefill_dict
topk_indices = nsa_dict["topk_indices"]
softmax_scale = nsa_dict["softmax_scale"]
kv_lora_rank = nsa_dict["kv_lora_rank"]
topk_mem_indices = nsa_dict["topk_mem_indices"]
Expand All @@ -91,18 +91,25 @@ def _nsa_prefill_att(
)
else:
kv = prefill_cache_kv
topk_indices = nsa_dict["topk_indices"]

if topk_indices.ndim == 2:
topk_indices = topk_indices.unsqueeze(1)

mla_out, _, _ = flash_mla.flash_mla_sparse_fwd(
real_head_num = q.shape[1]
head_block_size = 64
pad_head_num = (-real_head_num) % head_block_size
if pad_head_num:
q = F.pad(q, (0, 0, 0, pad_head_num))

mla_out, _, _ = flash_mla_sparse_fwd(
q=q,
kv=kv,
indices=topk_indices,
sm_scale=softmax_scale,
d_v=kv_lora_rank,
)
return mla_out
return mla_out[:, :real_head_num, :]


@dataclasses.dataclass
Expand All @@ -111,7 +118,6 @@ class NsaFlashMlaFp8SparseDecodeAttState(BaseDecodeAttState):
ke: torch.Tensor = None
lengths: torch.Tensor = None
ragged_mem_index: torch.Tensor = None
flashmla_sched_meta: object = None

def init_state(self):
self.backend: NsaFlashMlaFp8SparseAttBackend = self.backend
Expand Down Expand Up @@ -141,9 +147,6 @@ def init_state(self):
ragged_mem_index=self.ragged_mem_index,
hold_req_idx=self.infer_state.req_manager.HOLD_REQUEST_ID,
)
import flash_mla

self.flashmla_sched_meta, _ = flash_mla.get_mla_metadata()
return

def decode_att(
Expand All @@ -164,7 +167,7 @@ def _nsa_decode_att(
packed_kv: torch.Tensor,
att_control: AttControl,
) -> torch.Tensor:
import flash_mla
from sgl_kernel.flash_mla import flash_mla_with_kvcache, get_mla_metadata

nsa_dict = att_control.nsa_decode_dict
topk_mem_indices = nsa_dict["topk_mem_indices"]
Expand All @@ -177,22 +180,53 @@ def _nsa_decode_att(

q_nope, q_rope = q
q_all = torch.cat([q_nope, q_rope], dim=-1).unsqueeze(1).contiguous()

real_head_num = q_all.shape[2]
if real_head_num <= 64:
padded_head_num = 64
elif real_head_num <= 128:
padded_head_num = 128
else:
padded_head_num = real_head_num
if padded_head_num != real_head_num:
q_all = F.pad(q_all, (0, 0, 0, padded_head_num - real_head_num))

cache_seqlens = self.infer_state.b_seq_len.to(dtype=torch.int32)
page_block_size = 64
num_heads_k = 1
num_heads_q = q_all.shape[2]
tile_scheduler_metadata, num_splits = get_mla_metadata(
cache_seqlens=cache_seqlens,
num_q_tokens_per_head_k=num_heads_q // num_heads_k,
num_heads_k=num_heads_k,
num_heads_q=num_heads_q,
is_fp8_kvcache=True,
topk=topk_mem_indices.shape[-1],
)
kv = torch.as_strided(
packed_kv,
size=(packed_kv.shape[0], 1, 1, packed_kv.shape[-1]),
stride=(packed_kv.stride(0), packed_kv.shape[-1], packed_kv.shape[-1], packed_kv.stride(-1)),
size=(packed_kv.shape[0] // page_block_size, page_block_size, 1, packed_kv.shape[-1]),
stride=(
packed_kv.stride(0) * page_block_size,
packed_kv.stride(0),
packed_kv.shape[-1],
packed_kv.stride(-1),
),
)
block_table = torch.empty((cache_seqlens.shape[0], 0), dtype=torch.int32, device=q_all.device)

o_tensor, _ = flash_mla.flash_mla_with_kvcache(
o_tensor, _ = flash_mla_with_kvcache(
q=q_all,
k_cache=kv,
block_table=None,
cache_seqlens=None,
block_table=block_table,
cache_seqlens=cache_seqlens,
head_dim_v=kv_lora_rank,
tile_scheduler_metadata=self.flashmla_sched_meta,
tile_scheduler_metadata=tile_scheduler_metadata,
num_splits=num_splits,
softmax_scale=softmax_scale,
causal=False,
is_fp8_kvcache=True,
indices=topk_mem_indices,
indices=topk_mem_indices.to(dtype=torch.int32),
)
Comment on lines +218 to 230

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In _nsa_decode_att, topk_mem_indices is passed directly to flash_mla_with_kvcache as indices=topk_mem_indices.to(dtype=torch.int32).

However, topk_mem_indices is a 2D tensor of shape [batch_size, topk], whereas flash_mla_with_kvcache expects indices to be a 3D tensor of shape [batch_size, q_seqlen, topk] (where q_seqlen = 1 in decode). Passing a 2D tensor can cause shape mismatch errors or silent correctness issues in the kernel.

We should unsqueeze indices to 3D if it is 2D, similar to how it is handled in _nsa_prefill_att.

Suggested change
o_tensor, _ = flash_mla_with_kvcache(
q=q_all,
k_cache=kv,
block_table=None,
cache_seqlens=None,
block_table=block_table,
cache_seqlens=cache_seqlens,
head_dim_v=kv_lora_rank,
tile_scheduler_metadata=self.flashmla_sched_meta,
tile_scheduler_metadata=tile_scheduler_metadata,
num_splits=num_splits,
softmax_scale=softmax_scale,
causal=False,
is_fp8_kvcache=True,
indices=topk_mem_indices,
indices=topk_mem_indices.to(dtype=torch.int32),
)
indices = topk_mem_indices.to(dtype=torch.int32)
if indices.ndim == 2:
indices = indices.unsqueeze(1)
o_tensor, _ = flash_mla_with_kvcache(
q=q_all,
k_cache=kv,
block_table=block_table,
cache_seqlens=cache_seqlens,
head_dim_v=kv_lora_rank,
tile_scheduler_metadata=tile_scheduler_metadata,
num_splits=num_splits,
softmax_scale=softmax_scale,
causal=False,
is_fp8_kvcache=True,
indices=indices,
)

o_tensor = o_tensor[:, :, :real_head_num, :]
return o_tensor[:, 0, :, :] # [b, 1, h, d] -> [b, h, d]
7 changes: 1 addition & 6 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,12 +1171,7 @@ def _init_padded_req(self):
def _gen_special_model_input(self, token_num: int):
special_model_input = {}

is_mtp_draft_model = (
"Deepseek3MTPModel" in str(self.__class__)
or "Qwen3MOEMTPModel" in str(self.__class__)
or "MistralMTPModel" in str(self.__class__)
or "Glm4MoeLiteMTPModel" in str(self.__class__)
)
is_mtp_draft_model = getattr(self, "is_mtp_draft_model", False)
if is_mtp_draft_model:
special_model_input["mtp_draft_input_hiddens"] = torch.randn(
token_num, self.config["hidden_size"], dtype=self.data_type, device="cuda"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .base_weight import BaseWeightTpl
from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_dp, get_dp_world_size
from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import rmsnorm_forward
from lightllm.common.basemodel.triton_kernel.norm.fused_add_rmsnorm import fused_add_rmsnorm_forward
from lightllm.common.basemodel.triton_kernel.norm.layernorm import layernorm_forward
from lightllm.common.basemodel.triton_kernel.norm.qk_norm import qk_rmsnorm_fused_forward
from lightllm.common.basemodel.triton_kernel.norm.gated_rmsnorm import gated_rmsnorm_forward
Expand Down Expand Up @@ -71,6 +72,21 @@ def __call__(
) -> torch.Tensor:
return self._forward(input=input, eps=eps, out=out, alloc_func=alloc_func)

def fused_add_forward(
self,
residual: torch.Tensor,
x: torch.Tensor,
eps: float,
out: Optional[torch.Tensor] = None,
alloc_func=torch.empty,
) -> torch.Tensor:
"""Fused residual-add + RMSNorm: ``residual <- residual + x`` (in place) and return
``rmsnorm(residual) * weight``. Bit-identical to a plain ``residual.add_(x)`` followed
by ``__call__`` but in a single Triton launch. CUDA/MUSA (Triton) only."""
if out is None:
out = alloc_func(residual.shape, dtype=residual.dtype, device=residual.device)
return fused_add_rmsnorm_forward(residual=residual, x=x, weight=self.weight, eps=eps, out=out)


class GatedRMSNormWeight(RMSNormWeight):
def _triton_forward(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import torch
import triton
import triton.language as tl
Expand All @@ -26,6 +27,7 @@
from lightllm.utils.vllm_utils import vllm_ops
from lightllm.utils.device_utils import triton_support_tensor_descriptor
from .moe_silu_and_mul import silu_and_mul_fwd
from .moe_silu_and_mul_group_quant import silu_and_mul_group_quant_fwd
from .moe_sum_reduce import moe_sum_reduce
from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import per_token_group_quant_fp8
from lightllm.utils.torch_ops_utils import direct_register_custom_op
Expand All @@ -35,6 +37,11 @@

logger = init_logger(__name__)

# Fuse the down-projection's per-token-group fp8 quant into the silu_and_mul that produces its
# input (eliminates a separate per_token_group_quant launch per MoE layer). Numerically matches
# the unfused path; gate exists for A/B and rollback.
_ENABLE_FUSED_SILU_QUANT = os.environ.get("LIGHTLLM_FUSED_SILU_QUANT", "1") == "1"


@triton.jit
def moe_align_kernel(
Expand Down Expand Up @@ -762,8 +769,12 @@ def grouped_matmul(
assert BLOCK_SIZE_K == triton.next_power_of_2(BLOCK_SIZE_K)

if use_fp8_w8a8:
if token_inputs.dtype == expert_weights.dtype:
# token_inputs were already fp8-quantized by a fused producer (e.g. the fused
# silu+mul+group-quant on the down projection); token_input_scale is its scale.
assert token_input_scale is not None, "pre-quantized fp8 input requires its scale"
# 当权重使用 block wise 量化时,激活也使用 per token, group size 量化
if block_size_k == 0:
elif block_size_k == 0:
# input 使用 per token 量化
token_inputs, token_input_scale = vllm_ops.scaled_fp8_quant(
token_inputs, token_input_scale, use_per_token_if_dynamic=True
Expand Down Expand Up @@ -984,18 +995,44 @@ def fused_experts_impl(
bias=w1_bias,
)

silu_and_mul_fwd(
intermediate_cache1.view(-1, N),
intermediate_cache2.view(-1, N // 2),
limit=limit,
alpha=alpha,
layout=layout,
)
# Fuse the down-projection's per-token-group fp8 quant into silu_and_mul when the down
# weight is block-wise fp8 quantized: the silu output is emitted directly as fp8 + scales,
# so grouped_matmul below skips its internal per_token_group_quant launch.
use_fused_silu_quant = _ENABLE_FUSED_SILU_QUANT and use_fp8_w8a8 and w2_scale is not None and w2_scale.ndim == 3
if use_fused_silu_quant:
down_token_num = curr_topk_ids.numel()
down_k = N // 2
down_group_size = down_k // w2_scale.shape[2]
down_inputs = alloc_tensor_func(
(down_token_num, down_k), device=hidden_states.device, dtype=torch.float8_e4m3fn
)
down_input_scale = alloc_tensor_func(
(down_token_num, down_k // down_group_size), device=hidden_states.device, dtype=torch.float32
)
silu_and_mul_group_quant_fwd(
intermediate_cache1.view(-1, N),
down_inputs,
down_input_scale,
down_group_size,
layout=layout,
limit=limit,
alpha=alpha,
)
else:
silu_and_mul_fwd(
intermediate_cache1.view(-1, N),
intermediate_cache2.view(-1, N // 2),
limit=limit,
alpha=alpha,
layout=layout,
)
down_inputs = intermediate_cache2.view(-1, N // 2)
down_input_scale = a2_scale

grouped_matmul(
curr_topk_ids.numel(),
intermediate_cache2.view(-1, N // 2),
a2_scale,
down_inputs,
down_input_scale,
expert_to_token_num,
expert_to_tokens,
expert_to_weights=expert_to_weights,
Expand Down
Loading
Loading