From 703de087d2f7e06f90616163610a2f125956b686 Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 19 Jun 2026 00:13:53 +0800 Subject: [PATCH 1/8] support glm52 --- .../attention/nsa/flashmla_sparse.py | 9 +- .../attention/nsa/fp8_flashmla_sparse.py | 64 ++- lightllm/common/basemodel/basemodel.py | 2 +- ...oken_group_quant_deepseek3_2mem_manager.py | 40 +- ...=32,dtype=torch.bfloat16}_NVIDIA_H200.json | 98 ++++ ...=64,dtype=torch.bfloat16}_NVIDIA_H200.json | 98 ++++ ...out_dtype=torch.bfloat16}_NVIDIA_H200.json | 98 ++++ ...out_dtype=torch.bfloat16}_NVIDIA_H200.json | 518 ++++++++++++++++++ lightllm/models/__init__.py | 1 + .../layer_infer/transformer_layer_infer.py | 23 +- .../layer_weights/transformer_layer_weight.py | 5 +- lightllm/models/glm5_2/__init__.py | 3 + lightllm/models/glm5_2/indexshare.py | 12 + .../models/glm5_2/layer_infer/__init__.py | 3 + .../layer_infer/transformer_layer_infer.py | 101 ++++ .../models/glm5_2/layer_weights/__init__.py | 3 + .../layer_weights/transformer_layer_weight.py | 46 ++ lightllm/models/glm5_2/model.py | 54 ++ lightllm/models/glm5_2_mtp/__init__.py | 3 + lightllm/models/glm5_2_mtp/model.py | 98 ++++ lightllm/server/build_prompt.py | 34 +- .../model_infer/mode_backend/base_backend.py | 6 +- lightllm/server/tokenizer.py | 41 ++ lightllm/utils/config_utils.py | 8 + 24 files changed, 1312 insertions(+), 56 deletions(-) create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=32,dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=64,dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=12288,out_dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=2048,out_dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/models/glm5_2/__init__.py create mode 100644 lightllm/models/glm5_2/indexshare.py create mode 100644 lightllm/models/glm5_2/layer_infer/__init__.py create mode 100644 lightllm/models/glm5_2/layer_infer/transformer_layer_infer.py create mode 100644 lightllm/models/glm5_2/layer_weights/__init__.py create mode 100644 lightllm/models/glm5_2/layer_weights/transformer_layer_weight.py create mode 100644 lightllm/models/glm5_2/model.py create mode 100644 lightllm/models/glm5_2_mtp/__init__.py create mode 100644 lightllm/models/glm5_2_mtp/model.py diff --git a/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py b/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py index c3456f4b7..49fb1c2dd 100644 --- a/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py +++ b/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py @@ -3,6 +3,7 @@ import dataclasses import torch +import torch.nn.functional as F from typing import Tuple, TYPE_CHECKING from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl @@ -86,6 +87,12 @@ def _nsa_prefill_att( if topk_mem_indices.ndim == 2: topk_mem_indices = topk_mem_indices.unsqueeze(1) + real_head_num = q.shape[1] + head_block_size = 64 + pad_head_num = (-real_head_num) % head_block_size + if pad_head_num: + q = F.pad(q, (0, 0, 0, pad_head_num)) + mla_out, _, _ = flash_mla_sparse_fwd( q=q, kv=kv, @@ -93,7 +100,7 @@ def _nsa_prefill_att( sm_scale=softmax_scale, d_v=kv_lora_rank, ) - return mla_out + return mla_out[:, :real_head_num, :] @dataclasses.dataclass diff --git a/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py b/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py index 539ade769..1ce896051 100644 --- a/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py +++ b/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py @@ -1,5 +1,6 @@ import dataclasses import torch +import torch.nn.functional as F from typing import TYPE_CHECKING, Tuple from ..base_att import AttControl, BaseAttBackend, BaseDecodeAttState, BasePrefillAttState @@ -70,10 +71,9 @@ def _nsa_prefill_att( packed_kv: torch.Tensor, att_control: AttControl, ) -> torch.Tensor: - import flash_mla + from sgl_kernel.flash_mla import flash_mla_sparse_fwd nsa_dict = att_control.nsa_prefill_dict - topk_indices = nsa_dict["topk_indices"] softmax_scale = nsa_dict["softmax_scale"] kv_lora_rank = nsa_dict["kv_lora_rank"] topk_mem_indices = nsa_dict["topk_mem_indices"] @@ -91,18 +91,25 @@ def _nsa_prefill_att( ) else: kv = prefill_cache_kv + topk_indices = topk_mem_indices if topk_indices.ndim == 2: topk_indices = topk_indices.unsqueeze(1) - mla_out, _, _ = flash_mla.flash_mla_sparse_fwd( + real_head_num = q.shape[1] + head_block_size = 64 + pad_head_num = (-real_head_num) % head_block_size + if pad_head_num: + q = F.pad(q, (0, 0, 0, pad_head_num)) + + mla_out, _, _ = flash_mla_sparse_fwd( q=q, kv=kv, indices=topk_indices, sm_scale=softmax_scale, d_v=kv_lora_rank, ) - return mla_out + return mla_out[:, :real_head_num, :] @dataclasses.dataclass @@ -141,9 +148,6 @@ def init_state(self): ragged_mem_index=self.ragged_mem_index, hold_req_idx=self.infer_state.req_manager.HOLD_REQUEST_ID, ) - import flash_mla - - self.flashmla_sched_meta, _ = flash_mla.get_mla_metadata() return def decode_att( @@ -164,35 +168,65 @@ def _nsa_decode_att( packed_kv: torch.Tensor, att_control: AttControl, ) -> torch.Tensor: - import flash_mla + from sgl_kernel.flash_mla import flash_mla_with_kvcache, get_mla_metadata nsa_dict = att_control.nsa_decode_dict topk_mem_indices = nsa_dict["topk_mem_indices"] + topk_indices = nsa_dict.get("topk_indices", topk_mem_indices) softmax_scale = nsa_dict["softmax_scale"] kv_lora_rank = nsa_dict["kv_lora_rank"] if topk_mem_indices.ndim == 2: topk_mem_indices = topk_mem_indices.unsqueeze(1) + if topk_indices.ndim == 2: + topk_indices = topk_indices.unsqueeze(1) assert topk_mem_indices.shape[1] == 1, "FlashMLA sparse decode path currently expects seq_len_q == 1" + assert topk_indices.shape[1] == 1, "FlashMLA sparse decode path currently expects seq_len_q == 1" q_nope, q_rope = q q_all = torch.cat([q_nope, q_rope], dim=-1).unsqueeze(1).contiguous() + cache_seqlens = self.infer_state.b_seq_len.to(dtype=torch.int32) + page_block_size = 64 + max_seq_len = int(cache_seqlens.max().item()) + max_block_num = (max_seq_len + page_block_size - 1) // page_block_size + block_table = ( + self.infer_state.req_manager.req_to_token_indexs[ + self.infer_state.b_req_idx, : max_block_num * page_block_size : page_block_size + ] + // page_block_size + ).to(dtype=torch.int32) + num_heads_k = 1 + num_heads_q = q_all.shape[2] + tile_scheduler_metadata, num_splits = get_mla_metadata( + cache_seqlens=cache_seqlens, + num_q_tokens_per_head_k=num_heads_q // num_heads_k, + num_heads_k=num_heads_k, + num_heads_q=num_heads_q, + is_fp8_kvcache=True, + topk=topk_mem_indices.shape[-1], + ) kv = torch.as_strided( packed_kv, - size=(packed_kv.shape[0], 1, 1, packed_kv.shape[-1]), - stride=(packed_kv.stride(0), packed_kv.shape[-1], packed_kv.shape[-1], packed_kv.stride(-1)), + size=(packed_kv.shape[0] // page_block_size, page_block_size, 1, packed_kv.shape[-1]), + stride=( + packed_kv.stride(0) * page_block_size, + packed_kv.stride(0), + packed_kv.shape[-1], + packed_kv.stride(-1), + ), ) - o_tensor, _ = flash_mla.flash_mla_with_kvcache( + o_tensor, _ = flash_mla_with_kvcache( q=q_all, k_cache=kv, - block_table=None, - cache_seqlens=None, + block_table=block_table, + cache_seqlens=cache_seqlens, head_dim_v=kv_lora_rank, - tile_scheduler_metadata=self.flashmla_sched_meta, + tile_scheduler_metadata=tile_scheduler_metadata, + num_splits=num_splits, softmax_scale=softmax_scale, causal=False, is_fp8_kvcache=True, - indices=topk_mem_indices, + indices=topk_indices.to(dtype=torch.int32), ) return o_tensor[:, 0, :, :] # [b, 1, h, d] -> [b, h, d] diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 94f9d4c1a..3c500fa7c 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -1171,7 +1171,7 @@ def _init_padded_req(self): def _gen_special_model_input(self, token_num: int): special_model_input = {} - is_mtp_draft_model = ( + is_mtp_draft_model = getattr(self, "is_mtp_draft_model", False) or ( "Deepseek3MTPModel" in str(self.__class__) or "Qwen3MOEMTPModel" in str(self.__class__) or "MistralMTPModel" in str(self.__class__) diff --git a/lightllm/common/kv_cache_mem_manager/fp8_per_token_group_quant_deepseek3_2mem_manager.py b/lightllm/common/kv_cache_mem_manager/fp8_per_token_group_quant_deepseek3_2mem_manager.py index b72587545..fd7234805 100644 --- a/lightllm/common/kv_cache_mem_manager/fp8_per_token_group_quant_deepseek3_2mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/fp8_per_token_group_quant_deepseek3_2mem_manager.py @@ -27,24 +27,42 @@ class FP8PerTokenGroupQuantDeepseek3_2MemoryManager(Deepseek2MemoryManager): # 132 bytes = 128 + 4 indexer_bytes_per_token = indexer_head_dim + 4 - # 16-byte 对齐,满足FlashMLA的对齐要求 - alignment = 16 - total_bytes_per_token = ( - (flashmla_bytes_per_token + indexer_bytes_per_token + alignment - 1) // alignment * alignment - ) - def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): assert head_num == 1, "DeepSeek-V3.2 DSA FP8 path expects MQA-style head_num == 1" self.prefill_dtype = dtype - super().__init__(size, torch.uint8, head_num, self.total_bytes_per_token, layer_num, always_copy, mem_fraction) + super().__init__(size, torch.uint8, head_num, self.flashmla_bytes_per_token, layer_num, always_copy, mem_fraction) + + def get_cell_size(self): + return self.layer_num * (self.flashmla_bytes_per_token + self.indexer_bytes_per_token) + + def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): + self.kv_buffer = torch.empty( + (layer_num, size + 1, head_num, self.flashmla_bytes_per_token), + dtype=dtype, + device="cuda", + ) + self.indexer_k_buffer = torch.empty( + (layer_num, size + 1, head_num, self.indexer_bytes_per_token), + dtype=dtype, + device="cuda", + ) def get_att_input_params(self, layer_index: int) -> Any: - return self.kv_buffer[layer_index][:, :, : self.flashmla_bytes_per_token] + return self.kv_buffer[layer_index] def get_indexer_k_buffer(self, layer_index: int) -> torch.Tensor: - begin = self.flashmla_bytes_per_token - end = begin + self.indexer_bytes_per_token - return self.kv_buffer[layer_index][:, :, begin:end] + return self.indexer_k_buffer[layer_index] + + def _free_buffers(self): + self.kv_buffer = None + self.indexer_k_buffer = None + + def get_index_kv_buffer(self, index): + return {"kv_buffer": self.kv_buffer[:, index], "indexer_k_buffer": self.indexer_k_buffer[:, index]} + + def load_index_kv_buffer(self, index, load_tensor_dict): + self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"]) + self.indexer_k_buffer[:, index].copy_(load_tensor_dict["indexer_k_buffer"]) def get_prefill_kv_cache_and_remap_indices( self, diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=32,dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=32,dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..eaef18b7e --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=32,dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,98 @@ +{ + "1": { + "BLOCK_SEQ": 16, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 4, + "num_warps": 2 + }, + "100": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 1, + "num_warps": 2 + }, + "1024": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 4, + "num_stages": 5, + "num_warps": 1 + }, + "128": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 8, + "num_stages": 2, + "num_warps": 2 + }, + "16": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 3, + "num_warps": 1 + }, + "2048": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 2, + "num_stages": 2, + "num_warps": 1 + }, + "256": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 8, + "num_stages": 2, + "num_warps": 1 + }, + "32": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 2, + "num_warps": 2 + }, + "4096": { + "BLOCK_SEQ": 4, + "HEAD_PARALLEL_NUM": 4, + "num_stages": 1, + "num_warps": 2 + }, + "5120": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 2, + "num_stages": 4, + "num_warps": 1 + }, + "6144": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 2, + "num_stages": 2, + "num_warps": 1 + }, + "64": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 4, + "num_warps": 1 + }, + "6400": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 2, + "num_stages": 2, + "num_warps": 1 + }, + "7168": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 2, + "num_stages": 4, + "num_warps": 1 + }, + "8": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 4, + "num_warps": 1 + }, + "8192": { + "BLOCK_SEQ": 2, + "HEAD_PARALLEL_NUM": 1, + "num_stages": 4, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=64,dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=64,dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..4d1b132f7 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=64,dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,98 @@ +{ + "1": { + "BLOCK_SEQ": 2, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 5, + "num_warps": 4 + }, + "100": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 1, + "num_warps": 2 + }, + "1024": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 4, + "num_stages": 5, + "num_warps": 1 + }, + "128": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 4, + "num_warps": 1 + }, + "16": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 2, + "num_stages": 5, + "num_warps": 1 + }, + "256": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 8, + "num_stages": 2, + "num_warps": 2 + }, + "32": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 1, + "num_warps": 2 + }, + "4096": { + "BLOCK_SEQ": 2, + "HEAD_PARALLEL_NUM": 2, + "num_stages": 1, + "num_warps": 2 + }, + "5120": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 4, + "num_stages": 2, + "num_warps": 1 + }, + "6144": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 2, + "num_stages": 5, + "num_warps": 1 + }, + "64": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 1, + "num_warps": 2 + }, + "6400": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 8, + "num_stages": 4, + "num_warps": 1 + }, + "7168": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 4, + "num_stages": 2, + "num_warps": 1 + }, + "8": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 1, + "num_warps": 8 + }, + "8192": { + "BLOCK_SEQ": 4, + "HEAD_PARALLEL_NUM": 2, + "num_stages": 5, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=12288,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=12288,out_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..6c3a7ac83 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=12288,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,98 @@ +{ + "1": { + "BLOCK_M": 32, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 8 + }, + "100": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "1024": { + "BLOCK_M": 32, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "128": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "16": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 4 + }, + "2048": { + "BLOCK_M": 64, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "256": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "32": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "4096": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "5120": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "6144": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "64": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "6400": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "7168": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=2048,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=2048,out_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..e1f78b692 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=2048,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,518 @@ +{ + "1": { + "BLOCK_M": 32, + "BLOCK_N": 64, + "NUM_STAGES": 4, + "num_warps": 8 + }, + "100": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 4 + }, + "1024": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "10240": { + "BLOCK_M": 32, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "10368": { + "BLOCK_M": 32, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "10496": { + "BLOCK_M": 32, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "1152": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "1408": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "1536": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "16": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "1664": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "1792": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "17920": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "18048": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "18176": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "18304": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "18688": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "18816": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "1920": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2176": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "256": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 4 + }, + "2688": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2816": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2944": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "3072": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "32": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 4 + }, + "33920": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "3456": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "34560": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "34816": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "34944": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "35200": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "35328": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "3584": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "3712": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "384": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 4 + }, + "3840": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "3968": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "4096": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "42112": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "4224": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "42624": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "43008": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "43264": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "43776": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "49664": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "50944": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "5120": { + "BLOCK_M": 32, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "51200": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "51328": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "51456": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "52352": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "52480": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "52992": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "53120": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "53248": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "53632": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "53888": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "58368": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "58880": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "59008": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "59136": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "59520": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "59776": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "59904": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "60032": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "6144": { + "BLOCK_M": 32, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 4 + }, + "640": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "6400": { + "BLOCK_M": 32, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "66176": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "66432": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "67200": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "67328": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "67456": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "67840": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "68096": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "68480": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "7168": { + "BLOCK_M": 32, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "8192": { + "BLOCK_M": 32, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "9472": { + "BLOCK_M": 32, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "9728": { + "BLOCK_M": 32, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "9984": { + "BLOCK_M": 32, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index f619b1d88..56cd13bb3 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -21,6 +21,7 @@ from lightllm.models.deepseek2.model import Deepseek2TpPartModel from lightllm.models.deepseek3_2.model import Deepseek3_2TpPartModel from lightllm.models.glm4_moe_lite.model import Glm4MoeLiteTpPartModel +from lightllm.models.glm5_2.model import Glm5_2TpPartModel from lightllm.models.internvl.model import ( InternVLLlamaTpPartModel, InternVLPhi3TpPartModel, diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index d6eaebe2f..684478cfd 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -120,7 +120,7 @@ def _token_attention_kernel( # 计算 topk mem indices att_state = infer_state.decode_att_state - topk_mem_indices, _ = self.indexer._get_indices( + topk_mem_indices, topk_indices = self.indexer._get_indices( hidden_states=infer_state.get_topk_indices_params["hidden_states"], q_lora=infer_state.get_topk_indices_params["q_lora"], infer_state=infer_state, @@ -135,6 +135,7 @@ def _token_attention_kernel( nsa_decode_dict={ "layer_index": self.layer_num_, "topk_mem_indices": topk_mem_indices, + "topk_indices": topk_indices, "softmax_scale": self.softmax_scale, "kv_lora_rank": self.kv_lora_rank, "qk_rope_head_dim": self.qk_rope_head_dim, @@ -160,13 +161,15 @@ def __init__(self, layer_idx: int, network_config: dict, tp_world_size: int): self.qk_rope_head_dim = network_config["qk_rope_head_dim"] self.index_head_dim = network_config["index_head_dim"] self.eps = network_config["rms_norm_eps"] - self.block_size = network_config["quantization_config"]["weight_block_size"][0] - self.scale_fmt = network_config["quantization_config"]["scale_fmt"] + quantization_config = network_config.get("quantization_config") or {} + self.block_size = quantization_config.get("weight_block_size", [128, 128])[0] + self.scale_fmt = quantization_config.get("scale_fmt", "ue8m0") self.softmax_scale = (self.index_head_dim) ** (-0.5) self.index_n_heads = network_config["index_n_heads"] - self.index_n_heads_scale = (self.index_n_heads ** -0.5) * self.softmax_scale + self.index_n_heads_scale = (self.index_n_heads**-0.5) * self.softmax_scale self.tp_world_size_ = tp_world_size self.tp_index_n_heads = self.index_n_heads // self.tp_world_size_ + self.indexer_rope_interleave = network_config.get("indexer_rope_interleave", False) def _get_indices( self, @@ -176,12 +179,11 @@ def _get_indices( att_state: Any, layer_weight: Deepseek3_2TransformerLayerWeight, ): - q, k = self._get_q_k_bf16(hidden_states, q_lora, infer_state, layer_weight) if self.tp_world_size_ > 1: q_merge = torch.empty( - size=(self.tp_world_size_ * q.numel()), + size=(self.tp_world_size_ * q.numel(),), dtype=q.dtype, device=q.device, ) @@ -262,7 +264,7 @@ def _rotate_activation(x: torch.Tensor) -> torch.Tensor: hidden_size = x.size(-1) assert (hidden_size & (hidden_size - 1)) == 0, "Hidden size must be a power of 2 for Hadamard transform." - return hadamard_transform(x, scale=hidden_size ** -0.5) + return hadamard_transform(x, scale=hidden_size**-0.5) def _get_q_k_bf16( self, @@ -276,8 +278,11 @@ def _get_q_k_bf16( k = layer_weight.k_norm_(k, eps=self.eps) - # 为什么 indexer 和主模型用的q k 的 rotary的排布方式不一样,这不是脱裤子放屁麻。 - from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd + if self.indexer_rope_interleave: + from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd + else: + # DeepSeek-V3.2 indexer RoPE uses the non-interleaved layout. + from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd rotary_emb_fwd( q[:, :, : self.qk_rope_head_dim], diff --git a/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py index c7c71c827..86f67f6aa 100644 --- a/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py @@ -19,7 +19,6 @@ def _init_weight(self): self._init_indexer_weight() def _init_indexer_weight(self): - prefix = f"model.layers.{self.layer_num_}.self_attn.indexer" assert self.index_n_heads % self.tp_world_size_ == 0 @@ -28,7 +27,7 @@ def _init_indexer_weight(self): out_dims=[self.index_n_heads * self.index_head_dim], weight_names=f"{prefix}.wq_b.weight", data_type=self.data_type_, - quant_method=None, + quant_method=self.get_quant_method("indexer_wq_b"), tp_rank=self.tp_rank_, tp_world_size=self.tp_world_size_, ) @@ -37,7 +36,7 @@ def _init_indexer_weight(self): out_dims=[self.index_head_dim], weight_names=f"{prefix}.wk.weight", data_type=self.data_type_, - quant_method=None, + quant_method=self.get_quant_method("indexer_wk"), tp_rank=0, tp_world_size=1, ) diff --git a/lightllm/models/glm5_2/__init__.py b/lightllm/models/glm5_2/__init__.py new file mode 100644 index 000000000..bf9546188 --- /dev/null +++ b/lightllm/models/glm5_2/__init__.py @@ -0,0 +1,3 @@ +from lightllm.models.glm5_2.model import Glm5_2TpPartModel + +__all__ = ["Glm5_2TpPartModel"] diff --git a/lightllm/models/glm5_2/indexshare.py b/lightllm/models/glm5_2/indexshare.py new file mode 100644 index 000000000..617e1f650 --- /dev/null +++ b/lightllm/models/glm5_2/indexshare.py @@ -0,0 +1,12 @@ +def owns_indexer_layer(layer_id: int, config: dict) -> bool: + num_hidden_layers = config.get("num_hidden_layers") + if num_hidden_layers is not None and layer_id >= num_hidden_layers: + return True + + pattern = config.get("index_topk_pattern") + if pattern is not None and 0 <= layer_id < len(pattern): + return pattern[layer_id] != "S" + + freq = config.get("index_topk_freq", 1) + offset = config.get("index_skip_topk_offset", 2) + return max(layer_id - offset + 1, 0) % freq == 0 diff --git a/lightllm/models/glm5_2/layer_infer/__init__.py b/lightllm/models/glm5_2/layer_infer/__init__.py new file mode 100644 index 000000000..1fa3327f6 --- /dev/null +++ b/lightllm/models/glm5_2/layer_infer/__init__.py @@ -0,0 +1,3 @@ +from lightllm.models.glm5_2.layer_infer.transformer_layer_infer import Glm5_2TransformerLayerInfer + +__all__ = ["Glm5_2TransformerLayerInfer"] diff --git a/lightllm/models/glm5_2/layer_infer/transformer_layer_infer.py b/lightllm/models/glm5_2/layer_infer/transformer_layer_infer.py new file mode 100644 index 000000000..00cbad6c1 --- /dev/null +++ b/lightllm/models/glm5_2/layer_infer/transformer_layer_infer.py @@ -0,0 +1,101 @@ +import torch + +from lightllm.common.basemodel.attention.base_att import AttControl +from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer +from lightllm.models.deepseek3_2.layer_infer.transformer_layer_infer import Deepseek3_2TransformerLayerInfer, NsaInfer +from lightllm.models.glm5_2.indexshare import owns_indexer_layer + + +class Glm5_2TransformerLayerInfer(Deepseek3_2TransformerLayerInfer): + def __init__(self, layer_num, network_config): + self.has_indexer = owns_indexer_layer(layer_num, network_config) + Deepseek2TransformerLayerInfer.__init__(self, layer_num, network_config) + self.indexer = ( + NsaInfer(layer_idx=self.layer_num_, network_config=self.network_config_, tp_world_size=self.tp_world_size_) + if self.has_indexer + else None + ) + + def _get_or_reuse_topk_indices(self, infer_state, att_state, layer_weight): + if getattr(infer_state, "glm5_2_reuse_mtp_topk_indices", False): + model_input = getattr(infer_state, "glm5_2_model_input", None) + cached_topk = getattr(model_input, "glm5_2_mtp_topk_cache", None) + if cached_topk is not None: + infer_state.glm5_2_indexshare_topk_cache = cached_topk + return cached_topk + + if self.indexer is not None: + topk_mem_indices, topk_indices = self.indexer._get_indices( + hidden_states=infer_state.get_topk_indices_params["hidden_states"], + q_lora=infer_state.get_topk_indices_params["q_lora"], + infer_state=infer_state, + att_state=att_state, + layer_weight=layer_weight, + ) + infer_state.glm5_2_indexshare_topk_cache = (topk_mem_indices, topk_indices) + if getattr(infer_state, "glm5_2_reuse_mtp_topk_indices", False): + model_input = getattr(infer_state, "glm5_2_model_input", None) + if model_input is not None: + model_input.glm5_2_mtp_topk_cache = infer_state.glm5_2_indexshare_topk_cache + return topk_mem_indices, topk_indices + + if not hasattr(infer_state, "glm5_2_indexshare_topk_cache"): + raise RuntimeError( + f"GLM-5.2 layer {self.layer_num_} needs cached IndexShare top-k indices, " + "but no previous indexer layer has produced them." + ) + return infer_state.glm5_2_indexshare_topk_cache + + def _context_attention_kernel(self, q, kv, infer_state, layer_weight, out=None): + q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] + q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) + q_all = q_nope if self.qk_rope_head_dim == 0 else torch.cat([q_nope, q_rope], dim=-1) + + att_state = infer_state.prefill_att_state + topk_mem_indices, topk_indices = self._get_or_reuse_topk_indices(infer_state, att_state, layer_weight) + del infer_state.get_topk_indices_params + + att_control = AttControl( + nsa_prefill=True, + nsa_prefill_dict={ + "topk_mem_indices": topk_mem_indices, + "topk_indices": topk_indices, + "prefill_cache_kv": kv, + "softmax_scale": self.softmax_scale, + "kv_lora_rank": self.kv_lora_rank, + }, + ) + + return infer_state.prefill_att_state.prefill_att( + q=q_all, + k=infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_), + v=None, + att_control=att_control, + ) + + def _token_attention_kernel(self, q, infer_state, layer_weight, out=None): + q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] + q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) + + att_state = infer_state.decode_att_state + topk_mem_indices, topk_indices = self._get_or_reuse_topk_indices(infer_state, att_state, layer_weight) + del infer_state.get_topk_indices_params + + att_control = AttControl( + nsa_decode=True, + nsa_decode_dict={ + "layer_index": self.layer_num_, + "topk_mem_indices": topk_mem_indices, + "topk_indices": topk_indices, + "softmax_scale": self.softmax_scale, + "kv_lora_rank": self.kv_lora_rank, + "qk_rope_head_dim": self.qk_rope_head_dim, + }, + ) + + return infer_state.decode_att_state.decode_att( + q=(q_nope, q_rope), + k=infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_), + v=None, + att_control=att_control, + ) diff --git a/lightllm/models/glm5_2/layer_weights/__init__.py b/lightllm/models/glm5_2/layer_weights/__init__.py new file mode 100644 index 000000000..01013d565 --- /dev/null +++ b/lightllm/models/glm5_2/layer_weights/__init__.py @@ -0,0 +1,3 @@ +from lightllm.models.glm5_2.layer_weights.transformer_layer_weight import Glm5_2TransformerLayerWeight + +__all__ = ["Glm5_2TransformerLayerWeight"] diff --git a/lightllm/models/glm5_2/layer_weights/transformer_layer_weight.py b/lightllm/models/glm5_2/layer_weights/transformer_layer_weight.py new file mode 100644 index 000000000..066c551b1 --- /dev/null +++ b/lightllm/models/glm5_2/layer_weights/transformer_layer_weight.py @@ -0,0 +1,46 @@ +import torch + +from lightllm.common.basemodel.layer_weights.meta_weights import FusedMoeWeight, ROWMMWeight +from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight +from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight +from lightllm.models.glm5_2.indexshare import owns_indexer_layer + + +class Glm5_2TransformerLayerWeight(Deepseek3_2TransformerLayerWeight): + def _parse_config(self): + super()._parse_config() + self.has_indexer = owns_indexer_layer(self.layer_num_, self.network_config_) + + def _init_weight(self): + Deepseek2TransformerLayerWeight._init_weight(self) + if self.has_indexer: + self._init_indexer_weight() + + def _init_moe(self): + if self.num_fused_shared_experts == 0: + self._load_mlp(f"model.layers.{self.layer_num_}.mlp.shared_experts", is_shared_experts=True) + + self.moe_gate = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.n_routed_experts], + weight_names=f"model.layers.{self.layer_num_}.mlp.gate.weight", + data_type=torch.float32, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + self.experts = FusedMoeWeight( + gate_proj_name="gate_proj", + down_proj_name="down_proj", + up_proj_name="up_proj", + e_score_correction_bias_name=self.e_score_correction_bias_name, + weight_prefix=f"model.layers.{self.layer_num_}.mlp.experts", + n_routed_experts=self.n_routed_experts, + hidden_size=self.n_embed, + moe_intermediate_size=self.network_config_["moe_intermediate_size"], + data_type=self.data_type_, + quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"), + num_fused_shared_experts=self.num_fused_shared_experts, + layer_num=self.layer_num_, + network_config=self.network_config_, + ) diff --git a/lightllm/models/glm5_2/model.py b/lightllm/models/glm5_2/model.py new file mode 100644 index 000000000..6ac26c5e3 --- /dev/null +++ b/lightllm/models/glm5_2/model.py @@ -0,0 +1,54 @@ +import torch + +from lightllm.distributed.communication_op import dist_group_manager +from lightllm.models.deepseek3_2.model import Deepseek3_2TpPartModel +from lightllm.models.glm5_2.layer_infer.transformer_layer_infer import Glm5_2TransformerLayerInfer +from lightllm.models.glm5_2.layer_weights.transformer_layer_weight import Glm5_2TransformerLayerWeight +from lightllm.models.registry import ModelRegistry + + +@ModelRegistry("glm_moe_dsa") +class Glm5_2TpPartModel(Deepseek3_2TpPartModel): + transformer_weight_class = Glm5_2TransformerLayerWeight + transformer_layer_infer_class = Glm5_2TransformerLayerInfer + + def _init_config(self): + super()._init_config() + if "scoring_func" not in self.config: + self.config["scoring_func"] = "sigmoid" + if self.config.get("rope_theta") is None: + self.config["rope_theta"] = self.config.get("rope_parameters", {}).get("rope_theta", 1000000.0) + + def _init_custom(self): + self._init_glm5_2_rotary() + dist_group_manager.new_deepep_group( + self.config["n_routed_experts"], + self.config["hidden_size"], + self.config.get("num_experts_per_tok", 1), + self.config.get("moe_intermediate_size", self.config.get("intermediate_size")), + ) + + def _create_inferstate(self, model_input, microbatch_index: int = 0): + infer_state = super()._create_inferstate(model_input, microbatch_index=microbatch_index) + infer_state.glm5_2_model_input = model_input + infer_state.glm5_2_reuse_mtp_topk_indices = ( + getattr(self, "is_mtp_draft_model", False) + and self.config.get("index_share_for_mtp_iteration", False) + and not model_input.is_prefill + ) + return infer_state + + def _init_glm5_2_rotary(self): + rope_theta = self.config.get("rope_theta", 8000000.0) + qk_rope_head_dim = self.config.get("qk_rope_head_dim", 64) + max_position_embeddings = self.config.get("max_position_embeddings", 1048576) + + inv_freq = 1.0 / ( + rope_theta ** (torch.arange(0, qk_rope_head_dim, 2, device="cpu", dtype=torch.float32) / qk_rope_head_dim) + ) + max_seq_len = max(max_position_embeddings, self.max_seq_length) + t = torch.arange(max_seq_len, device="cpu", dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + + self._cos_cached = torch.cos(freqs).to(self.data_type).cuda() + self._sin_cached = torch.sin(freqs).to(self.data_type).cuda() diff --git a/lightllm/models/glm5_2_mtp/__init__.py b/lightllm/models/glm5_2_mtp/__init__.py new file mode 100644 index 000000000..8798b32f3 --- /dev/null +++ b/lightllm/models/glm5_2_mtp/__init__.py @@ -0,0 +1,3 @@ +from lightllm.models.glm5_2_mtp.model import Glm5_2MTPModel + +__all__ = ["Glm5_2MTPModel"] diff --git a/lightllm/models/glm5_2_mtp/model.py b/lightllm/models/glm5_2_mtp/model.py new file mode 100644 index 000000000..e1ea4a383 --- /dev/null +++ b/lightllm/models/glm5_2_mtp/model.py @@ -0,0 +1,98 @@ +from typing import List + +from lightllm.common.basemodel import TpPartBaseModel +from lightllm.common.basemodel.basemodel import load_hf_weights +from lightllm.models.deepseek_mtp.layer_infer.pre_layer_infer import Deepseek3MTPPreLayerInfer +from lightllm.models.glm4_moe_lite_mtp.layer_weights.pre_and_post_layer_weight import ( + Glm4MoeLiteMTPPreAndPostLayerWeight, +) +from lightllm.models.glm5_2.model import Glm5_2TpPartModel + + +class Glm5_2MTPModel(Glm5_2TpPartModel): + is_mtp_draft_model = True + + pre_and_post_weight_class = Glm4MoeLiteMTPPreAndPostLayerWeight + pre_layer_infer_class = Deepseek3MTPPreLayerInfer + + def __init__(self, kvargs: dict): + self._pre_init(kvargs) + super().__init__(kvargs) + + def _pre_init(self, kvargs: dict): + self.main_model: TpPartBaseModel = kvargs.pop("main_model") + self.mtp_previous_draft_models: List[TpPartBaseModel] = kvargs.pop("mtp_previous_draft_models") + + def _init_custom(self): + self._cos_cached = self.main_model._cos_cached + self._sin_cached = self.main_model._sin_cached + + def _init_req_manager(self): + self.req_manager = self.main_model.req_manager + + def _init_mem_manager(self): + self.mem_manager = self.main_model.mem_manager + + def _init_weights(self, start_layer_index=None): + assert start_layer_index is None + + mtp_layer_start = self.config["num_hidden_layers"] + num_mtp_layers = self.config.get("num_nextn_predict_layers", 1) + + self.pre_post_weight = self.pre_and_post_weight_class( + self.data_type, network_config=self.config, quant_cfg=self.quant_cfg + ) + + self.trans_layers_weight = [ + self.transformer_weight_class( + i, + self.data_type, + network_config=self.config, + quant_cfg=self.quant_cfg, + ) + for i in range(mtp_layer_start, mtp_layer_start + num_mtp_layers) + ] + + load_hf_weights( + self.data_type, + weight_dir=self.weight_dir_, + pre_post_layer=self.pre_post_weight, + transformer_layer_list=self.trans_layers_weight, + weight_dict=self.weight_dict, + ) + + self.pre_post_weight.verify_load() + [weight.verify_load() for weight in self.trans_layers_weight] + + self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_ + self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_ + + def _load_hf_weights(self): + # GLM-5.2 MTP only loads the nextn layer in _init_weights(); avoid the + # base-class second pass over the same tensors, which creates large + # temporary CUDA buffers for FP8 dequantization during startup. + return + + def _init_infer_layer(self, start_layer_index=None): + assert start_layer_index is None + + self.pre_infer = self.pre_layer_infer_class(network_config=self.config) + self.post_infer = self.post_layer_infer_class(network_config=self.config) + + total_pre_layers_num = len(self.main_model.layers_infer) + total_pre_layers_num += sum( + [len(previous_model.layers_infer) for previous_model in self.mtp_previous_draft_models] + ) + + num_mtp_layers = self.config.get("num_nextn_predict_layers", 1) + self.layers_infer = [ + self.transformer_layer_infer_class(i, network_config=self.config) + for i in range(total_pre_layers_num, total_pre_layers_num + num_mtp_layers) + ] + + def _init_some_value(self): + super()._init_some_value() + self.layers_num = self.config.get("num_nextn_predict_layers", 1) + + def autotune_layers(self): + return self.config.get("num_nextn_predict_layers", 1) diff --git a/lightllm/server/build_prompt.py b/lightllm/server/build_prompt.py index 54d22a0d0..16304a5b7 100644 --- a/lightllm/server/build_prompt.py +++ b/lightllm/server/build_prompt.py @@ -10,6 +10,13 @@ tokenizer = None +def _set_chat_template(chat_template_str: str) -> None: + if hasattr(tokenizer, "tokenizer"): + tokenizer.tokenizer.chat_template = chat_template_str + else: + tokenizer.chat_template = chat_template_str + + def init_tokenizer(args): global tokenizer @@ -17,30 +24,29 @@ def init_tokenizer(args): chat_path = args.chat_template if chat_path is not None: with open(chat_path, "r", encoding="utf-8") as f: - chat_template_str = f.read() - if hasattr(tokenizer, "tokenizer"): - tokenizer.tokenizer.chat_template = chat_template_str - else: - tokenizer.chat_template = chat_template_str + _set_chat_template(f.read()) return + default_jinja_chat_template_path = os.path.join(args.model_dir, "chat_template.jinja") + if os.path.exists(default_jinja_chat_template_path): + try: + with open(default_jinja_chat_template_path, "r", encoding="utf-8") as f: + _set_chat_template(f.read()) + logger.info(f"Loaded chat_template.jinja from {default_jinja_chat_template_path}") + return + except Exception as e: + logger.warning(f"Failed to load chat_template.jinja from {default_jinja_chat_template_path}: {e}") + # 如果 tokenizer 目录下存在chat_template.json, 同时不存在 chat_template.jinja, # 则加载其并赋值给tokenizer 的 chat_template 对象。 - if not os.path.exists(os.path.join(args.model_dir, "chat_template.jinja")) and os.path.exists( - os.path.join(args.model_dir, "chat_template.json") - ): + if os.path.exists(os.path.join(args.model_dir, "chat_template.json")): default_chat_template_path = os.path.join(args.model_dir, "chat_template.json") try: with open(default_chat_template_path, "r", encoding="utf-8") as f: template_data = json.load(f) if "chat_template" in template_data: # Set it directly on the tokenizer object so apply_chat_template can use it - if hasattr(tokenizer, "tokenizer"): - # 多模态 tokenizer - tokenizer.tokenizer.chat_template = template_data["chat_template"] - else: - tokenizer.chat_template = template_data["chat_template"] - + _set_chat_template(template_data["chat_template"]) logger.info(f"Loaded chat_template.json from {default_chat_template_path}") except Exception as e: logger.warning(f"Failed to load chat_template.json from {default_chat_template_path}: {e}") diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index a65dfb1bb..f931f87f1 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -45,6 +45,7 @@ from lightllm.models.qwen3_moe_mtp.model import Qwen3MOEMTPModel from lightllm.models.mistral_mtp.model import MistralMTPModel from lightllm.models.glm4_moe_lite_mtp.model import Glm4MoeLiteMTPModel +from lightllm.models.glm5_2_mtp.model import Glm5_2MTPModel from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.server.pd_io_struct import PDChunckedTransTaskRet @@ -342,6 +343,9 @@ def init_mtp_draft_model(self, main_kvargs: dict): elif mtp_model_cfg["model_type"] == "glm4_moe_lite": assert self.args.mtp_mode in ["vanilla_with_att", "eagle_with_att"] self.draft_models.append(Glm4MoeLiteMTPModel(mtp_model_kvargs)) + elif mtp_model_cfg["model_type"] == "glm_moe_dsa": + assert self.args.mtp_mode in ["vanilla_with_att", "eagle_with_att"] + self.draft_models.append(Glm5_2MTPModel(mtp_model_kvargs)) else: raise ValueError(f"Unsupported MTP model type: {model_type}") @@ -584,7 +588,6 @@ def _get_classed_reqs( can_alloc_token_num = g_infer_context.get_can_alloc_token_num() for req_obj in ready_reqs: - if req_obj.filter_mark: finished_reqs.append(req_obj) continue @@ -787,7 +790,6 @@ def _sample_and_scatter_token( b_prefill_has_output_cpu: torch.Tensor = None, mask_func: Optional[Callable] = None, ): - if mask_func is not None: assert len(run_reqs) == logits.shape[0] mask_func(run_reqs, logits) diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index e1a4e421d..7ddf985fb 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -16,6 +16,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json +import os from typing import List, Tuple, Union from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast @@ -39,6 +41,41 @@ _FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" +def _load_tokenizers_backend_tokenizer( + tokenizer_name: str, + error: ValueError, + *args, + **kwargs, +) -> Union[PreTrainedTokenizerFast, None]: + if "Tokenizer class TokenizersBackend does not exist or is not currently imported" not in str(error): + return None + if not os.path.isdir(tokenizer_name): + return None + + tokenizer_file = os.path.join(tokenizer_name, "tokenizer.json") + tokenizer_config_file = os.path.join(tokenizer_name, "tokenizer_config.json") + if not os.path.exists(tokenizer_file) or not os.path.exists(tokenizer_config_file): + return None + + with open(tokenizer_config_file, "r", encoding="utf-8") as fp: + tokenizer_config = json.load(fp) + if tokenizer_config.get("tokenizer_class") != "TokenizersBackend": + return None + + special_token_kwargs = { + name: tokenizer_config[name] + for name in ("bos_token", "eos_token", "unk_token", "sep_token", "pad_token", "cls_token", "mask_token") + if tokenizer_config.get(name) is not None + } + if tokenizer_config.get("extra_special_tokens"): + special_token_kwargs["additional_special_tokens"] = tokenizer_config["extra_special_tokens"] + if tokenizer_config.get("model_max_length") is not None: + special_token_kwargs["model_max_length"] = tokenizer_config["model_max_length"] + + logger.info("Loading TokenizersBackend tokenizer through tokenizer.json fast-tokenizer fallback.") + return PreTrainedTokenizerFast(tokenizer_file=tokenizer_file, *args, **special_token_kwargs, **kwargs) + + def get_tokenizer( tokenizer_name: str, tokenizer_mode: str = "auto", @@ -71,6 +108,10 @@ def get_tokenizer( logger.warning(f"load fast tokenizer fail: {str(e)}") kwargs["use_fast"] = False tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=trust_remote_code, *args, **kwargs) + except ValueError as e: + tokenizer = _load_tokenizers_backend_tokenizer(tokenizer_name, e, *args, **kwargs) + if tokenizer is None: + raise if not isinstance(tokenizer, PreTrainedTokenizerFast): logger.info( diff --git a/lightllm/utils/config_utils.py b/lightllm/utils/config_utils.py index c8d7373d5..ab96803ce 100644 --- a/lightllm/utils/config_utils.py +++ b/lightllm/utils/config_utils.py @@ -444,6 +444,10 @@ def get_tool_call_parser_for_model(model_path: str) -> Optional[str]: if model_type is None: return None + # GLM-5 / GLM-5.2 DSA models use the GLM-4.7 tool-call format. + if model_type == "glm_moe_dsa": + return "glm47" + # Qwen3.5 series if model_type in ["qwen3_5", "qwen3_5_moe", "qwen3_5_text", "qwen3_5_moe_text"]: return "qwen3_coder" @@ -473,6 +477,10 @@ def get_reasoning_parser_for_model(model_path: str) -> Optional[str]: if model_type is None: return None + # GLM-5 / GLM-5.2 DSA models share the GLM-4.5-style thinking tags. + if model_type == "glm_moe_dsa": + return "glm45" + # Qwen3.5 and Qwen3 series if model_type in [ "qwen3", From 2654a70f4632da326682db1a87ce8b5f501ad7a1 Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 24 Jun 2026 19:34:31 +0800 Subject: [PATCH 2/8] fix(nsa): correct fp8 DSA FlashMLA prefill/decode kernel contract The fp8kv_dsa NSA path (--llm_kv_type fp8kv_dsa) addressed the FlashMLA sparse kernels with the wrong index space. It was never caught because the default/benchmarked config uses --llm_kv_type None (bf16 NSA path). Decode (sgl_kernel.flash_mla.flash_mla_with_kvcache): - pass indices=topk_mem_indices (absolute KV-pool slots) instead of the raw sequence-space / global-ragged topk_indices - pass an empty (bs, 0) block_table; the kernel ignores block_table for sparse indices, and the previously-computed paged block_table was both wrong (the flat allocator is not 64-page aligned) and dead - pad query heads up to the supported FlashMLA decode variants (64/128) - drop the per-layer cache_seqlens.max().item() host sync (only needed for the removed block_table; also unblocks cuda graph capture) Prefill: the no-prefix branch indexed the local prefill_cache_kv buffer with mem-pool slots; use the local topk_indices (b_topk_index) instead. Mem manager: pad the kv buffer token dim to a multiple of 64 so the decode 64-token page view keeps every valid slot addressable. Contract verified against the SGLang reference (dsa_backend.py::_forward_flashmla_kv); still needs GPU validation on the fp8kv_dsa path (batch>1 decode). --- .../attention/nsa/fp8_flashmla_sparse.py | 30 +++++++++---------- ...oken_group_quant_deepseek3_2mem_manager.py | 6 ++-- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py b/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py index 1ce896051..ac42d991f 100644 --- a/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py +++ b/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py @@ -91,7 +91,7 @@ def _nsa_prefill_att( ) else: kv = prefill_cache_kv - topk_indices = topk_mem_indices + topk_indices = nsa_dict["topk_indices"] if topk_indices.ndim == 2: topk_indices = topk_indices.unsqueeze(1) @@ -118,7 +118,6 @@ class NsaFlashMlaFp8SparseDecodeAttState(BaseDecodeAttState): ke: torch.Tensor = None lengths: torch.Tensor = None ragged_mem_index: torch.Tensor = None - flashmla_sched_meta: object = None def init_state(self): self.backend: NsaFlashMlaFp8SparseAttBackend = self.backend @@ -172,29 +171,28 @@ def _nsa_decode_att( nsa_dict = att_control.nsa_decode_dict topk_mem_indices = nsa_dict["topk_mem_indices"] - topk_indices = nsa_dict.get("topk_indices", topk_mem_indices) softmax_scale = nsa_dict["softmax_scale"] kv_lora_rank = nsa_dict["kv_lora_rank"] if topk_mem_indices.ndim == 2: topk_mem_indices = topk_mem_indices.unsqueeze(1) - if topk_indices.ndim == 2: - topk_indices = topk_indices.unsqueeze(1) assert topk_mem_indices.shape[1] == 1, "FlashMLA sparse decode path currently expects seq_len_q == 1" - assert topk_indices.shape[1] == 1, "FlashMLA sparse decode path currently expects seq_len_q == 1" q_nope, q_rope = q q_all = torch.cat([q_nope, q_rope], dim=-1).unsqueeze(1).contiguous() + + real_head_num = q_all.shape[2] + if real_head_num <= 64: + padded_head_num = 64 + elif real_head_num <= 128: + padded_head_num = 128 + else: + padded_head_num = real_head_num + if padded_head_num != real_head_num: + q_all = F.pad(q_all, (0, 0, 0, padded_head_num - real_head_num)) + cache_seqlens = self.infer_state.b_seq_len.to(dtype=torch.int32) page_block_size = 64 - max_seq_len = int(cache_seqlens.max().item()) - max_block_num = (max_seq_len + page_block_size - 1) // page_block_size - block_table = ( - self.infer_state.req_manager.req_to_token_indexs[ - self.infer_state.b_req_idx, : max_block_num * page_block_size : page_block_size - ] - // page_block_size - ).to(dtype=torch.int32) num_heads_k = 1 num_heads_q = q_all.shape[2] tile_scheduler_metadata, num_splits = get_mla_metadata( @@ -215,6 +213,7 @@ def _nsa_decode_att( packed_kv.stride(-1), ), ) + block_table = torch.empty((cache_seqlens.shape[0], 0), dtype=torch.int32, device=q_all.device) o_tensor, _ = flash_mla_with_kvcache( q=q_all, @@ -227,6 +226,7 @@ def _nsa_decode_att( softmax_scale=softmax_scale, causal=False, is_fp8_kvcache=True, - indices=topk_indices.to(dtype=torch.int32), + indices=topk_mem_indices.to(dtype=torch.int32), ) + o_tensor = o_tensor[:, :, :real_head_num, :] return o_tensor[:, 0, :, :] # [b, 1, h, d] -> [b, h, d] diff --git a/lightllm/common/kv_cache_mem_manager/fp8_per_token_group_quant_deepseek3_2mem_manager.py b/lightllm/common/kv_cache_mem_manager/fp8_per_token_group_quant_deepseek3_2mem_manager.py index fd7234805..60e94de6a 100644 --- a/lightllm/common/kv_cache_mem_manager/fp8_per_token_group_quant_deepseek3_2mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/fp8_per_token_group_quant_deepseek3_2mem_manager.py @@ -36,13 +36,15 @@ def get_cell_size(self): return self.layer_num * (self.flashmla_bytes_per_token + self.indexer_bytes_per_token) def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): + page_block_size = 64 + token_num = ((size + 1 + page_block_size - 1) // page_block_size) * page_block_size self.kv_buffer = torch.empty( - (layer_num, size + 1, head_num, self.flashmla_bytes_per_token), + (layer_num, token_num, head_num, self.flashmla_bytes_per_token), dtype=dtype, device="cuda", ) self.indexer_k_buffer = torch.empty( - (layer_num, size + 1, head_num, self.indexer_bytes_per_token), + (layer_num, token_num, head_num, self.indexer_bytes_per_token), dtype=dtype, device="cuda", ) From b0ce5b37ab16ee74faebbe85fdf06c17308b08a6 Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 24 Jun 2026 19:39:52 +0800 Subject: [PATCH 3/8] cleanup(mtp): drop dead class-name list in is_mtp_draft_model predicate Every MTP draft model (Deepseek3MTP, Qwen3MOEMTP, Mistral, Glm4MoeLite, Glm5_2) already sets the class attribute is_mtp_draft_model = True, so the getattr term alone covers all of them; the " in str(__class__)" chain is unreachable (and was never extended for the GLM models). Reduce the predicate to the getattr. --- lightllm/common/basemodel/basemodel.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 3c500fa7c..814ccfbbe 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -1171,12 +1171,7 @@ def _init_padded_req(self): def _gen_special_model_input(self, token_num: int): special_model_input = {} - is_mtp_draft_model = getattr(self, "is_mtp_draft_model", False) or ( - "Deepseek3MTPModel" in str(self.__class__) - or "Qwen3MOEMTPModel" in str(self.__class__) - or "MistralMTPModel" in str(self.__class__) - or "Glm4MoeLiteMTPModel" in str(self.__class__) - ) + is_mtp_draft_model = getattr(self, "is_mtp_draft_model", False) if is_mtp_draft_model: special_model_input["mtp_draft_input_hiddens"] = torch.randn( token_num, self.config["hidden_size"], dtype=self.data_type, device="cuda" From 115ca55859fb404c87156f2cb4e59c70937a3609 Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 24 Jun 2026 19:39:53 +0800 Subject: [PATCH 4/8] cleanup(glm5_2): fix contradictory rope_theta default _init_config already backfills config["rope_theta"] with 1e6 when absent, so _init_glm5_2_rotary's 8e6 fallback is unreachable and disagrees with the value actually used. Align the default to 1e6. --- lightllm/models/glm5_2/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/models/glm5_2/model.py b/lightllm/models/glm5_2/model.py index 6ac26c5e3..10c1bb3f7 100644 --- a/lightllm/models/glm5_2/model.py +++ b/lightllm/models/glm5_2/model.py @@ -39,7 +39,7 @@ def _create_inferstate(self, model_input, microbatch_index: int = 0): return infer_state def _init_glm5_2_rotary(self): - rope_theta = self.config.get("rope_theta", 8000000.0) + rope_theta = self.config.get("rope_theta", 1000000.0) qk_rope_head_dim = self.config.get("qk_rope_head_dim", 64) max_position_embeddings = self.config.get("max_position_embeddings", 1048576) From 7e3dde2c7c09e126f51eadbac3947a4c044b7239 Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 25 Jun 2026 20:26:50 +0800 Subject: [PATCH 5/8] perf(glm5_2): add fused_add_rmsnorm (residual-add + RMSNorm in one kernel) Folds the residual elementwise-add into the RMSNorm pass, removing a separate tiny add kernel per residual junction at decode. Variance is taken from the bf16-rounded sum, so the result is bit-identical to `add_` then rmsnorm (test/kernel/test_fused_add_rmsnorm.py: out_max_abs=0). Exposed as RMSNormWeight.fused_add_forward; wired into the decode path in a follow-up. --- .../layer_weights/meta_weights/norm_weight.py | 16 +++ .../triton_kernel/norm/fused_add_rmsnorm.py | 98 +++++++++++++++++++ test/kernel/test_fused_add_rmsnorm.py | 80 +++++++++++++++ 3 files changed, 194 insertions(+) create mode 100644 lightllm/common/basemodel/triton_kernel/norm/fused_add_rmsnorm.py create mode 100644 test/kernel/test_fused_add_rmsnorm.py diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index ee9d1923c..e01287fac 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -3,6 +3,7 @@ from .base_weight import BaseWeightTpl from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_dp, get_dp_world_size from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import rmsnorm_forward +from lightllm.common.basemodel.triton_kernel.norm.fused_add_rmsnorm import fused_add_rmsnorm_forward from lightllm.common.basemodel.triton_kernel.norm.layernorm import layernorm_forward from lightllm.common.basemodel.triton_kernel.norm.qk_norm import qk_rmsnorm_fused_forward from lightllm.common.basemodel.triton_kernel.norm.gated_rmsnorm import gated_rmsnorm_forward @@ -71,6 +72,21 @@ def __call__( ) -> torch.Tensor: return self._forward(input=input, eps=eps, out=out, alloc_func=alloc_func) + def fused_add_forward( + self, + residual: torch.Tensor, + x: torch.Tensor, + eps: float, + out: Optional[torch.Tensor] = None, + alloc_func=torch.empty, + ) -> torch.Tensor: + """Fused residual-add + RMSNorm: ``residual <- residual + x`` (in place) and return + ``rmsnorm(residual) * weight``. Bit-identical to a plain ``residual.add_(x)`` followed + by ``__call__`` but in a single Triton launch. CUDA/MUSA (Triton) only.""" + if out is None: + out = alloc_func(residual.shape, dtype=residual.dtype, device=residual.device) + return fused_add_rmsnorm_forward(residual=residual, x=x, weight=self.weight, eps=eps, out=out) + class GatedRMSNormWeight(RMSNormWeight): def _triton_forward( diff --git a/lightllm/common/basemodel/triton_kernel/norm/fused_add_rmsnorm.py b/lightllm/common/basemodel/triton_kernel/norm/fused_add_rmsnorm.py new file mode 100644 index 000000000..bf84f498c --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/norm/fused_add_rmsnorm.py @@ -0,0 +1,98 @@ +import torch + +import triton +import triton.language as tl +import os + +fused_add_rmsnorm_num_warps = int(os.getenv("RMSNORM_WARPS", "8")) + + +@triton.jit +def _fused_add_rmsnorm_fwd( + X, # pointer to the addend (e.g. attention / ffn output) + RESIDUAL, # pointer to the residual, updated in place to (residual + x) + Y, # pointer to the normalized output + W, # pointer to the weights + x_stride0, + residual_stride0, + y_stride0, + N, # number of columns + eps, # epsilon to avoid division by zero + HAS_WEIGHT: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + row = tl.program_id(0) + X += row * x_stride0 + RESIDUAL += row * residual_stride0 + Y += row * y_stride0 + # pass 1: residual = residual + x (in place), accumulate variance of the updated residual + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) + r = tl.load(RESIDUAL + cols, mask=mask, other=0.0).to(tl.float32) + # round the updated residual to the storage dtype first, then accumulate variance + # from the rounded value so this matches the unfused (store; reload; rmsnorm) path + # bit-for-bit instead of using the higher-precision fp32 sum. + s = (r + x).to(RESIDUAL.dtype.element_ty) + tl.store(RESIDUAL + cols, s, mask=mask) + s = s.to(tl.float32) + _var += s * s + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + # pass 2: normalize the (rounded) updated residual and optionally apply weight + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + if HAS_WEIGHT: + w = tl.load(W + cols, mask=mask).to(tl.float32) + s = tl.load(RESIDUAL + cols, mask=mask, other=0.0).to(tl.float32) + y = s * rstd + if HAS_WEIGHT: + y = y * w + tl.store(Y + cols * 1, y.to(Y.dtype.element_ty), mask=mask) + + +def fused_add_rmsnorm_forward( + residual: torch.Tensor, x: torch.Tensor, weight: torch.Tensor, eps: float, out: torch.Tensor +) -> torch.Tensor: + """Fused residual-add + RMSNorm. + + Computes ``residual <- residual + x`` (in place) and ``out <- rmsnorm(residual) * weight`` + in a single kernel, eliminating the separate elementwise-add launch and the extra HBM + round-trip of the hidden state. ``residual`` and ``x`` share shape; ``residual`` is + updated in place (so the running residual stream is preserved for the next layer). + """ + assert residual.shape == x.shape + r_arg = residual.view(-1, residual.shape[-1]) + x_arg = x.view(-1, x.shape[-1]) + y_arg = out.view(-1, out.shape[-1]) + assert r_arg.shape == x_arg.shape == y_arg.shape + if weight is not None: + assert r_arg.shape[-1] == weight.shape[0] + # contiguous last dim (the kernel assumes unit stride within a row) + assert r_arg.stride(1) == 1 and x_arg.stride(1) == 1 and y_arg.stride(1) == 1 + assert out.data_ptr() == y_arg.data_ptr() + M, N = r_arg.shape + MAX_FUSED_SIZE = 65536 // residual.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_SIZE: + raise RuntimeError("fused_add_rmsnorm doesn't support feature dim >= 64KB.") + if BLOCK_SIZE > 16384: + BLOCK_SIZE = 16384 + _fused_add_rmsnorm_fwd[(M,)]( + x_arg, + r_arg, + y_arg, + weight, + x_arg.stride(0), + r_arg.stride(0), + y_arg.stride(0), + N, + eps, + HAS_WEIGHT=weight is not None, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=fused_add_rmsnorm_num_warps, + ) + return out diff --git a/test/kernel/test_fused_add_rmsnorm.py b/test/kernel/test_fused_add_rmsnorm.py new file mode 100644 index 000000000..a779a2850 --- /dev/null +++ b/test/kernel/test_fused_add_rmsnorm.py @@ -0,0 +1,80 @@ +"""Correctness + microbench for fused_add_rmsnorm vs the production (add_ ; rmsnorm) sequence.""" +import torch +from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import rmsnorm_forward +from lightllm.common.basemodel.triton_kernel.norm.fused_add_rmsnorm import fused_add_rmsnorm_forward + + +def ref(residual, x, weight, eps): + # exactly what token_forward does today: in-place residual add, then a separate rmsnorm + res = residual.clone() + res.add_(x) + out = rmsnorm_forward(res, weight, eps) + return out, res + + +def check(M, N, dtype, eps=1e-5): + torch.manual_seed(0) + residual = (-2.3 + 0.5 * torch.randn(M, N, dtype=dtype, device="cuda")).contiguous() + x = (0.7 * torch.randn(M, N, dtype=dtype, device="cuda")).contiguous() + weight = torch.rand(N, dtype=dtype, device="cuda") + + res_ref0 = residual.clone() + out_ref, res_ref = ref(res_ref0, x, weight, eps) + + res_fused = residual.clone() + out_fused = torch.empty_like(res_fused) + fused_add_rmsnorm_forward(res_fused, x, weight, eps, out=out_fused) + + # residual update must be bit-identical to a plain bf16 add_ + res_match = torch.equal(res_fused, res_ref) + out_abs = (out_fused.float() - out_ref.float()).abs() + rel = out_abs / (out_ref.float().abs() + 1e-6) + cos = torch.nn.functional.cosine_similarity(out_fused.float().flatten(), out_ref.float().flatten(), dim=0).item() + print( + f"M={M:<4} N={N:<6} {str(dtype):<14} residual_bitmatch={res_match} " + f"out_max_abs={out_abs.max().item():.3e} out_max_rel={rel.max().item():.3e} cos={cos:.7f}" + ) + assert res_match, "residual (residual+x) must bit-match a plain add_" + # variance is now taken from the bf16-rounded sum, matching the unfused path bit-for-bit + assert torch.equal(out_fused, out_ref), "normalized output must bit-match the unfused path" + + +def bench(M, N, dtype=torch.bfloat16, eps=1e-5, iters=200): + residual = torch.randn(M, N, dtype=dtype, device="cuda") + x = torch.randn(M, N, dtype=dtype, device="cuda") + weight = torch.rand(N, dtype=dtype, device="cuda") + out = torch.empty_like(residual) + + def run_unfused(): + residual.add_(x) + rmsnorm_forward(residual, weight, eps, out=out) + + def run_fused(): + fused_add_rmsnorm_forward(residual, x, weight, eps, out=out) + + for fn, name in [(run_unfused, "add_+rmsnorm"), (run_fused, "fused_add_rmsnorm")]: + for _ in range(20): + fn() + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(iters): + fn() + end.record() + torch.cuda.synchronize() + print(f" {name:<20} {start.elapsed_time(end) / iters * 1e3:.2f} us/call") + + +if __name__ == "__main__": + print("=== correctness ===") + for M in [1, 2, 4, 8, 16, 32]: + check(M, 6144, torch.bfloat16) # GLM-5.2 hidden + check(1, 7168, torch.bfloat16) # DeepSeek-V3 hidden + check(1, 4096, torch.float16) + check(13, 6144, torch.bfloat16) + print("=== microbench (decode-shaped, bs in 1..32 @ N=6144) ===") + for M in [1, 4, 16, 32]: + print(f"M={M}:") + bench(M, 6144) + print("OK") From 94c8b20ed42ee611427542452d2b9292d148dca0 Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 25 Jun 2026 20:26:50 +0800 Subject: [PATCH 6/8] perf(glm5_2): fold attn-output all-reduce into the residual + ffn RMSNorm The decode attention-output junction now uses flashinfer kARResidualRMSNorm (SGLang #22390): the all-reduce, residual add, and ffn RMSNorm fuse into one oneshot-lamport kernel (fp32_acc) instead of three. LightLLM already launched flashinfer's allreduce_fusion in kAllReduce mode; this adds the fused pattern via all_reduce_residual_rmsnorm, with a Triton fused_add_rmsnorm fallback when the flashinfer AR fast path is inactive (large messages / SP mode). Keeps o un-reduced (Deepseek2._get_o reduce=False) and inlines the decode attention. Cuda-graph safe; GSM8K unchanged. Gated by LIGHTLLM_FUSED_{ADD,AR}_RMSNORM=1. --- lightllm/distributed/communication_op.py | 28 ++++++ lightllm/distributed/flashinfer_all_reduce.py | 23 +++++ .../layer_infer/transformer_layer_infer.py | 86 ++++++++++++++++++- 3 files changed, 135 insertions(+), 2 deletions(-) diff --git a/lightllm/distributed/communication_op.py b/lightllm/distributed/communication_op.py index f15badde2..54af0a891 100644 --- a/lightllm/distributed/communication_op.py +++ b/lightllm/distributed/communication_op.py @@ -100,6 +100,14 @@ def all_reduce(self, input_: torch.Tensor) -> None: return return dist.all_reduce(input_, group=self.device_group) + def all_reduce_residual_rmsnorm(self, inp, residual, rms_weight, eps, alloc_func): + # Fused AR + residual-add + RMSNorm via flashinfer when the message is small enough + # for the oneshot-lamport fast path; otherwise return None so the caller falls back to + # a plain all_reduce + a separate (fused_add) rmsnorm. + if self.flashinfer_reduce is not None and self.flashinfer_reduce.should_use(inp): + return self.flashinfer_reduce.allreduce_residual_rmsnorm(inp, residual, rms_weight, eps, alloc_func) + return None + def all_gather_into_tensor(self, output_: torch.Tensor, input_: torch.Tensor, async_op: bool = False) -> None: return dist.all_gather_into_tensor(output_, input_, group=self.device_group, async_op=async_op) @@ -235,6 +243,26 @@ def all_reduce( return dist.all_reduce(input_, op, group, async_op) +def all_reduce_residual_rmsnorm( + inp: torch.Tensor, + residual: torch.Tensor, + rms_weight: torch.Tensor, + eps: float, + group: Optional[Union[ProcessGroup, CustomProcessGroup]], + alloc_func, +): + """Fused all-reduce + residual-add + RMSNorm (SGLang #22390). + + Returns ``(norm_out, residual_out)`` when a fused fast path (flashinfer) is available, + otherwise ``None`` so the caller can fall back to ``all_reduce`` + a separate + (fused-add) RMSNorm. ``inp`` is the un-reduced tensor; ``residual`` is added after the + reduction. + """ + if isinstance(group, CustomProcessGroup): + return group.all_reduce_residual_rmsnorm(inp, residual, rms_weight, eps, alloc_func) + return None + + def all_gather_into_tensor( output_: torch.Tensor, input_: torch.Tensor, diff --git a/lightllm/distributed/flashinfer_all_reduce.py b/lightllm/distributed/flashinfer_all_reduce.py index 27856d9ac..0373eff5d 100644 --- a/lightllm/distributed/flashinfer_all_reduce.py +++ b/lightllm/distributed/flashinfer_all_reduce.py @@ -133,3 +133,26 @@ def all_reduce(self, inp: torch.Tensor) -> torch.Tensor: workspace=self._workspace, pattern=flashinfer_comm.AllReduceFusionPattern.kAllReduce, ) + + def allreduce_residual_rmsnorm(self, inp, residual, rms_weight, eps, alloc_func): + """Fused all-reduce + residual-add + RMSNorm (flashinfer kARResidualRMSNorm). + + Computes ``residual_out = residual + allreduce(inp)`` and + ``norm_out = rmsnorm(residual_out) * rms_weight`` in one kernel — the SGLang + #22390 fusion. Returns ``(norm_out, residual_out)``; both are freshly allocated + (the kernel is out-of-place). ``inp`` must already satisfy ``should_use``. + """ + norm_out = alloc_func(inp.shape, dtype=inp.dtype, device=inp.device) + residual_out = alloc_func(residual.shape, dtype=residual.dtype, device=residual.device) + flashinfer_comm.allreduce_fusion( + input=inp, + workspace=self._workspace, + pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm, + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_weight, + rms_eps=eps, + fp32_acc=True, + ) + return norm_out, residual_out diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index be819c94a..c7ffb31cc 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -10,6 +10,7 @@ from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe_ep import use_sm100_mega_moe from functools import partial from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale +from lightllm.distributed.communication_op import all_reduce_residual_rmsnorm from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.dist_utils import get_global_world_size from lightllm.utils.log_utils import init_logger @@ -50,6 +51,13 @@ def __init__(self, layer_num, network_config): mscale = get_deepseek_mscale(scaling_factor, mscale_all_dim) self.softmax_scale = self.softmax_scale * mscale * mscale self.enable_cc_method = not os.getenv("DISABLE_CC_METHOD", "False").upper() in ["ON", "TRUE", "1"] + # Fuse the post-attention residual add into the ffn RMSNorm (one Triton launch instead + # of a separate add_ + rmsnorm). Bit-identical; gate exists only for A/B measurement. + self.enable_fused_add_norm = os.environ.get("LIGHTLLM_FUSED_ADD_RMSNORM", "1") == "1" + # Additionally fold the attention-output all-reduce into that residual-add + RMSNorm via + # flashinfer kARResidualRMSNorm (SGLang #22390). Only fires when flashinfer AR is the + # active backend (small messages / low concurrency); falls back otherwise. + self.enable_fused_ar_norm = os.environ.get("LIGHTLLM_FUSED_AR_RMSNORM", "1") == "1" super().__init__(layer_num, network_config) self.num_heads = network_config["num_attention_heads"] self.num_kv_heads = network_config["num_key_value_heads"] @@ -200,7 +208,11 @@ def _get_qkv( return q, cache_kv def _get_o( - self, input: torch.Tensor, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight + self, + input: torch.Tensor, + infer_state: Deepseek2InferStateInfo, + layer_weight: Deepseek2TransformerLayerWeight, + reduce: bool = True, ) -> torch.Tensor: if infer_state.need_dp_prefill_balance: input = infer_state._all_to_all_balance_get(data=input) @@ -208,7 +220,10 @@ def _get_o( if input.shape[2] == self.kv_lora_rank: input = layer_weight.v_b_proj_.bmm(input.transpose(0, 1)).transpose(0, 1) o_tensor = layer_weight.o_weight_.mm(input.reshape(-1, self.tp_q_head_num_ * self.v_head_dim)) - o_tensor = self._tpsp_reduce(input=o_tensor, infer_state=infer_state) + # reduce=False leaves o un-reduced so the caller can fuse the all-reduce into the + # following residual-add + RMSNorm (flashinfer kARResidualRMSNorm). + if reduce: + o_tensor = self._tpsp_reduce(input=o_tensor, infer_state=infer_state) return o_tensor def _moe_ffn_tp( @@ -288,6 +303,73 @@ def _ffn_ep_impl( return ffn2_out + def _fused_add_ffn_norm(self, input_embdings: torch.Tensor, o: torch.Tensor, infer_state, layer_weight): + # Fuse the post-attention residual add (input_embdings += o) into the following ffn + # RMSNorm in a single Triton launch — eliminates one tiny elementwise-add kernel per + # layer. Bit-identical to `input_embdings.add_(o); self._ffn_norm(input_embdings)`. + if self.enable_fused_add_norm: + return layer_weight.ffn_norm_weight_.fused_add_forward( + residual=input_embdings.view(-1, self.embed_dim_), + x=o.view(-1, self.embed_dim_), + eps=self.eps_, + alloc_func=self.alloc_tensor, + ) + input_embdings.add_(o.view(-1, self.embed_dim_)) + return self._ffn_norm(input_embdings, infer_state, layer_weight) + + def _attn_out_add_ffn_norm(self, o_attn, input_embdings, infer_state, layer_weight): + """Combine the attention-output all-reduce, the residual add, and the ffn RMSNorm. + + Fast path (flashinfer kARResidualRMSNorm): all three fold into one kernel; ``o`` is kept + un-reduced and the reduction happens inside the fused op. Returns the new (normed_input, + residual). Falls back to the standard all-reduce + (fused-add) RMSNorm when flashinfer + AR is not the active backend (large messages / SP mode / disabled). + """ + if self.enable_fused_ar_norm and self.tp_world_size_ > 1 and not get_env_start_args().enable_tpsp_mix_mode: + o = self._get_o(o_attn, infer_state, layer_weight, reduce=False).view(-1, self.embed_dim_) + fused = all_reduce_residual_rmsnorm( + inp=o, + residual=input_embdings.view(-1, self.embed_dim_), + rms_weight=layer_weight.ffn_norm_weight_.weight, + eps=self.eps_, + group=infer_state.dist_group, + alloc_func=self.alloc_tensor, + ) + if fused is not None: + norm_out, residual_out = fused + return norm_out, residual_out + # flashinfer not applicable for this message size: finish the all-reduce normally. + o = self._tpsp_reduce(input=o, infer_state=infer_state) + return self._fused_add_ffn_norm(input_embdings, o, infer_state, layer_weight), input_embdings + + o = self._get_o(o_attn, infer_state, layer_weight, reduce=True) + return self._fused_add_ffn_norm(input_embdings, o, infer_state, layer_weight), input_embdings + + def context_forward(self, input_embdings, infer_state: Deepseek2InferStateInfo, layer_weight): + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + o = self.context_attention_forward(input1, infer_state, layer_weight) + input1 = self._fused_add_ffn_norm(input_embdings, o, infer_state, layer_weight) + o = None + ffn_out = self._ffn(input1, infer_state, layer_weight) + input1 = None + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + return input_embdings + + def token_forward(self, input_embdings, infer_state: Deepseek2InferStateInfo, layer_weight): + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + # Inline the decode attention so the output projection stays un-reduced and its + # all-reduce can fold into the residual-add + ffn RMSNorm (see _attn_out_add_ffn_norm). + q, cache_kv = self._get_qkv(input1, infer_state, layer_weight) + self._post_cache_kv(cache_kv, infer_state, layer_weight) + o = self._token_attention_kernel(q, infer_state, layer_weight) + input1 = None + input1, input_embdings = self._attn_out_add_ffn_norm(o, input_embdings, infer_state, layer_weight) + o = None + ffn_out = self._ffn(input1, infer_state, layer_weight) + input1 = None + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + return input_embdings + def overlap_tpsp_token_forward( self, input_embdings: torch.Tensor, From d46cbddc77ed58ea0d7227d3ea9f0a5cd12e0787 Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 25 Jun 2026 20:26:51 +0800 Subject: [PATCH 7/8] perf(moe): fuse silu_and_mul with the down-proj per-token-group fp8 quant The block-wise-fp8 down projection quantized its activation in a separate per_token_group_quant launch per MoE layer. silu_and_mul_group_quant_fwd now emits the silu output directly as fp8 + row-major group scales (byte-matching per_token_group_quant_fp8: scales bit-identical, 99.9% fp8-exact), and grouped_matmul skips its internal quant when the input is already fp8. Gated by LIGHTLLM_FUSED_SILU_QUANT=1; GSM8K unchanged. --- .../fused_moe/grouped_fused_moe.py | 57 +++++++-- .../fused_moe/moe_silu_and_mul_group_quant.py | 118 ++++++++++++++++++ test/kernel/test_silu_and_mul_group_quant.py | 50 ++++++++ 3 files changed, 215 insertions(+), 10 deletions(-) create mode 100644 lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_group_quant.py create mode 100644 test/kernel/test_silu_and_mul_group_quant.py diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py index 76acea25a..616376aac 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import torch import triton import triton.language as tl @@ -26,6 +27,7 @@ from lightllm.utils.vllm_utils import vllm_ops from lightllm.utils.device_utils import triton_support_tensor_descriptor from .moe_silu_and_mul import silu_and_mul_fwd +from .moe_silu_and_mul_group_quant import silu_and_mul_group_quant_fwd from .moe_sum_reduce import moe_sum_reduce from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import per_token_group_quant_fp8 from lightllm.utils.torch_ops_utils import direct_register_custom_op @@ -35,6 +37,11 @@ logger = init_logger(__name__) +# Fuse the down-projection's per-token-group fp8 quant into the silu_and_mul that produces its +# input (eliminates a separate per_token_group_quant launch per MoE layer). Numerically matches +# the unfused path; gate exists for A/B and rollback. +_ENABLE_FUSED_SILU_QUANT = os.environ.get("LIGHTLLM_FUSED_SILU_QUANT", "1") == "1" + @triton.jit def moe_align_kernel( @@ -762,8 +769,12 @@ def grouped_matmul( assert BLOCK_SIZE_K == triton.next_power_of_2(BLOCK_SIZE_K) if use_fp8_w8a8: + if token_inputs.dtype == expert_weights.dtype: + # token_inputs were already fp8-quantized by a fused producer (e.g. the fused + # silu+mul+group-quant on the down projection); token_input_scale is its scale. + assert token_input_scale is not None, "pre-quantized fp8 input requires its scale" # 当权重使用 block wise 量化时,激活也使用 per token, group size 量化 - if block_size_k == 0: + elif block_size_k == 0: # input 使用 per token 量化 token_inputs, token_input_scale = vllm_ops.scaled_fp8_quant( token_inputs, token_input_scale, use_per_token_if_dynamic=True @@ -984,18 +995,44 @@ def fused_experts_impl( bias=w1_bias, ) - silu_and_mul_fwd( - intermediate_cache1.view(-1, N), - intermediate_cache2.view(-1, N // 2), - limit=limit, - alpha=alpha, - layout=layout, - ) + # Fuse the down-projection's per-token-group fp8 quant into silu_and_mul when the down + # weight is block-wise fp8 quantized: the silu output is emitted directly as fp8 + scales, + # so grouped_matmul below skips its internal per_token_group_quant launch. + use_fused_silu_quant = _ENABLE_FUSED_SILU_QUANT and use_fp8_w8a8 and w2_scale is not None and w2_scale.ndim == 3 + if use_fused_silu_quant: + down_token_num = curr_topk_ids.numel() + down_k = N // 2 + down_group_size = down_k // w2_scale.shape[2] + down_inputs = alloc_tensor_func( + (down_token_num, down_k), device=hidden_states.device, dtype=torch.float8_e4m3fn + ) + down_input_scale = alloc_tensor_func( + (down_token_num, down_k // down_group_size), device=hidden_states.device, dtype=torch.float32 + ) + silu_and_mul_group_quant_fwd( + intermediate_cache1.view(-1, N), + down_inputs, + down_input_scale, + down_group_size, + layout=layout, + limit=limit, + alpha=alpha, + ) + else: + silu_and_mul_fwd( + intermediate_cache1.view(-1, N), + intermediate_cache2.view(-1, N // 2), + limit=limit, + alpha=alpha, + layout=layout, + ) + down_inputs = intermediate_cache2.view(-1, N // 2) + down_input_scale = a2_scale grouped_matmul( curr_topk_ids.numel(), - intermediate_cache2.view(-1, N // 2), - a2_scale, + down_inputs, + down_input_scale, expert_to_token_num, expert_to_tokens, expert_to_weights=expert_to_weights, diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_group_quant.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_group_quant.py new file mode 100644 index 000000000..b28987dc6 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_group_quant.py @@ -0,0 +1,118 @@ +import torch + +import triton +import triton.language as tl +from lightllm.utils.config_utils import ffn_use_tanh_approximate_gelu + + +@triton.jit +def _silu_and_mul_group_quant_kernel( + input_ptr, + stride_input_m, + out_q_ptr, + stride_out_q_m, + out_s_ptr, + stride_out_s_m, + size_n, # width of the silu output (= down-proj K), a multiple of GROUP_SIZE + limit, + alpha, + fp8_min, + fp8_max, + eps, + GROUP_SIZE: tl.constexpr, + layout: tl.constexpr, # "blocked" or "interleaved" + USE_LIMIT_AND_ALPHA: tl.constexpr, + USE_TANH_APPROXIMATE_GELU: tl.constexpr, +): + # One program per (token row, quant group of GROUP_SIZE columns). Computes silu(gate)*up for + # the group, then a per-group fp8 quant — byte-identical layout to + # silu_and_mul_fwd followed by per_token_group_quant_fp8 (row-major scales), in one launch. + m_index = tl.program_id(0).to(tl.int64) + group_index = tl.program_id(1) + cols = group_index * GROUP_SIZE + tl.arange(0, GROUP_SIZE) + if layout == "interleaved": + # [gate0, up0, gate1, up1, ...] + gate_off = m_index * stride_input_m + cols * 2 + up_off = gate_off + 1 + else: + # [gate0, gate1, ..., up0, up1, ...] + gate_off = m_index * stride_input_m + cols + up_off = m_index * stride_input_m + cols + size_n + gate = tl.load(input_ptr + gate_off).to(tl.float32) + up = tl.load(input_ptr + up_off) + + if USE_LIMIT_AND_ALPHA: + gate = tl.minimum(gate, limit) + up = tl.minimum(tl.maximum(up, -limit), limit) + gate = 1 / (1 + tl.exp(-gate * alpha)) * gate + gate = gate.to(input_ptr.dtype.element_ty) + gate_up = (up + 1) * gate + else: + if USE_TANH_APPROXIMATE_GELU: + gate_cubed = gate * gate * gate + tanh_arg = 0.7978845608028654 * (gate + 0.044715 * gate_cubed) + tanh_val = 2.0 / (1.0 + tl.exp(-2.0 * tanh_arg)) - 1.0 + gate = 0.5 * gate * (1.0 + tanh_val) + else: + gate = gate / (1 + tl.exp(-gate)) + gate = gate.to(input_ptr.dtype.element_ty) + gate_up = up * gate + + # quantize the (bf16-rounded) silu output per group, matching per_token_group_quant_fp8. + gate_up_f = gate_up.to(tl.float32) + _absmax = tl.maximum(tl.max(tl.abs(gate_up_f)), eps) + out_s = _absmax / fp8_max + out_q = tl.clamp(gate_up_f / out_s, fp8_min, fp8_max).to(out_q_ptr.dtype.element_ty) + tl.store(out_q_ptr + m_index * stride_out_q_m + cols, out_q) + tl.store(out_s_ptr + m_index * stride_out_s_m + group_index, out_s) + + +def silu_and_mul_group_quant_fwd( + input: torch.Tensor, + output_q: torch.Tensor, + output_s: torch.Tensor, + group_size: int, + layout: str = "blocked", + limit=None, + alpha=None, +): + """Fused silu_and_mul + per-token-group fp8 quant. + + ``input`` [M, 2*N] (gate|up) -> ``output_q`` [M, N] fp8 + ``output_s`` [M, N//group_size] + float32 row-major. Equivalent to ``silu_and_mul_fwd(input, tmp); per_token_group_quant_fp8(tmp)`` + but in one kernel, so the down-projection's activation quant disappears as a separate launch. + """ + assert input.is_contiguous() + assert output_q.is_contiguous() and output_q.dtype == torch.float8_e4m3fn + assert (limit is None and alpha is None) or (limit is not None and alpha is not None) + size_m = input.shape[0] + size_n = input.shape[-1] // 2 + assert size_n % group_size == 0, f"silu output width {size_n} not divisible by group_size {group_size}" + assert output_q.shape[0] == size_m and output_q.shape[1] == size_n + assert output_s.shape[0] == size_m and output_s.shape[1] == size_n // group_size + + finfo = torch.finfo(torch.float8_e4m3fn) + fp8_max = finfo.max + fp8_min = -fp8_max + # grid: token rows on dim-0 (up to 2**31, may be large for prefill), groups on dim-1 (small). + grid = (size_m, size_n // group_size) + _silu_and_mul_group_quant_kernel[grid]( + input, + input.stride(0), + output_q, + output_q.stride(0), + output_s, + output_s.stride(0), + size_n, + limit if limit is not None else 0.0, + alpha if alpha is not None else 0.0, + fp8_min, + fp8_max, + 1e-10, + GROUP_SIZE=group_size, + layout=layout, + USE_LIMIT_AND_ALPHA=limit is not None and alpha is not None, + USE_TANH_APPROXIMATE_GELU=ffn_use_tanh_approximate_gelu(), + num_warps=1, + ) + return diff --git a/test/kernel/test_silu_and_mul_group_quant.py b/test/kernel/test_silu_and_mul_group_quant.py new file mode 100644 index 000000000..5d23ae6fa --- /dev/null +++ b/test/kernel/test_silu_and_mul_group_quant.py @@ -0,0 +1,50 @@ +"""Correctness: fused silu_and_mul + group fp8 quant vs unfused silu_and_mul_fwd then per_token_group_quant_fp8.""" +import torch +from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul import silu_and_mul_fwd +from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import per_token_group_quant_fp8 +from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul_group_quant import silu_and_mul_group_quant_fwd + + +def dequant(q, s, group_size): + # q: [M, N] fp8, s: [M, N//group] fp32 -> [M, N] fp32 + M, N = q.shape + qf = q.to(torch.float32).reshape(M, N // group_size, group_size) + return (qf * s.reshape(M, N // group_size, 1)).reshape(M, N) + + +def check(M, N, group_size=128, layout="blocked", dtype=torch.bfloat16): + torch.manual_seed(0) + inp = torch.randn(M, 2 * N, dtype=dtype, device="cuda").contiguous() + + # unfused reference + tmp = torch.empty(M, N, dtype=dtype, device="cuda") + silu_and_mul_fwd(inp, tmp, layout=layout) + q_ref, s_ref = per_token_group_quant_fp8(tmp, group_size, dtype=torch.float8_e4m3fn) + + # fused + q = torch.empty(M, N, dtype=torch.float8_e4m3fn, device="cuda") + s = torch.empty(M, N // group_size, dtype=torch.float32, device="cuda") + silu_and_mul_group_quant_fwd(inp, q, s, group_size, layout=layout) + + dq_ref = dequant(q_ref, s_ref.reshape(M, N // group_size), group_size) + dq = dequant(q, s, group_size) + cos = torch.nn.functional.cosine_similarity(dq.flatten(), dq_ref.flatten(), dim=0).item() + rel = ((dq - dq_ref).abs() / (dq_ref.abs() + 1e-4)).max().item() + s_match = (s.reshape(-1) - s_ref.reshape(-1)).abs().max().item() + q_exact = (q.to(torch.float32) == q_ref.to(torch.float32)).float().mean().item() + print( + f"M={M:<5} N={N:<6} {layout:<11} cos={cos:.7f} dequant_max_rel={rel:.3e} " + f"scale_max_abs={s_match:.3e} q_exact_frac={q_exact:.4f}" + ) + assert cos > 0.9999, f"cosine too low: {cos}" + assert s_match < 1e-6 or q_exact > 0.99, "scale/quant diverged beyond fp8 rounding" + + +if __name__ == "__main__": + print("=== correctness (fused silu+group-quant vs unfused) ===") + for M in [1, 8, 16, 64, 256]: + check(M, 1536) # glm/deepseek-ish moe intermediate + check(8, 2048) + check(13, 768) + check(8, 1536, layout="interleaved") + print("OK") From d94e6bb955558ff76929d9f0a511a0b8858fae1f Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 25 Jun 2026 23:25:14 +0800 Subject: [PATCH 8/8] docs(glm5_2): add GLM-5.2 to the supported models list BF16/FP8 + MTP, glm_moe_dsa (DeepSeek-V3.2-style DSA MoE). --- docs/CN/source/models/supported_models.rst | 4 +++- docs/EN/source/models/supported_models.rst | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/CN/source/models/supported_models.rst b/docs/CN/source/models/supported_models.rst index 3d66d2e07..eba38bd79 100755 --- a/docs/CN/source/models/supported_models.rst +++ b/docs/CN/source/models/supported_models.rst @@ -56,6 +56,8 @@ lightllm 支持大多数的主流的开源大语言模型以及多模态模型 - * - `Qwen3-Moe `_ - + * - `GLM-5.2 `_ + - 支持 BF16/FP8 和 MTP。 多模态模型 @@ -93,4 +95,4 @@ Reward模型 * - `internLM-reward `_ - :code:`--use_reward_model` * - `Qwen2-Reward `_ - - :code:`--use_reward_model` \ No newline at end of file + - :code:`--use_reward_model` diff --git a/docs/EN/source/models/supported_models.rst b/docs/EN/source/models/supported_models.rst index 1b1d4fcd0..105ced14c 100755 --- a/docs/EN/source/models/supported_models.rst +++ b/docs/EN/source/models/supported_models.rst @@ -56,6 +56,8 @@ Large Language Models - * - `DeepSeek-V3.2 `_ - + * - `GLM-5.2 `_ + - Supports BF16/FP8 and MTP. Multimodal Models ^^^^^^^^^^^^^^^^^ @@ -94,4 +96,3 @@ Reward Models - :code:`--use_reward_model` * - `Qwen2-Reward `_ - :code:`--use_reward_model` -