Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,16 +659,19 @@ def prefill_func(input_tensors, infer_state):
dist_group_manager.clear_deepep_buffer()
return model_output

def _token_forward_layers(self, input_embs: torch.Tensor, infer_state: InferStateInfo):
for i in range(self.layers_num):
layer = self.layers_infer[i]
input_embs: torch.Tensor = layer.token_forward(input_embs, infer_state, self.trans_layers_weight[i])
return input_embs

@final
def _token_forward(self, infer_state: InferStateInfo):
input_ids = infer_state.input_ids
cuda_input_ids = input_ids
input_embs = self.pre_infer.token_forward(cuda_input_ids, infer_state, self.pre_post_weight)
input_embs = self.pre_infer._tpsp_sp_split(input=input_embs, infer_state=infer_state)

for i in range(self.layers_num):
layer = self.layers_infer[i]
input_embs: torch.Tensor = layer.token_forward(input_embs, infer_state, self.trans_layers_weight[i])
input_embs = self._token_forward_layers(input_embs, infer_state)

last_input_embs = self.post_infer._tpsp_allgather(input=input_embs, infer_state=infer_state)
predict_logits: torch.Tensor = self.post_infer.token_forward(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,41 @@ def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weigh
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
return input_embdings

def token_forward_with_next_att_norm(
self,
input_embdings,
infer_state: InferStateInfo,
layer_weight,
att_normed_input=None,
next_layer=None,
next_layer_weight=None,
):
input1 = att_normed_input
if input1 is None:
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
o = self.token_attention_forward(input1, infer_state, layer_weight)
input1 = layer_weight.ffn_norm_weight_.add_rmsnorm(
input=input_embdings,
residual=o.view(-1, self.embed_dim_),
eps=self.eps_,
alloc_func=self.alloc_tensor,
)
o = None

ffn_out = self._ffn(input1, infer_state, layer_weight)
ffn_out = ffn_out.view(-1, self.embed_dim_)

if next_layer is not None:
return input_embdings, next_layer_weight.att_norm_weight_.add_rmsnorm(
input=input_embdings,
residual=ffn_out,
eps=next_layer.eps_,
alloc_func=next_layer.alloc_tensor,
)

input_embdings.add_(ffn_out)
return input_embdings, None

def _context_attention_wrapper_run(
self, q: torch.Tensor, cache_kv: torch.Tensor, infer_state: InferStateInfo, layer_weight
) -> torch.Tensor:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional, Dict
from .base_weight import BaseWeightTpl
from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_dp, get_dp_world_size
from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import rmsnorm_forward
from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import add_rmsnorm_forward, rmsnorm_forward
from lightllm.common.basemodel.triton_kernel.norm.layernorm import layernorm_forward
from lightllm.common.basemodel.triton_kernel.norm.qk_norm import qk_rmsnorm_fused_forward
from lightllm.common.basemodel.triton_kernel.norm.gated_rmsnorm import gated_rmsnorm_forward
Expand Down Expand Up @@ -71,6 +71,21 @@ def __call__(
) -> torch.Tensor:
return self._forward(input=input, eps=eps, out=out, alloc_func=alloc_func)

def add_rmsnorm(
self,
input: torch.Tensor,
residual: torch.Tensor,
eps: float,
out: Optional[torch.Tensor] = None,
alloc_func=torch.empty,
) -> torch.Tensor:
assert (
input.ndim in [2, 3] and residual.ndim in [2, 3] and self.weight.ndim == 1
), f"input.ndim: {input.ndim}, residual.ndim: {residual.ndim}, weight.ndim: {self.weight.ndim}"
if out is None:
out = alloc_func(input.shape, dtype=input.dtype, device=input.device)
return add_rmsnorm_forward(x=input, residual=residual, weight=self.weight, eps=eps, out=out)


class GatedRMSNormWeight(RMSNormWeight):
def _triton_forward(
Expand Down
141 changes: 140 additions & 1 deletion lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import triton
import triton.language as tl
import os
from lightllm.common.triton_utils.autotuner import autotune

rmsnorm_num_warps = int(os.getenv("RMSNORM_WARPS", "8"))

Expand Down Expand Up @@ -48,6 +49,69 @@ def _rms_norm_fwd_fused(
tl.store(Y + cols * y_stride1, y.to(Y.dtype.element_ty), mask=mask)


@triton.jit
def _add_rms_norm_fwd_fused(
X,
R,
Y,
W,
x_stride0,
x_stride1,
r_stride0,
r_stride1,
y_stride0,
y_stride1,
N,
eps,
HAS_WEIGHT: tl.constexpr,
SINGLE_PASS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
row = tl.program_id(0)
X += row * x_stride0
R += row * r_stride0
Y += row * y_stride0

if SINGLE_PASS:
cols = tl.arange(0, BLOCK_SIZE)
mask = cols < N
x = tl.load(X + cols * x_stride1, mask=mask, other=0.0).to(tl.float32)
r = tl.load(R + cols * r_stride1, mask=mask, other=0.0).to(tl.float32)
x = (x + r).to(X.dtype.element_ty)
tl.store(X + cols * x_stride1, x, mask=mask)

x = x.to(tl.float32)
var = tl.sum(x * x, axis=0) / N
y = x * (1 / tl.sqrt(var + eps))
if HAS_WEIGHT:
w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32)
y *= w
tl.store(Y + cols * y_stride1, y.to(Y.dtype.element_ty), mask=mask)
else:
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
x = tl.load(X + cols * x_stride1, mask=mask, other=0.0).to(tl.float32)
r = tl.load(R + cols * r_stride1, mask=mask, other=0.0).to(tl.float32)
x = (x + r).to(X.dtype.element_ty)
tl.store(X + cols * x_stride1, x, mask=mask)
x = x.to(tl.float32)
_var += x * x

var = tl.sum(_var, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
x = tl.load(X + cols * x_stride1, mask=mask, other=0.0).to(tl.float32)
y = x * rstd
if HAS_WEIGHT:
w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32)
y *= w
tl.store(Y + cols * y_stride1, y.to(Y.dtype.element_ty), mask=mask)


def rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float, out=None):
# allocate output
y = torch.empty_like(x) if out is None else out
Expand All @@ -60,7 +124,7 @@ def rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float, out=None)
assert y.data_ptr() == y_arg.data_ptr()
M, N = x_arg.shape
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
MAX_FUSED_SIZE = 65536 // x_arg.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
# print("BLOCK_SIZE:", BLOCK_SIZE)
if N > BLOCK_SIZE:
Expand All @@ -86,6 +150,81 @@ def rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float, out=None)
return y


