From 57523f95007a1044beaab2eaa9d6f26c84d6f0a8 Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Wed, 24 Jun 2026 14:37:17 +0800 Subject: [PATCH] feat: fuse add and rmsnorm --- lightllm/common/basemodel/basemodel.py | 11 +- .../transformer_layer_infer_template.py | 35 +++++ .../layer_weights/meta_weights/norm_weight.py | 17 ++- .../basemodel/triton_kernel/norm/rmsnorm.py | 141 +++++++++++++++++- lightllm/models/qwen3next/model.py | 16 ++ 5 files changed, 214 insertions(+), 6 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 94f9d4c1a..983b27020 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -659,16 +659,19 @@ def prefill_func(input_tensors, infer_state): dist_group_manager.clear_deepep_buffer() return model_output + def _token_forward_layers(self, input_embs: torch.Tensor, infer_state: InferStateInfo): + for i in range(self.layers_num): + layer = self.layers_infer[i] + input_embs: torch.Tensor = layer.token_forward(input_embs, infer_state, self.trans_layers_weight[i]) + return input_embs + @final def _token_forward(self, infer_state: InferStateInfo): input_ids = infer_state.input_ids cuda_input_ids = input_ids input_embs = self.pre_infer.token_forward(cuda_input_ids, infer_state, self.pre_post_weight) input_embs = self.pre_infer._tpsp_sp_split(input=input_embs, infer_state=infer_state) - - for i in range(self.layers_num): - layer = self.layers_infer[i] - input_embs: torch.Tensor = layer.token_forward(input_embs, infer_state, self.trans_layers_weight[i]) + input_embs = self._token_forward_layers(input_embs, infer_state) last_input_embs = self.post_infer._tpsp_allgather(input=input_embs, infer_state=infer_state) predict_logits: torch.Tensor = self.post_infer.token_forward( diff --git a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py index f0cc129c0..a43405de5 100755 --- a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py +++ b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py @@ -98,6 +98,41 @@ def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weigh input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return input_embdings + def token_forward_with_next_att_norm( + self, + input_embdings, + infer_state: InferStateInfo, + layer_weight, + att_normed_input=None, + next_layer=None, + next_layer_weight=None, + ): + input1 = att_normed_input + if input1 is None: + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + o = self.token_attention_forward(input1, infer_state, layer_weight) + input1 = layer_weight.ffn_norm_weight_.add_rmsnorm( + input=input_embdings, + residual=o.view(-1, self.embed_dim_), + eps=self.eps_, + alloc_func=self.alloc_tensor, + ) + o = None + + ffn_out = self._ffn(input1, infer_state, layer_weight) + ffn_out = ffn_out.view(-1, self.embed_dim_) + + if next_layer is not None: + return input_embdings, next_layer_weight.att_norm_weight_.add_rmsnorm( + input=input_embdings, + residual=ffn_out, + eps=next_layer.eps_, + alloc_func=next_layer.alloc_tensor, + ) + + input_embdings.add_(ffn_out) + return input_embdings, None + def _context_attention_wrapper_run( self, q: torch.Tensor, cache_kv: torch.Tensor, infer_state: InferStateInfo, layer_weight ) -> torch.Tensor: diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index ee9d1923c..33b59ac5a 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -2,7 +2,7 @@ from typing import Optional, Dict from .base_weight import BaseWeightTpl from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_dp, get_dp_world_size -from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import rmsnorm_forward +from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import add_rmsnorm_forward, rmsnorm_forward from lightllm.common.basemodel.triton_kernel.norm.layernorm import layernorm_forward from lightllm.common.basemodel.triton_kernel.norm.qk_norm import qk_rmsnorm_fused_forward from lightllm.common.basemodel.triton_kernel.norm.gated_rmsnorm import gated_rmsnorm_forward @@ -71,6 +71,21 @@ def __call__( ) -> torch.Tensor: return self._forward(input=input, eps=eps, out=out, alloc_func=alloc_func) + def add_rmsnorm( + self, + input: torch.Tensor, + residual: torch.Tensor, + eps: float, + out: Optional[torch.Tensor] = None, + alloc_func=torch.empty, + ) -> torch.Tensor: + assert ( + input.ndim in [2, 3] and residual.ndim in [2, 3] and self.weight.ndim == 1 + ), f"input.ndim: {input.ndim}, residual.ndim: {residual.ndim}, weight.ndim: {self.weight.ndim}" + if out is None: + out = alloc_func(input.shape, dtype=input.dtype, device=input.device) + return add_rmsnorm_forward(x=input, residual=residual, weight=self.weight, eps=eps, out=out) + class GatedRMSNormWeight(RMSNormWeight): def _triton_forward( diff --git a/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py b/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py index 8dc855892..0c52a308f 100644 --- a/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py +++ b/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py @@ -3,6 +3,7 @@ import triton import triton.language as tl import os +from lightllm.common.triton_utils.autotuner import autotune rmsnorm_num_warps = int(os.getenv("RMSNORM_WARPS", "8")) @@ -48,6 +49,69 @@ def _rms_norm_fwd_fused( tl.store(Y + cols * y_stride1, y.to(Y.dtype.element_ty), mask=mask) +@triton.jit +def _add_rms_norm_fwd_fused( + X, + R, + Y, + W, + x_stride0, + x_stride1, + r_stride0, + r_stride1, + y_stride0, + y_stride1, + N, + eps, + HAS_WEIGHT: tl.constexpr, + SINGLE_PASS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + row = tl.program_id(0) + X += row * x_stride0 + R += row * r_stride0 + Y += row * y_stride0 + + if SINGLE_PASS: + cols = tl.arange(0, BLOCK_SIZE) + mask = cols < N + x = tl.load(X + cols * x_stride1, mask=mask, other=0.0).to(tl.float32) + r = tl.load(R + cols * r_stride1, mask=mask, other=0.0).to(tl.float32) + x = (x + r).to(X.dtype.element_ty) + tl.store(X + cols * x_stride1, x, mask=mask) + + x = x.to(tl.float32) + var = tl.sum(x * x, axis=0) / N + y = x * (1 / tl.sqrt(var + eps)) + if HAS_WEIGHT: + w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32) + y *= w + tl.store(Y + cols * y_stride1, y.to(Y.dtype.element_ty), mask=mask) + else: + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + x = tl.load(X + cols * x_stride1, mask=mask, other=0.0).to(tl.float32) + r = tl.load(R + cols * r_stride1, mask=mask, other=0.0).to(tl.float32) + x = (x + r).to(X.dtype.element_ty) + tl.store(X + cols * x_stride1, x, mask=mask) + x = x.to(tl.float32) + _var += x * x + + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + x = tl.load(X + cols * x_stride1, mask=mask, other=0.0).to(tl.float32) + y = x * rstd + if HAS_WEIGHT: + w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32) + y *= w + tl.store(Y + cols * y_stride1, y.to(Y.dtype.element_ty), mask=mask) + + def rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float, out=None): # allocate output y = torch.empty_like(x) if out is None else out @@ -60,7 +124,7 @@ def rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float, out=None) assert y.data_ptr() == y_arg.data_ptr() M, N = x_arg.shape # Less than 64KB per feature: enqueue fused kernel - MAX_FUSED_SIZE = 65536 // x.element_size() + MAX_FUSED_SIZE = 65536 // x_arg.element_size() BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) # print("BLOCK_SIZE:", BLOCK_SIZE) if N > BLOCK_SIZE: @@ -86,6 +150,81 @@ def rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float, out=None) return y +def _get_add_rmsnorm_configs(): + return [{"num_warps": nw} for nw in [4, 8, 16]] + + +def _get_add_rmsnorm_static_key( + x_arg: torch.Tensor, residual_arg: torch.Tensor, y_arg: torch.Tensor, weight: torch.Tensor +): + return { + "x_dtype": str(x_arg.dtype), + "residual_dtype": str(residual_arg.dtype), + "out_dtype": str(y_arg.dtype), + "weight_dtype": "none" if weight is None else str(weight.dtype), + "N": x_arg.shape[1], + "has_weight": weight is not None, + } + + +@autotune( + kernel_name="add_rmsnorm_forward:v1", + configs_gen_func=_get_add_rmsnorm_configs, + static_key_func=_get_add_rmsnorm_static_key, + run_key_func=lambda x_arg: x_arg.shape[0], + mutates_args=["x_arg", "y_arg"], +) +def _add_rmsnorm_forward( + x_arg: torch.Tensor, + residual_arg: torch.Tensor, + y_arg: torch.Tensor, + weight: torch.Tensor, + eps: float, + run_config: dict = None, +): + M, N = x_arg.shape + MAX_FUSED_SIZE = 65536 // x_arg.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_SIZE: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + if BLOCK_SIZE > 16384: + BLOCK_SIZE = 16384 + if not run_config: + run_config = {"num_warps": rmsnorm_num_warps} + _add_rms_norm_fwd_fused[(M,)]( + x_arg, + residual_arg, + y_arg, + weight, + x_arg.stride(0), + x_arg.stride(1), + residual_arg.stride(0), + residual_arg.stride(1), + y_arg.stride(0), + y_arg.stride(1), + N, + eps, + HAS_WEIGHT=weight is not None, + SINGLE_PASS=N <= BLOCK_SIZE, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=run_config["num_warps"], + ) + return y_arg + + +def add_rmsnorm_forward(x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float, out=None): + y = torch.empty_like(x) if out is None else out + x_arg = x.view(-1, x.shape[-1]) + residual_arg = residual.view(-1, x.shape[-1]) + y_arg = y.view(-1, x.shape[-1]) + assert x_arg.shape == residual_arg.shape == y_arg.shape + if weight is not None: + assert x_arg.shape[-1] == weight.shape[0] + assert y.data_ptr() == y_arg.data_ptr() + _add_rmsnorm_forward(x_arg, residual_arg, y_arg, weight, eps) + return y + + def torch_rms_norm(x, weight, eps): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) * weight diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index 9b5e9b7a5..5a4ef3e3c 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -102,3 +102,19 @@ def _init_req_manager(self): self.max_req_num, create_max_seq_len, None, linear_config=LinearAttCacheConfig.load_from_args() ) return + + def _token_forward_layers(self, input_embs: torch.Tensor, infer_state: Qwen3NextInferStateInfo): + next_att_normed = None + for i in range(self.layers_num): + layer: Qwen3NextTransformerLayerInfer = self.layers_infer[i] + layer_weight: Qwen3NextTransformerLayerWeight = self.trans_layers_weight[i] + has_next_layer = i + 1 < self.layers_num + input_embs, next_att_normed = layer.token_forward_with_next_att_norm( + input_embs, + infer_state, + layer_weight, + att_normed_input=next_att_normed, + next_layer=self.layers_infer[i + 1] if has_next_layer else None, + next_layer_weight=self.trans_layers_weight[i + 1] if has_next_layer else None, + ) + return input_embs