def _get_add_rmsnorm_configs():
return [{"num_warps": nw} for nw in [4, 8, 16]]


def _get_add_rmsnorm_static_key(
x_arg: torch.Tensor, residual_arg: torch.Tensor, y_arg: torch.Tensor, weight: torch.Tensor
):
Comment on lines +157 to +159

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _get_add_rmsnorm_static_key function is used as the static_key_func for the @autotune decorator on _add_rmsnorm_forward. However, _add_rmsnorm_forward accepts 6 arguments (x_arg, residual_arg, y_arg, weight, eps, run_config), while _get_add_rmsnorm_static_key only accepts 4. When the autotuner invokes this function with all arguments, it will raise a TypeError at runtime. Adding *args, **kwargs to the signature will make it robust against extra arguments.

Suggested change
def _get_add_rmsnorm_static_key(
x_arg: torch.Tensor, residual_arg: torch.Tensor, y_arg: torch.Tensor, weight: torch.Tensor
):
def _get_add_rmsnorm_static_key(
x_arg: torch.Tensor, residual_arg: torch.Tensor, y_arg: torch.Tensor, weight: torch.Tensor, *args, **kwargs
):

return {
"x_dtype": str(x_arg.dtype),
"residual_dtype": str(residual_arg.dtype),
"out_dtype": str(y_arg.dtype),
"weight_dtype": "none" if weight is None else str(weight.dtype),
"N": x_arg.shape[1],
"has_weight": weight is not None,
}


@autotune(
kernel_name="add_rmsnorm_forward:v1",
configs_gen_func=_get_add_rmsnorm_configs,
static_key_func=_get_add_rmsnorm_static_key,
run_key_func=lambda x_arg: x_arg.shape[0],

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The run_key_func lambda lambda x_arg: x_arg.shape[0] only accepts 1 argument, but the autotuner will pass all arguments of the decorated _add_rmsnorm_forward function to it. This will cause a TypeError at runtime. Updating the lambda to accept *args, **kwargs will prevent this crash.

Suggested change
run_key_func=lambda x_arg: x_arg.shape[0],
run_key_func=lambda x_arg, *args, **kwargs: x_arg.shape[0],

mutates_args=["x_arg", "y_arg"],
)
def _add_rmsnorm_forward(
x_arg: torch.Tensor,
residual_arg: torch.Tensor,
y_arg: torch.Tensor,
weight: torch.Tensor,
eps: float,
run_config: dict = None,
):
M, N = x_arg.shape
MAX_FUSED_SIZE = 65536 // x_arg.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
if BLOCK_SIZE > 16384:
BLOCK_SIZE = 16384
if not run_config:
run_config = {"num_warps": rmsnorm_num_warps}
_add_rms_norm_fwd_fused[(M,)](
x_arg,
residual_arg,
y_arg,
weight,
x_arg.stride(0),
x_arg.stride(1),
residual_arg.stride(0),
residual_arg.stride(1),
y_arg.stride(0),
y_arg.stride(1),
N,
eps,
HAS_WEIGHT=weight is not None,
SINGLE_PASS=N <= BLOCK_SIZE,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=run_config["num_warps"],
)
return y_arg


def add_rmsnorm_forward(x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float, out=None):
y = torch.empty_like(x) if out is None else out
x_arg = x.view(-1, x.shape[-1])
residual_arg = residual.view(-1, x.shape[-1])
y_arg = y.view(-1, x.shape[-1])
assert x_arg.shape == residual_arg.shape == y_arg.shape
if weight is not None:
assert x_arg.shape[-1] == weight.shape[0]
assert y.data_ptr() == y_arg.data_ptr()
_add_rmsnorm_forward(x_arg, residual_arg, y_arg, weight, eps)
return y


def torch_rms_norm(x, weight, eps):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) * weight

Expand Down
16 changes: 16 additions & 0 deletions lightllm/models/qwen3next/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,19 @@ def _init_req_manager(self):
self.max_req_num, create_max_seq_len, None, linear_config=LinearAttCacheConfig.load_from_args()
)
return

def _token_forward_layers(self, input_embs: torch.Tensor, infer_state: Qwen3NextInferStateInfo):
next_att_normed = None
for i in range(self.layers_num):
layer: Qwen3NextTransformerLayerInfer = self.layers_infer[i]
layer_weight: Qwen3NextTransformerLayerWeight = self.trans_layers_weight[i]
has_next_layer = i + 1 < self.layers_num
input_embs, next_att_normed = layer.token_forward_with_next_att_norm(
input_embs,
infer_state,
layer_weight,
att_normed_input=next_att_normed,
next_layer=self.layers_infer[i + 1] if has_next_layer else None,
next_layer_weight=self.trans_layers_weight[i + 1] if has_next_layer else None,
)
return input_embs
Loading