diff --git a/.gitignore b/.gitignore index 9b69e2eb4..67a0db0b4 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,7 @@ dist .vscode tmp/ requirements-musa.txt -logs/ \ No newline at end of file +logs/ + +/benchmark/ +artifacts/ diff --git a/lightllm/common/basemodel/attention/fa3/fp8.py b/lightllm/common/basemodel/attention/fa3/fp8.py index acbb1315f..99fd864bf 100644 --- a/lightllm/common/basemodel/attention/fa3/fp8.py +++ b/lightllm/common/basemodel/attention/fa3/fp8.py @@ -45,9 +45,12 @@ def init_state(self): torch.arange(batch_size, device=device), self.infer_state.b_q_seq_len ) # 为了减少推理计算量,在推理外部初始化k_descale和v_descale - self.k_descale = offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) - self.v_descale = offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) - + self.k_descale = ( + offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + ) + self.v_descale = ( + offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + ) def prefill_att( self, @@ -116,20 +119,19 @@ def init_state(self): super().init_state() self.backend: Fp8Fa3AttBackend = self.backend - args_mtp_step = get_env_start_args().mtp_step - att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1) - assert self.infer_state.batch_size % (args_mtp_step + 1) == 0 - - device = self.infer_state.input_ids.device - batch_size = att_batch_size + batch_size = self.b_att_seq_len.shape[0] mem_manager = self.backend.model.mem_manager offline_scales: torch.Tensor = mem_manager.scales head_num = mem_manager.head_num # 为了减少推理计算量,在推理外部初始化k_descale和v_descale - self.k_descale = offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) - self.v_descale = offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + self.k_descale = ( + offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + ) + self.v_descale = ( + offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + ) return @@ -180,11 +182,11 @@ def _fp8_decode_att( k_cache=cache_k, v_cache=cache_v, page_table=self.page_table, - cache_seqlens=self.infer_state.b_seq_len, + cache_seqlens=self.b_att_seq_len, cu_seqlens_q=self.cu_seqlens_q, cu_seqlens_k_new=self.cu_seqlens_k, max_seqlen_q=self.decode_max_q_seq_len, - causal=False, + causal=True, window_size=(-1, -1), softcap=0.0, q_descale=q_scale.view(self.infer_state.batch_size, k_head_num), diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 5437c2436..dd681559e 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -4,6 +4,7 @@ import gc import copy import json +import math import torch import torch.nn.functional as F import triton @@ -337,6 +338,7 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0) infer_state.b_req_idx = model_input.b_req_idx infer_state.b_seq_len = model_input.b_seq_len infer_state.b_mtp_index = model_input.b_mtp_index + infer_state.b_num_accepted_tokens = model_input.b_num_accepted_tokens if model_input.is_prefill: if model_input.b_ready_cache_len is not None: infer_state.b_ready_cache_len = model_input.b_ready_cache_len @@ -379,6 +381,16 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0) return infer_state + def _get_decode_padding_unit(self, model_input: ModelInput) -> int: + padding_unit = self.tp_world_size_ if self.args.enable_tpsp_mix_mode else 1 + if (not model_input.is_prefill) and self.args.mtp_step > 0: + padding_unit = math.lcm(padding_unit, self.args.mtp_step + 1) + return padding_unit + + def _get_decode_infer_batch_size(self, model_input: ModelInput) -> int: + padding_unit = self._get_decode_padding_unit(model_input) + return triton.cdiv(model_input.batch_size, padding_unit) * padding_unit + def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_size: int): if model_input.batch_size == new_batch_size: return model_input @@ -388,8 +400,27 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s padded_batch_size = new_batch_size - model_input.batch_size new_model_input = copy.copy(model_input) new_model_input.batch_size = new_batch_size - new_model_input.total_token_num += padded_batch_size * 2 - new_model_input.max_kv_seq_len = max(2, model_input.max_kv_seq_len) + + is_mtp_grouped_decode = (not model_input.is_prefill) and self.args.mtp_step > 0 + if is_mtp_grouped_decode: + mtp_size = self.args.mtp_step + 1 + assert padded_batch_size % mtp_size == 0 + padded_req_num = padded_batch_size // mtp_size + new_model_input.total_token_num += padded_req_num * (mtp_size * (mtp_size + 3) // 2) + new_model_input.max_kv_seq_len = max(mtp_size + 1, model_input.max_kv_seq_len) + pad_seq_len = torch.arange( + 2, mtp_size + 2, dtype=new_model_input.b_seq_len.dtype, device=new_model_input.b_seq_len.device + ).repeat(padded_req_num) + new_model_input.b_seq_len = torch.cat((new_model_input.b_seq_len, pad_seq_len), dim=0) + # b_num_accepted_tokens 不再随 model_input 流转/补齐:它在 GDN 的 init_mtp_verify_extra_state + # 里按 req_first 从 req_to_accept_len gather,padding 组 req_first=HOLD(槽恒为 1)自然得 1。 + else: + new_model_input.total_token_num += padded_batch_size * 2 + new_model_input.max_kv_seq_len = max(2, model_input.max_kv_seq_len) + new_model_input.b_seq_len = F.pad( + new_model_input.b_seq_len, (0, padded_batch_size), mode="constant", value=2 + ) + new_model_input.input_ids = F.pad(new_model_input.input_ids, (0, padded_batch_size), mode="constant", value=1) new_model_input.b_req_idx = F.pad( new_model_input.b_req_idx, (0, padded_batch_size), mode="constant", value=self.req_manager.HOLD_REQUEST_ID @@ -397,7 +428,6 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s new_model_input.b_mtp_index = F.pad( new_model_input.b_mtp_index, (0, padded_batch_size), mode="constant", value=0 ) - new_model_input.b_seq_len = F.pad(new_model_input.b_seq_len, (0, padded_batch_size), mode="constant", value=2) new_model_input.mem_indexes = F.pad( new_model_input.mem_indexes, (0, padded_batch_size), @@ -576,10 +606,7 @@ def _decode( ) origin_batch_size = model_input.batch_size - if self.args.enable_tpsp_mix_mode: - infer_batch_size = triton.cdiv(model_input.batch_size, self.tp_world_size_) * self.tp_world_size_ - else: - infer_batch_size = model_input.batch_size + infer_batch_size = self._get_decode_infer_batch_size(model_input) if self.graph is not None and self.graph.can_run( batch_size=infer_batch_size, max_len_in_batch=model_input.max_kv_seq_len @@ -831,7 +858,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode origin_batch_size = model_input0.batch_size max_len_in_batch = max(model_input0.max_kv_seq_len, model_input1.max_kv_seq_len) - infer_batch_size = triton.cdiv(origin_batch_size, self.tp_world_size_) * self.tp_world_size_ + infer_batch_size = self._get_decode_infer_batch_size(model_input0) if self.graph is not None and self.graph.can_run(infer_batch_size, max_len_in_batch): infer_batch_size = self.graph.find_closest_graph_batch_size(infer_batch_size) @@ -1202,12 +1229,7 @@ def _init_padded_req(self): def _gen_special_model_input(self, token_num: int): special_model_input = {} - is_mtp_draft_model = ( - "Deepseek3MTPModel" in str(self.__class__) - or "Qwen3MOEMTPModel" in str(self.__class__) - or "MistralMTPModel" in str(self.__class__) - or "Glm4MoeLiteMTPModel" in str(self.__class__) - ) + is_mtp_draft_model = getattr(self, "is_mtp_draft_model", False) if is_mtp_draft_model: special_model_input["mtp_draft_input_hiddens"] = torch.randn( token_num, self.config["hidden_size"], dtype=self.data_type, device="cuda" diff --git a/lightllm/common/basemodel/batch_objs.py b/lightllm/common/basemodel/batch_objs.py index 1795ff9a8..03cb36d28 100644 --- a/lightllm/common/basemodel/batch_objs.py +++ b/lightllm/common/basemodel/batch_objs.py @@ -53,6 +53,8 @@ class ModelInput: # 的 draft 模型的输入 mtp_draft_input_hiddens: Optional[torch.Tensor] = None + b_num_accepted_tokens: Optional[torch.Tensor] = None + def to_cuda(self): if self.input_ids is not None: self.input_ids = self.input_ids.cuda(non_blocking=True) @@ -66,6 +68,8 @@ def to_cuda(self): self.b_req_idx = self.b_req_idx.cuda(non_blocking=True) self.b_seq_len = self.b_seq_len.cuda(non_blocking=True) self.b_mtp_index = self.b_mtp_index.cuda(non_blocking=True) + if self.b_num_accepted_tokens is not None: + self.b_num_accepted_tokens = self.b_num_accepted_tokens.cuda(non_blocking=True) if self.b_ready_cache_len is not None: self.b_ready_cache_len = self.b_ready_cache_len.cuda(non_blocking=True) if self.b_prefill_start_loc is not None: diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 5e8036ee8..53daf88be 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -2,6 +2,7 @@ import torch import copy import bisect +import math import triton from typing import Optional from lightllm.utils.log_utils import init_logger @@ -32,33 +33,43 @@ def __init__(self, max_batch_size=8, max_len_in_batch=8192, tp_world_size: int = self.enable_decode_microbatch_overlap = self.args.enable_decode_microbatch_overlap self.torch_memory_saver = TorchMemorySaverWrapper(self.args.enable_torch_memory_saver) - # gen cuda graph batch_sizes - # cuda graph gen for batch size = [1, 2, 3, ..., graph_split_batch_size] - # and [graph_split_batch_size + graph_grow_step_size, - # if the mtp_step is not 0, then the batch_sizes will be multiply of (mtp_step + 1) + # With MTP enabled, both the main-model verify forward and the draft (MTP) forward run over + # the (mtp_step+1)-expanded decode layout, so every decode batch size is a multiple of + # (mtp_step+1) and there is a single decode layout — the graph is keyed by batch size alone. + batch_size_multiple = self.mtp_step + 1 if self.mtp_step > 0 else 1 + self.cuda_graph_batch_sizes = self._build_cuda_graph_batch_sizes(batch_size_multiple=batch_size_multiple) + logger.info(f"cuda graph batch_sizes: {self.cuda_graph_batch_sizes}") - graph_split_batch_size = self.args.graph_split_batch_size * (self.mtp_step + 1) - graph_grow_step_size = self.args.graph_grow_step_size * (self.mtp_step + 1) + def _build_cuda_graph_batch_sizes(self, batch_size_multiple: int): + graph_split_batch_size = self.args.graph_split_batch_size * batch_size_multiple + graph_grow_step_size = self.args.graph_grow_step_size * batch_size_multiple - batch_sizes = [i * (self.mtp_step + 1) for i in range(1, self.args.graph_split_batch_size + 1)] - for _batch_size in range(graph_split_batch_size + graph_grow_step_size, max_batch_size, graph_grow_step_size): + batch_sizes = [i * batch_size_multiple for i in range(1, self.args.graph_split_batch_size + 1)] + for _batch_size in range( + graph_split_batch_size + graph_grow_step_size, + self.max_batch_size, + graph_grow_step_size, + ): batch_sizes.append(_batch_size) - batch_sizes = list(set([e for e in batch_sizes if e < max_batch_size])) - batch_sizes.append(max_batch_size) + batch_sizes = list(set([e for e in batch_sizes if e < self.max_batch_size])) + batch_sizes.append(self.max_batch_size) batch_sizes.sort() if self.args.enable_tpsp_mix_mode: - batch_sizes = [triton.cdiv(e, self.tp_world_size) * self.tp_world_size for e in batch_sizes] + padding_unit = math.lcm(self.tp_world_size, batch_size_multiple) + batch_sizes = [triton.cdiv(e, padding_unit) * padding_unit for e in batch_sizes] batch_sizes = list(set(batch_sizes)) batch_sizes.sort() - self.cuda_graph_batch_sizes = batch_sizes assert batch_sizes[-1] == self.max_batch_size - logger.info(f"cuda graph batch_sizes: {self.cuda_graph_batch_sizes}") + return batch_sizes def can_run(self, batch_size, max_len_in_batch): return batch_size <= self.max_batch_size and max_len_in_batch <= self.graph_max_len_in_batch + def _decode_graph_key(self, infer_state: InferStateInfo): + return infer_state.input_ids.shape[0] + def need_capture(self, batch_size): find_batch_size = self.find_closest_graph_batch_size(batch_size) if find_batch_size is not None: @@ -74,6 +85,54 @@ def find_closest_graph_batch_size(self, batch_size): else: return None + def _build_warmup_decode_model_input( + self, + model, + batch_size: int, + device: str = "cuda", + ) -> ModelInput: + mtp_size = self.mtp_step + 1 + input_ids = torch.ones(batch_size, dtype=torch.int32, device=device) + mem_indexes = model.mem_manager.alloc(batch_size).to(device) + b_req_idx = torch.full( + (batch_size,), + fill_value=model.req_manager.HOLD_REQUEST_ID, + dtype=torch.int32, + device=device, + ) + + b_num_accepted_tokens = None + if self.mtp_step > 0: + assert batch_size % mtp_size == 0, "MTP decode CUDA graph batch size must be a multiple of mtp_step + 1" + real_batch_size = batch_size // mtp_size + b_mtp_index = torch.arange(mtp_size, dtype=torch.int32, device=device).repeat(real_batch_size) + b_seq_len = torch.arange(2, mtp_size + 2, dtype=torch.int32, device=device).repeat(real_batch_size) + # b_num_accepted_tokens 不再随 model_input 传入:GDN 的 init_mtp_verify_extra_state 会按 + # req_first(全 HOLD,槽恒为 1) gather,warmup/capture 自然得到全 1,等价旧的 torch.ones。 + total_token_num = real_batch_size * (mtp_size * (mtp_size + 3) // 2) + else: + seq_len = 2 + total_token_num = batch_size * seq_len + b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device=device) + b_seq_len = torch.empty(batch_size, dtype=torch.int32, device=device) + b_seq_len.fill_(seq_len) + + return ModelInput( + batch_size=batch_size, + total_token_num=total_token_num, + max_q_seq_len=1, + max_kv_seq_len=self.graph_max_len_in_batch, + input_ids=input_ids, + mem_indexes=mem_indexes, + b_req_idx=b_req_idx, + b_seq_len=b_seq_len, + b_mtp_index=b_mtp_index, + b_num_accepted_tokens=b_num_accepted_tokens, + is_prefill=False, + multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size)], + **model._gen_special_model_input(batch_size), + ) + def _capture_decode(self, decode_func, infer_state: InferStateInfo): graph_obj = torch.cuda.CUDAGraph() input_ids = infer_state.input_ids @@ -101,7 +160,11 @@ def _capture_decode(self, decode_func, infer_state: InferStateInfo): with self.torch_memory_saver.cuda_graph(graph_obj, pool=self.mempool): model_output = decode_func(infer_state) - self.graph[batch_size] = (graph_obj, infer_state, model_output) + self.graph[self._decode_graph_key(infer_state)] = ( + graph_obj, + infer_state, + model_output, + ) graph_obj.replay() return model_output @@ -135,7 +198,7 @@ def _capture_decode_overlap( with self.torch_memory_saver.cuda_graph(graph_obj, pool=self.mempool): model_output, model_output1 = decode_func(infer_state, infer_state1) - self.graph[batch_size] = ( + self.graph[self._decode_graph_key(infer_state)] = ( graph_obj, infer_state, infer_state1, @@ -162,8 +225,7 @@ def capture_decode( return self._capture_decode(decode_func, infer_state) def _replay(self, infer_state: InferStateInfo): - batch_size = infer_state.input_ids.shape[0] - graph_obj, graph_infer_state, graph_output = self.graph[batch_size] + graph_obj, graph_infer_state, graph_output = self.graph[self._decode_graph_key(infer_state)] graph_infer_state.copy_for_cuda_graph(infer_state) graph_obj.replay() return graph_output @@ -173,14 +235,13 @@ def _replay_overlap( infer_state: InferStateInfo, infer_state1: InferStateInfo, ): - batch_size = infer_state.input_ids.shape[0] ( graph_obj, graph_infer_state, graph_infer_state1, graph_model_output, graph_model_output1, - ) = self.graph[batch_size] + ) = self.graph[self._decode_graph_key(infer_state)] graph_infer_state.copy_for_cuda_graph(infer_state) graph_infer_state1.copy_for_cuda_graph(infer_state1) graph_obj.replay() @@ -203,38 +264,9 @@ def warmup(self, model): # decode cuda graph init for batch_size in self.cuda_graph_batch_sizes[::-1]: - seq_len = 2 - total_token_num = batch_size * seq_len - max_len_in_batch = self.graph_max_len_in_batch - input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda") - mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda() - b_req_idx = torch.tensor( - [model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda" - ) - b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda") - b_seq_len.fill_(seq_len) - b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - - model_input = ModelInput( - batch_size=batch_size, - total_token_num=total_token_num, - max_q_seq_len=1, - max_kv_seq_len=max_len_in_batch, - input_ids=input_ids, - mem_indexes=mem_indexes, - b_req_idx=b_req_idx, - b_seq_len=b_seq_len, - b_mtp_index=b_mtp_index, - is_prefill=False, - multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size)], - **model._gen_special_model_input(batch_size), - ) + model_input = self._build_warmup_decode_model_input(model, batch_size) model_output: ModelOutput = model.forward(model_input) del model_output - del input_ids - del mem_indexes - del b_req_idx - del b_seq_len model.mem_manager.free_all() model.req_manager.free_all() @@ -261,32 +293,7 @@ def warmup_overlap(self, model): decode_batches = [] for micro_batch_index in [0, 1]: # dummy decoding, capture the cudagraph - seq_len = 2 - total_token_num = batch_size * seq_len - max_len_in_batch = self.graph_max_len_in_batch - input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda") - mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda() - b_req_idx = torch.tensor( - [model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda" - ) - b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda") - b_seq_len.fill_(seq_len) - b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - - micro_batch = ModelInput( - is_prefill=False, - batch_size=batch_size, - total_token_num=total_token_num, - max_q_seq_len=1, - max_kv_seq_len=max_len_in_batch, - input_ids=input_ids, - b_mtp_index=b_mtp_index, - mem_indexes=mem_indexes, - b_req_idx=b_req_idx, - b_seq_len=b_seq_len, - multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size)], - **model._gen_special_model_input(batch_size), - ) + micro_batch = self._build_warmup_decode_model_input(model, batch_size) decode_batches.append(micro_batch) del micro_batch diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index e09452b5a..654880185 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -39,6 +39,8 @@ def __init__(self): self.b_mtp_index: torch.Tensor = None + self.b_num_accepted_tokens: torch.Tensor = None + self.b_seq_len: torch.Tensor = None # max_cache_len 用于 prefill 阶段标识请求中最大 cache的kv 的长度 self.max_cache_len: int = None diff --git a/lightllm/common/basemodel/mtp_verify_extra_state.py b/lightllm/common/basemodel/mtp_verify_extra_state.py new file mode 100644 index 000000000..95bfce938 --- /dev/null +++ b/lightllm/common/basemodel/mtp_verify_extra_state.py @@ -0,0 +1,26 @@ +import torch + +from lightllm.utils.envs_utils import get_env_start_args + + +def init_mtp_verify_extra_state(self, model): + self.b_att_seq_len = self.b_seq_len + mtp_step = get_env_start_args().mtp_step + self.b_buffer_idx = self.b_req_idx * (mtp_step + 1) + self.b_mtp_index + self.b_conv_buffer_idx = self.b_req_idx + self.is_mtp_verify = (mtp_step > 0) and (not self.is_prefill) and (self.b_mtp_index is not None) + self.b_gdn_verify_cu_seqlens = None + self.b_ssm_index_rows = None + if self.is_mtp_verify: + step = mtp_step + 1 + n_real = self.b_req_idx.shape[0] // step + self.b_gdn_verify_cu_seqlens = torch.arange( + 0, (n_real + 1) * step, step, dtype=torch.int32, device=self.b_req_idx.device + ) + req_first = self.b_req_idx.view(n_real, step)[:, 0] + base = (req_first * step).view(n_real, 1) + self.b_ssm_index_rows = base + torch.arange(step, device=base.device, dtype=base.dtype).view(1, step) + assert self.b_ssm_index_rows.shape == (n_real, step) + self.b_conv_buffer_idx = req_first + self.b_num_accepted_tokens = model.req_manager.req_to_accept_len[req_first] + return diff --git a/lightllm/common/basemodel/triton_kernel/linear_att_copy.py b/lightllm/common/basemodel/triton_kernel/linear_att_copy.py index d9f631cbd..fd4c16043 100644 --- a/lightllm/common/basemodel/triton_kernel/linear_att_copy.py +++ b/lightllm/common/basemodel/triton_kernel/linear_att_copy.py @@ -5,12 +5,13 @@ @triton.jit def _copy_linear_att_state_to_kv_buffer( - gpu_conv_ptr, # [linear_layer_num, size_num, xdim] + gpu_conv_ptr, # [linear_layer_num, size_num, conv_dim * gpu_widened_width] (uint8 tail) gpu_ssm_ptr, # [linear_layer_num, size_num, xxdim] - cpu_kv_conv_ptr, # [size, linear_layer_num, xdim] + cpu_kv_conv_ptr, # [size, linear_layer_num, conv_dim * width_narrow] (uint8 tail) cpu_kv_ssm_ptr, # [size, linear_layer_num, xxdim] b_req_idx, # [batch_size,] big_page_buffer_ids, # [batch_size,] + num_accepted_tokens_ptr, # [batch_size,] gpu_conv_stride_l, gpu_conv_stride_s, gpu_conv_stride_d, @@ -24,7 +25,9 @@ def _copy_linear_att_state_to_kv_buffer( cpu_kv_ssm_stride_l, cpu_kv_ssm_stride_d, mtp_step, - gpu_conv_tail_dim, + conv_dim, # number of conv rows (the d dimension) + gpu_conv_row_bytes, # widened per-row byte length: gpu_widened_width * itemsize + conv_narrow_row_bytes, # narrow per-row byte length: width_narrow * itemsize gpu_ssm_tail_dim, BLOCK: tl.constexpr, ): @@ -40,28 +43,26 @@ def _copy_linear_att_state_to_kv_buffer( return cur_req_idx = tl.load(b_req_idx + cur_batch).to(tl.int64) - cur_state_req_idx = (cur_req_idx * (mtp_step + 1)).to(tl.int64) + accept_len = tl.load(num_accepted_tokens_ptr + cur_batch).to(tl.int64) + canonical_off = accept_len - 1 - for i in range(tl.cdiv(gpu_conv_tail_dim, BLOCK)): - gpu_start_off = i * BLOCK + tl.arange(0, BLOCK) - mask = gpu_start_off < gpu_conv_tail_dim - conv_data = tl.load( - gpu_conv_ptr + cur_layer * gpu_conv_stride_l + cur_state_req_idx * gpu_conv_stride_s + gpu_start_off, - mask=mask, - ) - dest_conv_ptr = ( - cpu_kv_conv_ptr - + big_page_buffer_idx * cpu_kv_conv_stride_s - + cur_layer * cpu_kv_conv_stride_l - + gpu_start_off - ) - tl.store(dest_conv_ptr, conv_data, mask=mask) + conv_src_slot = cur_req_idx + conv_off_bytes = canonical_off * gpu_conv_stride_d + gpu_conv_base = gpu_conv_ptr + cur_layer * gpu_conv_stride_l + conv_src_slot * gpu_conv_stride_s + conv_off_bytes + cpu_conv_base = cpu_kv_conv_ptr + big_page_buffer_idx * cpu_kv_conv_stride_s + cur_layer * cpu_kv_conv_stride_l + for d in range(conv_dim): + for i in range(tl.cdiv(conv_narrow_row_bytes, BLOCK)): + off = i * BLOCK + tl.arange(0, BLOCK) + mask = off < conv_narrow_row_bytes + conv_data = tl.load(gpu_conv_base + d * gpu_conv_row_bytes + off, mask=mask) + tl.store(cpu_conv_base + d * cpu_kv_conv_stride_d + off, conv_data, mask=mask) + ssm_src_slot = (cur_req_idx * (mtp_step + 1) + canonical_off).to(tl.int64) for i in range(tl.cdiv(gpu_ssm_tail_dim, BLOCK)): gpu_start_off = i * BLOCK + tl.arange(0, BLOCK) mask = gpu_start_off < gpu_ssm_tail_dim ssm_data = tl.load( - gpu_ssm_ptr + cur_layer * gpu_ssm_stride_l + cur_state_req_idx * gpu_ssm_stride_s + gpu_start_off, + gpu_ssm_ptr + cur_layer * gpu_ssm_stride_l + ssm_src_slot * gpu_ssm_stride_s + gpu_start_off, mask=mask, ) dest_ssm_ptr = ( @@ -75,32 +76,51 @@ def _copy_linear_att_state_to_kv_buffer( def copy_linear_att_state_to_kv_buffer( b_req_idx: torch.Tensor, big_page_buffer_ids: torch.Tensor, - gpu_conv_state: torch.Tensor, # [linear_layer_num, s, ...] - gpu_ssm_state: torch.Tensor, # [linear_layer_num, s, ...] - cpu_kv_conv_state: torch.Tensor, # [s, linear_layer_num, ...] - cpu_kv_ssm_state: torch.Tensor, # [s, linear_layer_num, ...] + gpu_conv_state: torch.Tensor, # [linear_layer_num, s_widened, conv_dim, gpu_widened_width] + gpu_ssm_state: torch.Tensor, # [linear_layer_num, s_block, ...] + cpu_kv_conv_state: torch.Tensor, # [size, linear_layer_num, conv_dim, width_narrow] + cpu_kv_ssm_state: torch.Tensor, # [size, linear_layer_num, ...] mtp_step: int, + b_num_accepted_tokens: torch.Tensor, # [batch_size,] per-req post-accept count (>=1) ): assert len(b_req_idx) == big_page_buffer_ids.shape[0] + assert len(b_req_idx) == b_num_accepted_tokens.shape[0] BLOCK = 4096 - gpu_conv_state = gpu_conv_state.view(gpu_conv_state.shape[0], gpu_conv_state.shape[1], -1).view(dtype=torch.uint8) - gpu_ssm_state = gpu_ssm_state.view(gpu_ssm_state.shape[0], gpu_ssm_state.shape[1], -1).view(dtype=torch.uint8) - cpu_kv_conv_state = cpu_kv_conv_state.view(cpu_kv_conv_state.shape[0], cpu_kv_conv_state.shape[1], -1).view( - dtype=torch.uint8 + + assert gpu_conv_state.dim() >= 4, "gpu_conv_state must be [layer, s, conv_dim, widened_width]" + assert cpu_kv_conv_state.dim() >= 4, "cpu_kv_conv_state must be [size, layer, conv_dim, width_narrow]" + # #6: the byte snapshot hardcodes gpu_conv_stride_d=conv_itemsize, which is only valid when the + # widened-width axis is element-contiguous (stride 1). Fail fast instead of snapshotting wrong bytes. + assert gpu_conv_state.stride(3) == 1, ( + "gpu_conv_state widened-width axis must be element-contiguous (stride 1); " + "gpu_conv_stride_d=conv_itemsize assumes it" + ) + # #18: canonical_off = accept_len - 1 indexes into the widened slot; bound it to [0, mtp_step] + # (accept_len in [1, mtp_step+1]) so a stale/oversized accept-count can't slice past the slot. + assert int(b_num_accepted_tokens.min()) >= 1 and int(b_num_accepted_tokens.max()) <= mtp_step + 1, ( + f"b_num_accepted_tokens out of range [1, {mtp_step + 1}]: " + f"min={int(b_num_accepted_tokens.min())} max={int(b_num_accepted_tokens.max())}" ) + conv_itemsize = gpu_conv_state.element_size() + gpu_conv_state = gpu_conv_state.view( + gpu_conv_state.shape[0], gpu_conv_state.shape[1], gpu_conv_state.shape[2], -1 + ).view(dtype=torch.uint8) + cpu_kv_conv_state = cpu_kv_conv_state.view( + cpu_kv_conv_state.shape[0], cpu_kv_conv_state.shape[1], cpu_kv_conv_state.shape[2], -1 + ).view(dtype=torch.uint8) + + gpu_ssm_state = gpu_ssm_state.view(gpu_ssm_state.shape[0], gpu_ssm_state.shape[1], -1).view(dtype=torch.uint8) cpu_kv_ssm_state = cpu_kv_ssm_state.view(cpu_kv_ssm_state.shape[0], cpu_kv_ssm_state.shape[1], -1).view( dtype=torch.uint8 ) - assert gpu_conv_state.shape[-1] == cpu_kv_conv_state.shape[-1] + + assert gpu_conv_state.shape[2] == cpu_kv_conv_state.shape[2], "conv_dim mismatch between gpu and cpu conv buffers" assert gpu_ssm_state.shape[-1] == cpu_kv_ssm_state.shape[-1] - assert ( - gpu_conv_state.stride(-1) - == gpu_ssm_state.stride(-1) - == cpu_kv_conv_state.stride(-1) - == cpu_kv_ssm_state.stride(-1) - ) - gpu_conv_tail_dim = gpu_conv_state.shape[-1] + conv_dim = gpu_conv_state.shape[2] + gpu_conv_row_bytes = gpu_conv_state.shape[-1] # widened per-row byte length + conv_narrow_row_bytes = cpu_kv_conv_state.shape[-1] # narrow per-row byte length + assert conv_narrow_row_bytes <= gpu_conv_row_bytes gpu_ssm_tail_dim = gpu_ssm_state.shape[-1] layer_num = gpu_conv_state.shape[0] @@ -114,9 +134,10 @@ def copy_linear_att_state_to_kv_buffer( cpu_kv_ssm_ptr=cpu_kv_ssm_state, b_req_idx=b_req_idx, big_page_buffer_ids=big_page_buffer_ids, + num_accepted_tokens_ptr=b_num_accepted_tokens, gpu_conv_stride_l=gpu_conv_state.stride(0), gpu_conv_stride_s=gpu_conv_state.stride(1), - gpu_conv_stride_d=gpu_conv_state.stride(2), + gpu_conv_stride_d=conv_itemsize, gpu_ssm_stride_l=gpu_ssm_state.stride(0), gpu_ssm_stride_s=gpu_ssm_state.stride(1), gpu_ssm_stride_d=gpu_ssm_state.stride(2), @@ -127,7 +148,9 @@ def copy_linear_att_state_to_kv_buffer( cpu_kv_ssm_stride_l=cpu_kv_ssm_state.stride(1), cpu_kv_ssm_stride_d=cpu_kv_ssm_state.stride(2), mtp_step=mtp_step, - gpu_conv_tail_dim=gpu_conv_tail_dim, + conv_dim=conv_dim, + gpu_conv_row_bytes=gpu_conv_row_bytes, + conv_narrow_row_bytes=conv_narrow_row_bytes, gpu_ssm_tail_dim=gpu_ssm_tail_dim, BLOCK=BLOCK, ) diff --git a/lightllm/common/basemodel/triton_kernel/linear_att_cpu_cache_copy.py b/lightllm/common/basemodel/triton_kernel/linear_att_cpu_cache_copy.py index 37b27cadb..1251dddc3 100644 --- a/lightllm/common/basemodel/triton_kernel/linear_att_cpu_cache_copy.py +++ b/lightllm/common/basemodel/triton_kernel/linear_att_cpu_cache_copy.py @@ -193,11 +193,7 @@ def copy_kv_buffer_to_cpu_cache( cpu_kv_ssm_tail_dim = cpu_kv_ssm_state.shape[-1] full_att_layer_num = gpu_kv_full_att_state.shape[-2] - assert ( - full_att_layer_num - == (linear_config.all_layer_num // linear_config.full_attention_interval) - == (linear_config.all_layer_num - linear_config.linear_layer_num) - ) + assert full_att_layer_num == linear_config.get_persisted_full_att_layer_num() assert gpu_full_att_tail_dim == cpu_cache_full_att.shape[-1] assert cpu_cache_conv.shape[-1] == cpu_kv_conv_state.shape[-1] assert cpu_cache_ssm.shape[-1] == cpu_kv_ssm_state.shape[-1] @@ -428,6 +424,7 @@ def copy_cpu_cache_to_kv_buffer( cpu_kv_ssm_tail_dim = cpu_kv_ssm_state.shape[-1] full_att_layer_num = gpu_full_att_kv_state.shape[-2] + assert full_att_layer_num == linear_config.get_persisted_full_att_layer_num() assert gpu_full_att_tail_dim == cpu_cache_full_att.shape[-1] assert cpu_cache_conv.shape[-1] == cpu_kv_conv_state.shape[-1] assert cpu_cache_ssm.shape[-1] == cpu_kv_ssm_state.shape[-1] diff --git a/lightllm/common/basemodel/triton_kernel/mtp_utils.py b/lightllm/common/basemodel/triton_kernel/mtp_utils.py index 2d70a68c0..bdd59c65e 100644 --- a/lightllm/common/basemodel/triton_kernel/mtp_utils.py +++ b/lightllm/common/basemodel/triton_kernel/mtp_utils.py @@ -149,35 +149,48 @@ def mtp_scatter_next_token_ids( @triton.jit -def _fwd_kernel_gen_b_req_mtp_start_loc( - b_mtp_index, +def _fwd_kernel_scatter_accept_len( + req_to_accept_len, b_req_mtp_start_loc, - num_reqs: tl.constexpr, - batch_size: tl.constexpr, - BLOCK_SIZE: tl.constexpr, + b_req_idx, + mtp_accept_len, ): - offset = tl.arange(0, BLOCK_SIZE) - cur_mtp_index = tl.load(b_mtp_index + offset, mask=offset < batch_size, other=-1) - non_zero_mask = tl.where(cur_mtp_index == 0, 1, 0) # 1 0 1 0 0 - output_offset = tl.cumsum(non_zero_mask) - 1 - tl.store(b_req_mtp_start_loc + output_offset, offset, mask=non_zero_mask == 1) + cur_index = tl.program_id(0) + req_start_loc = tl.load(b_req_mtp_start_loc + cur_index) + cur_req_idx = tl.load(b_req_idx + req_start_loc) + accept_len = tl.load(mtp_accept_len + cur_index) + tl.store(req_to_accept_len + cur_req_idx, accept_len) return -def gen_b_req_mtp_start_loc(b_mtp_index: torch.Tensor, num_reqs: int): - b_req_mtp_start_loc = torch.empty((num_reqs,), dtype=torch.int32, device=b_mtp_index.device) - BLOCK_SIZE = triton.next_power_of_2(b_mtp_index.shape[0]) - batch_size = b_mtp_index.shape[0] - grid = (1,) - _fwd_kernel_gen_b_req_mtp_start_loc[grid]( - b_mtp_index=b_mtp_index, +def scatter_mtp_accept_len( + req_to_accept_len: torch.Tensor, + b_req_mtp_start_loc: torch.Tensor, + b_req_idx: torch.Tensor, + mtp_accept_len: torch.Tensor, +): + """ + 将本步每个真实请求(组首)的 accept 数量写入 GPU 常驻的 req_to_accept_len[req_idx]。 + 融合 `req_to_accept_len[b_req_idx[b_req_mtp_start_loc]] = mtp_accept_len` 的 gather+scatter + 为单次 launch、无中间张量。每个 program 处理一个真实请求。 + Args: + req_to_accept_len: (max_req_num + 1,) + b_req_mtp_start_loc: (num_reqs,) 每组首行在 batch 中的偏移 + b_req_idx: (batch_size,) grouped 布局的 req_idx(组首即该请求的 req_idx) + mtp_accept_len: (num_reqs,) + """ + num_reqs = mtp_accept_len.shape[0] + if num_reqs == 0: + return + grid = (num_reqs,) + _fwd_kernel_scatter_accept_len[grid]( + req_to_accept_len=req_to_accept_len, b_req_mtp_start_loc=b_req_mtp_start_loc, - num_reqs=num_reqs, - batch_size=batch_size, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=8, + b_req_idx=b_req_idx, + mtp_accept_len=mtp_accept_len, + num_warps=1, + num_stages=1, ) - return b_req_mtp_start_loc def test_mtp_verify(): @@ -201,13 +214,5 @@ def test_mtp_verify(): print(accepted_index) -def test_gen_b_req_mtp_start_loc(): - b_mtp_index = torch.tensor([0, 1, 0, 1, 2], dtype=torch.int32, device="cuda") - gt_output = torch.where(b_mtp_index == 0)[0] - b_req_mtp_start_loc = gen_b_req_mtp_start_loc(b_mtp_index, 2) - print(b_req_mtp_start_loc, gt_output) - - if __name__ == "__main__": test_mtp_verify() - # test_gen_b_req_mtp_start_loc() diff --git a/lightllm/common/kv_cache_mem_manager/operator/linear_att.py b/lightllm/common/kv_cache_mem_manager/operator/linear_att.py index 109e81322..e3ae9493c 100644 --- a/lightllm/common/kv_cache_mem_manager/operator/linear_att.py +++ b/lightllm/common/kv_cache_mem_manager/operator/linear_att.py @@ -24,6 +24,16 @@ def __init__(self, mem_manager): super().__init__(mem_manager) self.linear_config = LinearAttCacheConfig.load_from_args() + @staticmethod + def _get_persisted_full_att_layer_num(mem_manager) -> int: + persisted_full_att = getattr(mem_manager, "persisted_full_att_layer_num", None) + if persisted_full_att is None: + main_full_att = getattr(mem_manager, "main_full_att_layer_num", mem_manager.kv_buffer.shape[0]) + draft_full_att = getattr(mem_manager, "draft_full_att_layers", 0) + persisted_full_att = main_full_att + draft_full_att + assert 0 < persisted_full_att <= mem_manager.kv_buffer.shape[0] + return int(persisted_full_att) + def load_cpu_cache_to_gpu( self, mem_indexes: torch.Tensor, @@ -76,11 +86,14 @@ def load_cpu_cache_to_gpu( copy_cpu_cache_to_kv_buffer, ) + # Restore the persisted full-attn slice: main slots followed by MTP draft slots. + persisted_full_att = self._get_persisted_full_att_layer_num(mem_manager) + copy_cpu_cache_to_kv_buffer( mem_indexes=mem_indexes, big_page_buffer_ids=big_page_buffer_ids_gpu, page_indexes=page_indexes, - gpu_full_att_kv_state=mem_manager.kv_buffer, + gpu_full_att_kv_state=mem_manager.kv_buffer[:persisted_full_att], cpu_kv_conv_state=mem_manager.linear_att_big_page_buffers.conv_state_cache.buffer, cpu_kv_ssm_state=mem_manager.linear_att_big_page_buffers.ssm_state_cache.buffer, cpu_cache_tensor=cpu_cache_client.cpu_kv_cache_tensor, @@ -169,12 +182,15 @@ def offload_gpu_kv_to_cpu_cache( copy_kv_buffer_to_cpu_cache, ) + # Persist the full-attn slice used for prefix reuse: main slots followed by MTP draft slots. + persisted_full_att = self._get_persisted_full_att_layer_num(mem_manager) + copy_kv_buffer_to_cpu_cache( mem_indexes=mem_indexes, page_indexes=page_indexes, page_readies=page_readies, big_page_buffer_ids=big_page_buffer_ids_gpu, - gpu_kv_full_att_state=mem_manager.kv_buffer, + gpu_kv_full_att_state=mem_manager.kv_buffer[:persisted_full_att], cpu_kv_conv_state=mem_manager.linear_att_big_page_buffers.conv_state_cache.buffer, cpu_kv_ssm_state=mem_manager.linear_att_big_page_buffers.ssm_state_cache.buffer, cpu_cache_tensor=cpu_cache_client.cpu_kv_cache_tensor, diff --git a/lightllm/common/kv_cache_mem_manager/qwen3next_mem_manager.py b/lightllm/common/kv_cache_mem_manager/qwen3next_mem_manager.py index c7ce9d96b..566ce5ea3 100644 --- a/lightllm/common/kv_cache_mem_manager/qwen3next_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/qwen3next_mem_manager.py @@ -208,9 +208,9 @@ def write_req_to_page( dp_mems: List["Qwen3NextMemManager"], ): conv_page, ssm_page = self.view_page_to_linear_att_state(page_index) - req_buffer_idx = req_idx * (get_env_start_args().mtp_step + 1) + conv_req_idx, ssm_req_idx = self._get_req_state_indexes(req_idx) for tp_index, mem in enumerate(dp_mems): - self._write_one_rank(mem, tp_index, req_buffer_idx, conv_page, ssm_page) + self._write_one_rank(mem, tp_index, conv_req_idx, ssm_req_idx, conv_page, ssm_page) return def read_page_to_req( @@ -220,21 +220,27 @@ def read_page_to_req( dp_mems: List["Qwen3NextMemManager"], ): conv_page, ssm_page = self.view_page_to_linear_att_state(page_index) - req_buffer_idx = req_idx * (get_env_start_args().mtp_step + 1) + conv_req_idx, ssm_req_idx = self._get_req_state_indexes(req_idx) for tp_index, mem in enumerate(dp_mems): - self._read_one_rank(mem, tp_index, req_buffer_idx, conv_page, ssm_page) + self._read_one_rank(mem, tp_index, conv_req_idx, ssm_req_idx, conv_page, ssm_page) return + def _get_req_state_indexes(self, req_idx: int): + mtp_size = get_env_start_args().mtp_step + 1 + # Conv is one widened slot per request; SSM keeps the historical S+1 block layout. + return req_idx, req_idx * mtp_size + def _write_one_rank( self, mem: "Qwen3NextMemManager", tp_index: int, - req_buffer_idx: int, + conv_req_idx: int, + ssm_req_idx: int, conv_page: torch.Tensor, ssm_page: torch.Tensor, ): - conv_state = mem.req_to_conv_state.buffer[:, req_buffer_idx, ...] - ssm_state = mem.req_to_ssm_state.buffer[:, req_buffer_idx, ...] + conv_state = mem.req_to_conv_state.buffer[:, conv_req_idx, ..., : self.conv_shape[-1]] + ssm_state = mem.req_to_ssm_state.buffer[:, ssm_req_idx, ...] self._copy_conv_state_to_page(conv_state, conv_page, mem, tp_index) self._copy_ssm_state_to_page(ssm_state, ssm_page, mem, tp_index) return @@ -408,12 +414,13 @@ def _read_one_rank( self, mem: "Qwen3NextMemManager", tp_index: int, - req_buffer_idx: int, + conv_req_idx: int, + ssm_req_idx: int, conv_page: torch.Tensor, ssm_page: torch.Tensor, ): - conv_state = mem.req_to_conv_state.buffer[:, req_buffer_idx, ...] - ssm_state = mem.req_to_ssm_state.buffer[:, req_buffer_idx, ...] + conv_state = mem.req_to_conv_state.buffer[:, conv_req_idx, ..., : self.conv_shape[-1]] + ssm_state = mem.req_to_ssm_state.buffer[:, ssm_req_idx, ...] self._copy_page_to_conv_state(conv_page, conv_state, mem, tp_index) self._copy_page_to_ssm_state(ssm_page, ssm_state, mem, tp_index) return diff --git a/lightllm/common/linear_att_cache_manager/config_objs.py b/lightllm/common/linear_att_cache_manager/config_objs.py index bc3906706..f48b9865a 100644 --- a/lightllm/common/linear_att_cache_manager/config_objs.py +++ b/lightllm/common/linear_att_cache_manager/config_objs.py @@ -8,6 +8,16 @@ logger = init_logger(__name__) +def get_mtp_draft_full_att_layer_num(args) -> int: + # mtp_mode -> draft model 增加的 full-att KV 层数(与 envs_utils.get_added_mtp_kv_layer_num 同口径)。 + mtp_mode = getattr(args, "mtp_mode", None) + if mtp_mode == "eagle_with_att": + return 1 + if mtp_mode == "vanilla_with_att": + return getattr(args, "mtp_step", 0) + return 0 + + @dataclasses.dataclass class LinearAttCacheConfig: tp_world_size: int @@ -30,6 +40,7 @@ class LinearAttCacheConfig: ssm_state_dtype: torch.dtype full_attention_interval: int all_layer_num: int # 包括 linear att 和 full att 的层加起来的层数 + draft_full_att_layer_num: int = 0 def get_conv_dim(self): # 第一项对应q的参数,第二项对应k的参数,第三项对应v的参数 @@ -41,9 +52,25 @@ def get_conv_dim(self): + self.head_linear_v_dim * self.num_linear_v_heads ) - def get_conv_state_shape(self): + def get_main_full_att_layer_num(self): + main_full_att_layer_num = self.all_layer_num - self.linear_layer_num + assert main_full_att_layer_num == self.all_layer_num // self.full_attention_interval + return main_full_att_layer_num + + def get_persisted_full_att_layer_num(self): + return self.get_main_full_att_layer_num() + self.draft_full_att_layer_num + + def get_persisted_conv_state_shape(self): + # NARROW shape used for the CPU/disk persisted page and ALL byte math. + # Persisted state is always the committed (narrow) sliding window. return (self.get_conv_dim(), self.conv_kernel_size - 1) + def get_gpu_conv_state_shape(self, mtp_step: int): + # WIDENED working shape for the GPU buffer: holds the tentatively + # rolled-in S speculative tokens before acceptance. width-1 + S, where + # S = mtp_step (a verify step has seqlen=S+1 -> width-1+(seqlen-1)). + return (self.get_conv_dim(), (self.conv_kernel_size - 1) + mtp_step) + def get_ssm_state_shape(self): return (self.num_linear_v_heads, self.head_linear_k_dim, self.head_linear_v_dim) @@ -66,7 +93,7 @@ def get_cpu_cache_full_att_bytes(self): ) assert big_page_token_num == get_env_start_args().cpu_cache_token_page_size full_att_bytes = 2 * self.full_att_all_num_kv_heads * self.full_att_head_dim * self.full_att_dtype.itemsize - a = full_att_bytes * (self.all_layer_num - self.linear_layer_num) * big_page_token_num + a = full_att_bytes * self.get_persisted_full_att_layer_num() * big_page_token_num return a def get_cpu_cache_conv_bytes(self): @@ -113,4 +140,5 @@ def load_from_args() -> "LinearAttCacheConfig": ssm_state_dtype=get_torch_dtype(args.linear_att_ssm_data_type), full_attention_interval=llm_config["full_attention_interval"], all_layer_num=n_layer, + draft_full_att_layer_num=get_mtp_draft_full_att_layer_num(args), ) diff --git a/lightllm/common/linear_att_cache_manager/linear_att_buffer_manager.py b/lightllm/common/linear_att_cache_manager/linear_att_buffer_manager.py index 30dc4d937..2ab4313e3 100644 --- a/lightllm/common/linear_att_cache_manager/linear_att_buffer_manager.py +++ b/lightllm/common/linear_att_cache_manager/linear_att_buffer_manager.py @@ -24,7 +24,7 @@ def __init__( self.conv_state_cache = LayerCache( size=self.size, dtype=self.linear_config.conv_state_dtype, - shape=self.linear_config.get_conv_state_shape(), + shape=self.linear_config.get_persisted_conv_state_shape(), layer_num=self.linear_config.linear_layer_num, device="cpu", size_first=True, diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 01e9c4ad3..0f51ac271 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -19,6 +19,18 @@ logger = init_logger(__name__) +# Width of req_to_next_token_ids: holds the seed token + up to (WIDTH - 1) MTP draft tokens. +REQ_NEXT_TOKEN_IDS_WIDTH = 8 + + +def assert_mtp_step_within_next_token_ids_width(mtp_step: int) -> None: + assert mtp_step <= REQ_NEXT_TOKEN_IDS_WIDTH - 1, ( + f"mtp_step={mtp_step} exceeds {REQ_NEXT_TOKEN_IDS_WIDTH - 1}; " + f"req_to_next_token_ids width is {REQ_NEXT_TOKEN_IDS_WIDTH} " + "(widening it is an explicit follow-up, spec §9)" + ) + + class _ReqNode: def __init__(self, index): self.index = index @@ -75,6 +87,10 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana self.max_request_num = max_request_num self.HOLD_REQUEST_ID = max_request_num + # Always allocated (init 1 = no draft tokens accepted), so the linear-att cache-copy + # paths can index it unconditionally even when MTP is disabled (mtp_step == 0). + self.req_to_accept_len = torch.ones((max_request_num + 1,), dtype=torch.int32, device="cuda") + def alloc(self): return self.req_list.alloc() @@ -117,7 +133,7 @@ def __init__(self, max_request_num): self.req_to_frequency_penalty = torch.zeros(max_request_num + 1, dtype=torch.float32, device="cuda") self.req_to_repetition_penalty = torch.zeros(max_request_num + 1, dtype=torch.float32, device="cuda") self.req_to_next_token_ids = torch.zeros( - (max_request_num + 1, 8), + (max_request_num + 1, REQ_NEXT_TOKEN_IDS_WIDTH), dtype=torch.int64, device="cuda", ) @@ -236,15 +252,13 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager, linear_con self.big_page_token_num = ( get_env_start_args().linear_att_page_block_num * get_env_start_args().linear_att_hash_page_size ) - assert ( - self.mtp_step == 0 - ), "currently only support mtp_step 0 for simplicity, more mtp_step support will be added in the future" + assert_mtp_step_within_next_token_ids_width(self.mtp_step) self.linear_config = linear_config self.req_to_conv_state = LayerCache( - size=(max_request_num + 1) * (self.mtp_step + 1), + size=(max_request_num + 1), dtype=self.linear_config.conv_state_dtype, - shape=self.linear_config.get_conv_state_shape(), + shape=self.linear_config.get_gpu_conv_state_shape(mtp_step=self.mtp_step), layer_num=self.linear_config.linear_layer_num, device="cuda", ) @@ -258,11 +272,13 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager, linear_con return def init_linear_att_state(self, req: "InferReq"): - index = req.req_idx * (self.mtp_step + 1) - conv_state = self.req_to_conv_state.buffer[:, index, ...] - ssm_state = self.req_to_ssm_state.buffer[:, index, ...] - conv_state.fill_(0) - ssm_state.fill_(0) + conv_index = req.req_idx + ssm_start = req.req_idx * (self.mtp_step + 1) + self.req_to_conv_state.buffer[:, conv_index, ...].fill_(0) + # #17: zero the FULL (mtp_step + 1)-row SSM block, not just canonical row +0, so a future + # first-step verify reading offset>0 after fresh init never hits a never-written row (NaN). + self.req_to_ssm_state.buffer[:, ssm_start : ssm_start + (self.mtp_step + 1), ...].fill_(0) + self.req_to_accept_len[req.req_idx] = 1 return def get_mamba_cache(self, layer_idx_in_all: int): @@ -281,10 +297,12 @@ def copy_big_page_buffer_to_linear_att_state(self, big_page_buffer_idx: int, req big_page_buffers: LinearAttCacheManager = self.mem_manager.linear_att_big_page_buffers conv_state, ssm_state = big_page_buffers.get_state_cache(buffer_idx=big_page_buffer_idx) - dest_req_idx = req.req_idx * (self.mtp_step + 1) - - self.req_to_conv_state.buffer[:, dest_req_idx, ...] = conv_state - self.req_to_ssm_state.buffer[:, dest_req_idx, ...] = ssm_state + conv_dest = req.req_idx + ssm_dest = req.req_idx * (self.mtp_step + 1) + narrow_w = conv_state.shape[-1] # persisted (narrow) width + self.req_to_conv_state.buffer[:, conv_dest, ..., :narrow_w] = conv_state + self.req_to_ssm_state.buffer[:, ssm_dest, ...] = ssm_state + self.req_to_accept_len[req.req_idx] = 1 return def copy_small_page_buffer_to_linear_att_state( @@ -293,9 +311,12 @@ def copy_small_page_buffer_to_linear_att_state( conv_state, ssm_state = linear_att_small_page_buffers.get_state_cache( buffer_idx=req.shared_kv_node.small_page_buffer_idx ) - dest_req_idx = req.req_idx * (self.mtp_step + 1) + conv_dest = req.req_idx + ssm_dest = req.req_idx * (self.mtp_step + 1) + narrow_w = conv_state.shape[-1] # TODO 下面这个从 cpu cache 拷贝数据的 gpu的操作,是否是阻塞的操作。 # 同时,非连续对象的拷贝,可能存在效率问题。 - self.req_to_conv_state.buffer[:, dest_req_idx, ...] = conv_state - self.req_to_ssm_state.buffer[:, dest_req_idx, ...] = ssm_state + self.req_to_conv_state.buffer[:, conv_dest, ..., :narrow_w] = conv_state + self.req_to_ssm_state.buffer[:, ssm_dest, ...] = ssm_state + self.req_to_accept_len[req.req_idx] = 1 return diff --git a/lightllm/models/qwen3_5/infer_struct.py b/lightllm/models/qwen3_5/infer_struct.py index d23475c1c..35bd6f792 100644 --- a/lightllm/models/qwen3_5/infer_struct.py +++ b/lightllm/models/qwen3_5/infer_struct.py @@ -1,8 +1,4 @@ -import torch -from typing import List - from lightllm.models.qwen2_vl.infer_struct import Qwen2VLInferStateInfo -from lightllm.utils.envs_utils import get_env_start_args class Qwen35InferStateInfo(Qwen2VLInferStateInfo): @@ -12,8 +8,7 @@ def __init__(self): def init_some_extra_state(self, model): super().init_some_extra_state(model) - self.b_att_seq_len = self.b_seq_len - mtp_step = get_env_start_args().mtp_step + from lightllm.common.basemodel.mtp_verify_extra_state import init_mtp_verify_extra_state - self.b_buffer_idx = self.b_req_idx * (mtp_step + 1) + self.b_mtp_index + init_mtp_verify_extra_state(self, model) return diff --git a/lightllm/models/qwen3_5_moe_mtp/__init__.py b/lightllm/models/qwen3_5_moe_mtp/__init__.py new file mode 100644 index 000000000..c8885f886 --- /dev/null +++ b/lightllm/models/qwen3_5_moe_mtp/__init__.py @@ -0,0 +1,3 @@ +from lightllm.models.qwen3_5_moe_mtp.model import Qwen3_5MoeMTPModel + +__all__ = ["Qwen3_5MoeMTPModel"] diff --git a/lightllm/models/qwen3_5_moe_mtp/layer_weights/__init__.py b/lightllm/models/qwen3_5_moe_mtp/layer_weights/__init__.py new file mode 100644 index 000000000..dcad1087d --- /dev/null +++ b/lightllm/models/qwen3_5_moe_mtp/layer_weights/__init__.py @@ -0,0 +1,5 @@ +from lightllm.models.qwen3_5_moe_mtp.layer_weights.transformer_layer_weight import ( + Qwen3_5MoeMTPTransformerLayerWeight, +) + +__all__ = ["Qwen3_5MoeMTPTransformerLayerWeight"] diff --git a/lightllm/models/qwen3_5_moe_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_5_moe_mtp/layer_weights/transformer_layer_weight.py new file mode 100644 index 000000000..b2700aa0b --- /dev/null +++ b/lightllm/models/qwen3_5_moe_mtp/layer_weights/transformer_layer_weight.py @@ -0,0 +1,154 @@ +from lightllm.common.basemodel.layer_weights.meta_weights import ( + COLMMWeight, + FusedMoeWeight, + ROWMMWeight, + QKVROWNMMWeight, +) +from lightllm.models.qwen3_5_moe.layer_weights.transformer_layer_weight import ( + Qwen35MOETransformerLayerWeight, +) +from lightllm.utils.envs_utils import get_env_start_args + + +class Qwen3_5MoeMTPTransformerLayerWeight(Qwen35MOETransformerLayerWeight): + # MTP draft-model weights live under the `mtp.layers.*` checkpoint namespace; the + # main-model attention/norm names (`model.layers.*`) are retargeted to it, while the + # MoE expert / shared-expert names are built directly with the mtp prefix below. + + _MAIN_PREFIX = "model.layers." + _MTP_PREFIX = "mtp.layers." + + _ATTN_NORM_NAME_ATTRS = ( + "_q_weight_name", + "_q_norm_name", + "_q_bias_name", + "_k_weight_name", + "_k_norm_name", + "_k_bias_name", + "_v_weight_name", + "_v_bias_name", + "_kv_weight_name", + "_kv_bias_name", + "_o_weight_name", + "_o_bias_name", + "_att_norm_weight_name", + "_att_norm_bias_name", + "_ffn_norm_weight_name", + "_ffn_norm_bias_name", + ) + + def _retarget(self, name): + if name is None: + return None + return name.replace(self._MAIN_PREFIX, self._MTP_PREFIX, 1) + + def _retarget_attn_norm_names(self): + for attr in self._ATTN_NORM_NAME_ATTRS: + setattr(self, attr, self._retarget(getattr(self, attr))) + + def _init_qkv(self): + in_dim = self.n_embed + q_out_dim = self.q_head_num_ * self.head_dim + self.qkv_proj = QKVROWNMMWeight( + in_dim=in_dim, + q_head_num=self.q_head_num_, + kv_head_num=self.k_head_num_, + head_dim=self.head_dim, + weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name], + data_type=self.data_type_, + bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name], + quant_method=self.get_quant_method("qkv_proj"), + ) + self._o_gate_weight_name = f"{self._MTP_PREFIX}{self.layer_num_}.self_attn.o_gate_proj.weight" + self._o_gate_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[q_out_dim], + weight_names=[self._o_gate_weight_name], + data_type=self.data_type_, + bias_names=None, + quant_method=self.get_quant_method("o_gate_proj"), + ) + + def _init_weight_names(self): + super()._init_weight_names() + self._retarget_attn_norm_names() + + def _init_moe(self): + moe_intermediate_size = self.network_config_["moe_intermediate_size"] + self.moe_gate = ROWMMWeight( + in_dim=self.network_config_["hidden_size"], + out_dims=[self.n_routed_experts], + weight_names=f"{self._MTP_PREFIX}{self.layer_num_}.mlp.gate.weight", + data_type=self.data_type_, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + self.experts = FusedMoeWeight( + gate_proj_name="gate_proj", + down_proj_name="down_proj", + up_proj_name="up_proj", + e_score_correction_bias_name="", + weight_prefix=f"{self._MTP_PREFIX}{self.layer_num_}.mlp.experts", + n_routed_experts=self.n_routed_experts, + hidden_size=self.network_config_["hidden_size"], + moe_intermediate_size=moe_intermediate_size, + data_type=self.data_type_, + quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"), + layer_num=self.layer_num_, + network_config=self.network_config_, + ) + self._init_gated_ffn() + + def _init_gated_ffn(self): + hidden_size = self.network_config_["hidden_size"] + if "shared_expert_intermediate_size" not in self.network_config_: + return + + prefix = f"{self._MTP_PREFIX}{self.layer_num_}.mlp.shared_expert" + inter_size = self.network_config_["shared_expert_intermediate_size"] + if get_env_start_args().enable_ep_moe: + self.gate_up_proj = ROWMMWeight( + in_dim=hidden_size, + out_dims=[inter_size, inter_size], + weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], + data_type=self.data_type_, + quant_method=self.get_quant_method("gate_up_proj"), + tp_rank=0, + tp_world_size=1, + ) + self.down_proj = COLMMWeight( + in_dim=inter_size, + out_dims=[hidden_size], + weight_names=f"{prefix}.down_proj.weight", + data_type=self.data_type_, + quant_method=self.get_quant_method("down_proj"), + tp_rank=0, + tp_world_size=1, + ) + else: + self.gate_up_proj = ROWMMWeight( + in_dim=hidden_size, + out_dims=[inter_size, inter_size], + weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], + data_type=self.data_type_, + quant_method=self.get_quant_method("gate_up_proj"), + ) + self.down_proj = COLMMWeight( + in_dim=inter_size, + out_dims=[hidden_size], + weight_names=f"{prefix}.down_proj.weight", + data_type=self.data_type_, + quant_method=self.get_quant_method("down_proj"), + ) + + self.ffn_gate = ROWMMWeight( + in_dim=hidden_size, + out_dims=[1], + weight_names=f"{self._MTP_PREFIX}{self.layer_num_}.mlp.shared_expert_gate.weight", + data_type=self.data_type_, + bias_names=None, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) diff --git a/lightllm/models/qwen3_5_moe_mtp/model.py b/lightllm/models/qwen3_5_moe_mtp/model.py new file mode 100644 index 000000000..022864f6b --- /dev/null +++ b/lightllm/models/qwen3_5_moe_mtp/model.py @@ -0,0 +1,8 @@ +from lightllm.models.qwen3_5_mtp.model import Qwen3_5MTPModel +from lightllm.models.qwen3_5_moe_mtp.layer_weights.transformer_layer_weight import ( + Qwen3_5MoeMTPTransformerLayerWeight, +) + + +class Qwen3_5MoeMTPModel(Qwen3_5MTPModel): + transformer_weight_class = Qwen3_5MoeMTPTransformerLayerWeight diff --git a/lightllm/models/qwen3_5_mtp/__init__.py b/lightllm/models/qwen3_5_mtp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/qwen3_5_mtp/layer_infer/__init__.py b/lightllm/models/qwen3_5_mtp/layer_infer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/qwen3_5_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/qwen3_5_mtp/layer_infer/pre_layer_infer.py new file mode 100644 index 000000000..906a0ab62 --- /dev/null +++ b/lightllm/models/qwen3_5_mtp/layer_infer/pre_layer_infer.py @@ -0,0 +1,40 @@ +import torch + +from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import Qwen3VLMultimodalPreLayerInfer +from lightllm.models.qwen3_5_mtp.layer_weights.pre_and_post_layer_weight import Qwen3_5MTPPreAndPostLayerWeight +from lightllm.models.llama.infer_struct import LlamaInferStateInfo + + +class Qwen3_5MTPPreLayerInfer(Qwen3VLMultimodalPreLayerInfer): + def __init__(self, network_config): + super().__init__(network_config) + self.eps_ = network_config["rms_norm_eps"] + self.hidden_size = network_config["hidden_size"] + return + + def _mtp_fuse( + self, + input_embdings: torch.Tensor, + infer_state: LlamaInferStateInfo, + layer_weight: Qwen3_5MTPPreAndPostLayerWeight, + ) -> torch.Tensor: + tgt_embdings = infer_state.mtp_draft_input_hiddens + assert ( + input_embdings.shape[0] == tgt_embdings.shape[0] + ), f"shape {input_embdings.shape} != shape {tgt_embdings.shape}" + + layer_weight.enorm_weight_(input=input_embdings, eps=self.eps_, out=input_embdings) + layer_weight.hnorm_weight_(input=tgt_embdings, eps=self.eps_, out=tgt_embdings) + cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) + + return layer_weight.eh_proj_weight_.mm(cat_embdings) + + def context_forward( + self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: Qwen3_5MTPPreAndPostLayerWeight + ): + input_embdings = super().context_forward(input_ids, infer_state, layer_weight) + return self._mtp_fuse(input_embdings, infer_state, layer_weight) + + def token_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: Qwen3_5MTPPreAndPostLayerWeight): + input_embdings = super().token_forward(input_ids, infer_state, layer_weight) + return self._mtp_fuse(input_embdings, infer_state, layer_weight) diff --git a/lightllm/models/qwen3_5_mtp/layer_weights/__init__.py b/lightllm/models/qwen3_5_mtp/layer_weights/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/qwen3_5_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_5_mtp/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 000000000..25c56a0d7 --- /dev/null +++ b/lightllm/models/qwen3_5_mtp/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,45 @@ +from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + EmbeddingWeight, + LMHeadWeight, + NoTpGEMMANormWeight, + ROWMMWeight, +) +from lightllm.common.quantization import Quantcfg + + +class Qwen3_5MTPPreAndPostLayerWeight(PreAndPostLayerWeight): + def __init__(self, data_type, network_config, quant_cfg: Quantcfg): + super().__init__(data_type, network_config) + self.quant_cfg: Quantcfg = quant_cfg + hidden_size = network_config["hidden_size"] + + self.eh_proj_weight_ = ROWMMWeight( + in_dim=hidden_size * 2, + out_dims=[hidden_size], + weight_names="mtp.fc.weight", + data_type=self.data_type_, + quant_method=self.quant_cfg.get_quant_method(0, "eh_proj"), + tp_rank=0, + tp_world_size=1, + ) + self.enorm_weight_ = NoTpGEMMANormWeight( + dim=hidden_size, + weight_name="mtp.pre_fc_norm_embedding.weight", + data_type=self.data_type_, + ) + self.hnorm_weight_ = NoTpGEMMANormWeight( + dim=hidden_size, + weight_name="mtp.pre_fc_norm_hidden.weight", + data_type=self.data_type_, + ) + self.final_norm_weight_ = NoTpGEMMANormWeight( + dim=hidden_size, + weight_name="mtp.norm.weight", + data_type=self.data_type_, + ) + + # Shared with the main Qwen3.5 model, injected by the model class (not loaded here). + self.wte_weight_: EmbeddingWeight = None + self.lm_head_weight_: LMHeadWeight = None + return diff --git a/lightllm/models/qwen3_5_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_5_mtp/layer_weights/transformer_layer_weight.py new file mode 100644 index 000000000..4c76b4bce --- /dev/null +++ b/lightllm/models/qwen3_5_mtp/layer_weights/transformer_layer_weight.py @@ -0,0 +1,80 @@ +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, QKVROWNMMWeight +from lightllm.models.qwen3_5.layer_weights.transformer_layer_weight import ( + Qwen35TransformerLayerWeight, +) +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class Qwen3_5MTPTransformerLayerWeight(Qwen35TransformerLayerWeight): + # MTP draft-model weights live under the `mtp.layers.*` checkpoint namespace, so every + # main-model layer name (`model.layers.*`) is retargeted to it at load time. + + _MAIN_PREFIX = "model.layers." + _MTP_PREFIX = "mtp.layers." + + _ATTN_NORM_NAME_ATTRS = ( + "_q_weight_name", + "_q_norm_name", + "_q_bias_name", + "_k_weight_name", + "_k_norm_name", + "_k_bias_name", + "_v_weight_name", + "_v_bias_name", + "_kv_weight_name", + "_kv_bias_name", + "_o_weight_name", + "_o_bias_name", + "_att_norm_weight_name", + "_att_norm_bias_name", + "_ffn_norm_weight_name", + "_ffn_norm_bias_name", + ) + + def _retarget(self, name): + if name is None: + return None + return name.replace(self._MAIN_PREFIX, self._MTP_PREFIX, 1) + + def _retarget_attn_norm_names(self): + for attr in self._ATTN_NORM_NAME_ATTRS: + setattr(self, attr, self._retarget(getattr(self, attr))) + + def _init_qkv(self): + in_dim = self.n_embed + q_out_dim = self.q_head_num_ * self.head_dim + self.qkv_proj = QKVROWNMMWeight( + in_dim=in_dim, + q_head_num=self.q_head_num_, + kv_head_num=self.k_head_num_, + head_dim=self.head_dim, + weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name], + data_type=self.data_type_, + bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name], + quant_method=self.get_quant_method("qkv_proj"), + ) + self._o_gate_weight_name = f"{self._MTP_PREFIX}{self.layer_num_}.self_attn.o_gate_proj.weight" + self._o_gate_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[q_out_dim], + weight_names=[self._o_gate_weight_name], + data_type=self.data_type_, + bias_names=None, + quant_method=self.get_quant_method("o_gate_proj"), + ) + + def _init_weight_names(self): + super()._init_weight_names() + # Retarget all main-model layer key names to the mtp.* namespace. + self._retarget_attn_norm_names() + # MLP (dense) projection names retargeted by Qwen35TransformerLayerWeight. + self._gate_weight_name = self._retarget(self._gate_weight_name) + self._gate_bias_name = self._retarget(self._gate_bias_name) + self._up_weight_name = self._retarget(self._up_weight_name) + self._up_bias_name = self._retarget(self._up_bias_name) + self._gate_up_weight_name = self._retarget(self._gate_up_weight_name) + self._gate_up_bias_name = self._retarget(self._gate_up_bias_name) + self._down_weight_name = self._retarget(self._down_weight_name) + self._down_bias_name = self._retarget(self._down_bias_name) diff --git a/lightllm/models/qwen3_5_mtp/model.py b/lightllm/models/qwen3_5_mtp/model.py new file mode 100644 index 000000000..b98524a99 --- /dev/null +++ b/lightllm/models/qwen3_5_mtp/model.py @@ -0,0 +1,109 @@ +from typing import List + +from lightllm.common.basemodel.basemodel import TpPartBaseModel +from lightllm.models.qwen3_5.model import Qwen3_5TpPartModel +from lightllm.models.qwen3_5.layer_infer.transformer_layer_infer import Qwen35TransformerLayerInfer +from lightllm.models.qwen3_5_mtp.layer_weights.pre_and_post_layer_weight import Qwen3_5MTPPreAndPostLayerWeight +from lightllm.models.qwen3_5_mtp.layer_weights.transformer_layer_weight import Qwen3_5MTPTransformerLayerWeight +from lightllm.models.qwen3_5_mtp.layer_infer.pre_layer_infer import Qwen3_5MTPPreLayerInfer +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class Qwen3_5MTPModel(Qwen3_5TpPartModel): + + pre_and_post_weight_class = Qwen3_5MTPPreAndPostLayerWeight + pre_layer_infer_class = Qwen3_5MTPPreLayerInfer + transformer_weight_class = Qwen3_5MTPTransformerLayerWeight + transformer_layer_infer_class = Qwen35TransformerLayerInfer + + # MTP draft model: reuses the main model's req/mem managers and rope caches, and is + # marked so the decode CUDA-graph / padding paths detect it (is_mtp_draft_model). + is_mtp_draft_model = True + + def __init__(self, kvargs: dict): + self._pre_init(kvargs) + super().__init__(kvargs) + return + + def _pre_init(self, kvargs: dict): + self.main_model: TpPartBaseModel = kvargs.pop("main_model") + self.mtp_previous_draft_models: List[TpPartBaseModel] = kvargs.pop("mtp_previous_draft_models") + return + + def _init_custom(self): + self._cos_cached = self.main_model._cos_cached + self._sin_cached = self.main_model._sin_cached + return + + def _init_req_manager(self): + self.req_manager = self.main_model.req_manager + return + + def _init_mem_manager(self): + self.mem_manager = self.main_model.mem_manager + return + + def _init_config(self): + super()._init_config() + self.config["full_attention_interval"] = 1 + self.config["num_hidden_layers"] = 1 + self.config["n_layer"] = 1 + return + + def _init_some_value(self): + super()._init_some_value() + self.layers_num = 1 + return + + def _init_weights(self, start_layer_index=None): + assert start_layer_index is None + self.pre_post_weight = self.pre_and_post_weight_class( + self.data_type, network_config=self.config, quant_cfg=self.quant_cfg + ) + self.trans_layers_weight = [ + self.transformer_weight_class( + i, + self.data_type, + network_config=self.config, + quant_cfg=self.quant_cfg, + ) + for i in range(0, self.config["n_layer"]) + ] + # Shared with the main Qwen3.5 model (mtp_use_dedicated_embeddings: false). + self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_ + self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_ + return + + def _init_infer_layer(self, start_layer_index=None): + assert start_layer_index is None + # Build the single draft layer with layer_num == 0 so that, with + # full_attention_interval == 1, it takes the full-attention (mrope) path. + super()._init_infer_layer(start_layer_index=0) + self._assign_draft_kv_slot() + return + + def _assign_draft_kv_slot(self): + mem_manager = self.main_model.mem_manager + main_full_att = getattr(mem_manager, "main_full_att_layer_num", None) + interval = self.main_model.config["full_attention_interval"] + if main_full_att is None: + # Non-hybrid / unexpected mem_manager: nothing to remap. + return + + draft_idx = len(self.mtp_previous_draft_models) + draft_full_att_layers = getattr(mem_manager, "draft_full_att_layers", None) + if draft_full_att_layers is not None: + assert draft_idx < draft_full_att_layers, ( + f"draft_idx {draft_idx} out of range for draft_full_att_layers " + f"{draft_full_att_layers}; mem_manager not sized for this many MTP draft blocks" + ) + draft_kv_slot = main_full_att + draft_idx + layer_infer = self.layers_infer[0] + layer_infer.layer_num_ = draft_kv_slot * interval + logger.info( + f"Qwen3.5 MTP draft layer assigned dedicated full-attn KV slot {draft_kv_slot} " + f"(layer_num_={layer_infer.layer_num_}, interval={interval}, main_full_att={main_full_att})" + ) + return diff --git a/lightllm/models/qwen3next/infer_struct.py b/lightllm/models/qwen3next/infer_struct.py index 0006a682f..bb1516673 100644 --- a/lightllm/models/qwen3next/infer_struct.py +++ b/lightllm/models/qwen3next/infer_struct.py @@ -1,6 +1,4 @@ -import torch from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.utils.envs_utils import get_env_start_args class Qwen3NextInferStateInfo(LlamaInferStateInfo): @@ -10,7 +8,7 @@ def __init__(self): def init_some_extra_state(self, model): super().init_some_extra_state(model) - self.b_att_seq_len = self.b_seq_len - mtp_step = get_env_start_args().mtp_step - self.b_buffer_idx = self.b_req_idx * (mtp_step + 1) + self.b_mtp_index + from lightllm.common.basemodel.mtp_verify_extra_state import init_mtp_verify_extra_state + + init_mtp_verify_extra_state(self, model) return diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index 60bf0e6b7..66dd9bdd1 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -256,6 +256,18 @@ def gdn_forward( if is_prefill: core_attn_out, z = self._gdn_prefill_wrapper_run(mixed_qkvzba, infer_state, layer_weight) + elif getattr(infer_state, "is_mtp_verify", False): + mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba) + conv_states, ssm_states = infer_state.req_manager.get_mamba_cache(self.layer_num_) + core_attn_out = self._gdn_verify_kernel( + mixed_qkv, + conv_states, + ssm_states, + a, + b, + infer_state, + layer_weight, + ) else: mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba) conv_states, ssm_states = infer_state.req_manager.get_mamba_cache(self.layer_num_) @@ -376,7 +388,7 @@ def _gdn_prefill_kernel( layer_weight.linear_conv1d.mm_param.weight, bias=layer_weight.linear_conv1d.bias, query_start_loc=infer_state.b1_cu_q_seq_len, - cache_indices=infer_state.b_buffer_idx, + cache_indices=infer_state.b_conv_buffer_idx, has_initial_state=infer_state.b_ready_cache_len > 0, conv_states=conv_states, activation=self.activation, @@ -421,7 +433,7 @@ def _gdn_decode_kernel( layer_weight.linear_conv1d.mm_param.weight, bias=layer_weight.linear_conv1d.bias, activation=self.activation, - conv_state_indices=infer_state.b_buffer_idx, + conv_state_indices=infer_state.b_conv_buffer_idx, ) # Recurrent processing with fused gating; the kernel reads the @@ -441,3 +453,51 @@ def _gdn_decode_kernel( b_raw=b, ) return core_attn_out + + def _gdn_verify_kernel( + self, + mixed_qkv: torch.Tensor, + conv_states: torch.Tensor, + ssm_states: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextTransformerLayerWeight, + ): + from lightllm.models.qwen3next.triton_kernel.causal_conv1d_spec import ( + causal_conv1d_update as causal_conv1d_update_spec, + ) + + mixed_qkv = causal_conv1d_update_spec( + mixed_qkv, + conv_states, + layer_weight.linear_conv1d.mm_param.weight, + bias=layer_weight.linear_conv1d.bias, + activation=self.activation, + conv_state_indices=infer_state.b_conv_buffer_idx, + num_accepted_tokens=infer_state.b_num_accepted_tokens, + query_start_loc=infer_state.b_gdn_verify_cu_seqlens, + ) + + query, key, value = self._rearrange_mixed_qkv(mixed_qkv, decode=False) + assert infer_state.b_ssm_index_rows.dim() == 2, "SSM index rows must be 2D [N, S+1]" + # #8b: b_num_accepted_tokens >= 1 is guaranteed upstream (init sets accept_len=1; the + # offload/snapshot guards bound it to [1, mtp_step+1]). The old per-layer per-step .all() + # D2H sync stalled the GPU on the eager decode hot path; it is redundant here. + core_attn_out, _ = fused_recurrent_gated_delta_rule( + q=query, + k=key, + v=value, + initial_state=ssm_states, + inplace_final_state=True, + cu_seqlens=infer_state.b_gdn_verify_cu_seqlens.to(torch.long), + ssm_state_indices=infer_state.b_ssm_index_rows, + ssm_state_write_indices=infer_state.b_ssm_index_rows, + num_accepted_tokens=infer_state.b_num_accepted_tokens, + use_qk_l2norm_in_kernel=True, + A_log=layer_weight.linear_A_log.weight, + dt_bias=layer_weight.linear_dt_bias.weight, + a_raw=a, + b_raw=b, + ) + return core_attn_out diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index 9b5e9b7a5..5d60bb28f 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -16,7 +16,10 @@ from lightllm.common.kv_cache_mem_manager.qwen3next_mem_manager import Qwen3NextMemManager from lightllm.server.core.objs.start_args_type import StartArgs from lightllm.common.req_manager import ReqManagerForMamba -from lightllm.common.linear_att_cache_manager.config_objs import LinearAttCacheConfig +from lightllm.common.linear_att_cache_manager.config_objs import ( + LinearAttCacheConfig, + get_mtp_draft_full_att_layer_num, +) logger = init_logger(__name__) @@ -59,6 +62,7 @@ def _init_mem_manager(self): assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 start_args: StartArgs = get_env_start_args() ssm_dtype_dict = {"bfloat16": torch.bfloat16, "float32": torch.float32} + draft_full_att_layers = get_mtp_draft_full_att_layer_num(start_args) self.linear_config = LinearAttCacheConfig( tp_world_size=self.tp_world_size_, full_att_all_num_kv_heads=self.config["num_key_value_heads"], @@ -78,17 +82,24 @@ def _init_mem_manager(self): ssm_state_dtype=ssm_dtype_dict[start_args.linear_att_ssm_data_type], full_attention_interval=self.config["full_attention_interval"], all_layer_num=self.config["n_layer"], + draft_full_att_layer_num=draft_full_att_layers, ) + main_full_att = self.linear_config.get_main_full_att_layer_num() + persisted_full_att = self.linear_config.get_persisted_full_att_layer_num() + self.mem_manager = Qwen3NextMemManager( size=self.max_total_token_num, dtype=self.data_type, num_kv_heads=self.num_kv_heads, head_dim=self.config["head_dim"], - full_att_layer_num=self.linear_config.all_layer_num - self.linear_config.linear_layer_num, + full_att_layer_num=persisted_full_att, linear_config=self.linear_config, mem_fraction=self.mem_fraction, ) + self.mem_manager.main_full_att_layer_num = main_full_att + self.mem_manager.draft_full_att_layers = draft_full_att_layers + self.mem_manager.persisted_full_att_layer_num = persisted_full_att def _init_req_manager(self): create_max_seq_len = 0 diff --git a/lightllm/models/qwen3next/triton_kernel/causal_conv1d_spec.py b/lightllm/models/qwen3next/triton_kernel/causal_conv1d_spec.py new file mode 100644 index 000000000..2f0e22fa3 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/causal_conv1d_spec.py @@ -0,0 +1,468 @@ +# Vendored from vLLM v0.14.1 +# source: vllm/model_executor/layers/mamba/ops/causal_conv1d.py +# commit: d7de043d55d1dd629554467e23874097e1c48993 +# Adapted for LightLLM: imports point at standard triton; the vLLM-specific +# block-table params (block_idx_last_scheduled_token, initial_state_idx, +# null_block_id) are dropped — LightLLM uses contiguous per-request slots. +# Supports spec-decode: writes per-position conv state to a single widened slot +# per request and reads from offset (num_accepted_tokens-1). +# +# Upstream copyright notice: +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2024, Tri Dao. +# Adapted from +# https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py +from typing import Optional + +import torch +import triton +import triton.language as tl + + +@triton.jit() +def _causal_conv1d_update_kernel( + # Pointers to matrices + x_ptr, # (batch, dim, seqlen) + w_ptr, # (dim, width) + bias_ptr, + conv_state_ptr, + conv_state_indices_ptr, + num_accepted_tokens_ptr, + query_start_loc_ptr, # (batch + 1) + o_ptr, # (batch, dim, seqlen) + # Matrix dimensions + batch: int, + dim: tl.constexpr, + seqlen: tl.constexpr, + state_len: tl.constexpr, + num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines + # Strides + stride_x_seq: tl.constexpr, + stride_x_dim: tl.constexpr, + stride_x_token: tl.constexpr, + stride_w_dim: tl.constexpr, + stride_w_width: tl.constexpr, + stride_conv_state_seq: tl.constexpr, + stride_conv_state_dim: tl.constexpr, + stride_conv_state_tok: tl.constexpr, + stride_state_indices: tl.constexpr, + stride_o_seq: tl.constexpr, + stride_o_dim: tl.constexpr, + stride_o_token: tl.constexpr, + # others + pad_slot_id: tl.constexpr, + # Meta-parameters + HAS_BIAS: tl.constexpr, + KERNEL_WIDTH: tl.constexpr, + SILU_ACTIVATION: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, + NP2_STATELEN: tl.constexpr, + USE_PAD_SLOT: tl.constexpr, + BLOCK_N: tl.constexpr, +): + # ruff: noqa: E501 + idx_seq = tl.program_id(0) + if idx_seq >= batch: + return + + # [BLOCK_N,] elements along the feature-dimension (channel) + idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) + + # LightLLM uses contiguous per-request slots, so the cache block for both + # the initial-state read and the final write is always conv_state_indices[idx_seq]. + conv_state_init = 0 + current_last_index = 0 + + # cache_idx + conv_states_input_coord = tl.load(conv_state_indices_ptr + idx_seq * stride_state_indices + conv_state_init).to( + tl.int64 + ) + + if USE_PAD_SLOT: # noqa + if conv_states_input_coord == pad_slot_id: + # not processing as this is not the actual sequence + return + + if IS_VARLEN: + query_start_index = tl.load(query_start_loc_ptr + idx_seq).to(tl.int64) + query_end_index = tl.load(query_start_loc_ptr + (idx_seq + 1)).to(tl.int64) + # revise state_len and seqlen + state_len = state_len - (seqlen - (query_end_index - query_start_index)) + seqlen = query_end_index - query_start_index + x_offset = query_start_index * stride_x_token + o_offset = query_start_index * stride_o_token + else: + query_start_index = idx_seq * seqlen + query_end_index = query_start_index + seqlen + x_offset = idx_seq * stride_x_seq + o_offset = idx_seq * stride_o_seq + + if query_start_index == query_end_index: + return + + if IS_SPEC_DECODING: + # The rolling of conv state: + # + # Before forward, the conv_state is: + # [history1, history2, ..., historyM]. + # + # After forward, the conv_state becomes: + # [history2, ..., historyM, draft1, draft2, ..., draftN]. + # + # After acceptance, it becomes: + # + # - accept 1 tokens: [history2, ..., historyM, draft1] + # - accept 2 tokens: [history3, ..., historyM, draft1, draft2] + # - and so on. + conv_state_token_offset = tl.load(num_accepted_tokens_ptr + idx_seq).to(tl.int64) - 1 + else: + conv_state_token_offset = 0 + + # STEP 1: READ init_state data + conv_states_base = ( + conv_state_ptr + (conv_states_input_coord * stride_conv_state_seq) + (idx_feats * stride_conv_state_dim) + ) + mask_w = idx_feats < dim + + prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok + if KERNEL_WIDTH >= 2: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 3: + conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 4: + conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 5: + conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N] + col3 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 6: + conv_states_ptrs = prior_tokens + 4 * stride_conv_state_tok # [BLOCK_N] + col4 = tl.load(conv_states_ptrs, mask_w, 0.0) + + # STEP 2: assume state_len > seqlen + idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + # With speculative decoding, the conv_state updates works in a sliding + # window manner, at each forward pass, the tokens are shift by 1, so we + # load since idx_tokens + 1. + conv_state_ptrs_source = ( + conv_state_ptr + + (conv_states_input_coord * stride_conv_state_seq) + + conv_state_token_offset * stride_conv_state_tok + + (idx_feats * stride_conv_state_dim)[None, :] + + ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = ( + (conv_states_input_coord < num_cache_lines) + & ((idx_tokens + seqlen) < state_len)[:, None] + & (idx_feats < dim)[None, :] + ) + conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) + + VAL = state_len - seqlen + x_base = x_ptr + x_offset + (idx_feats * stride_x_dim) # [BLOCK_N] + + x_ptrs = x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] + + mask_x = ( + (idx_tokens - VAL >= 0)[:, None] & (idx_tokens - VAL < seqlen)[:, None] & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + tl.debug_barrier() + + new_conv_state = tl.where(mask, conv_state, loaded_x) + + # Write the updated state back. In LightLLM the read and write slots are the + # same contiguous per-request slot (current_last_index == conv_state_init == 0), + # so this resolves to the same conv_state_indices[idx_seq] used for the read. + conv_states_offset = tl.load(conv_state_indices_ptr + idx_seq * stride_state_indices + current_last_index).to( + tl.int64 + ) + conv_state_ptrs_target = ( + conv_state_ptr + + (conv_states_offset * stride_conv_state_seq) # Offset from seq + + (idx_feats * stride_conv_state_dim) + )[ + None, : + ] + ( # [BLOCK_N,] + idx_tokens * stride_conv_state_tok + )[ + :, None + ] + mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.store(conv_state_ptrs_target, new_conv_state, mask) + + # STEP 3: init accumulator + if HAS_BIAS: + bias = bias_ptr + idx_feats + mask_bias = idx_feats < dim + acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to(tl.float32) # [BLOCK_N] + else: + acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32) + + # STEP 4: + # PRE-LOAD WEIGHTS + # first kernel column, configured for weights to handle BLOCK_N features in range + w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] + mask_w = idx_feats < dim + if KERNEL_WIDTH >= 2: + w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor + w_col0 = tl.load(w_ptrs, mask_w, other=0.0) + w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor + w_col1 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 3: + w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor + w_col2 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 4: + w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor + w_col3 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 5: + w_ptrs = w_base + (4 * stride_w_width) # [BLOCK_N] tensor + w_col4 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 6: + w_ptrs = w_base + (5 * stride_w_width) # [BLOCK_N] tensor + w_col5 = tl.load(w_ptrs, mask_w, other=0.0) + + x_base_1d = x_base # starting of chunk [BLOCK_N] + mask_x_1d = idx_feats < dim + + # STEP 5: compute each token + for idx_token in tl.range(seqlen): + acc = acc_preload + + matrix_w = w_col0 + matrix_x = col0 + for j in tl.static_range(KERNEL_WIDTH): + if KERNEL_WIDTH == 2: + if j == 1: # KERNEL_WIDTH-1: + matrix_w = w_col1 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 3: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 4: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 5: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + matrix_x = col3 + elif j == 4: + matrix_w = w_col4 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 6: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + matrix_x = col3 + elif j == 4: + matrix_w = w_col4 + matrix_x = col4 + elif j == 5: + matrix_w = w_col5 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + + acc += matrix_x * matrix_w # [BLOCK_N] + + if KERNEL_WIDTH == 2: + col0 = matrix_x + elif KERNEL_WIDTH == 3: + col0 = col1 + col1 = matrix_x + elif KERNEL_WIDTH == 4: + col0 = col1 + col1 = col2 + col2 = matrix_x + elif KERNEL_WIDTH == 5: + col0 = col1 + col1 = col2 + col2 = col3 + col3 = matrix_x + elif KERNEL_WIDTH == 6: + col0 = col1 + col1 = col2 + col2 = col3 + col3 = col4 + col4 = matrix_x + + if SILU_ACTIVATION: + acc = acc / (1 + tl.exp(-acc)) + mask_1d = (idx_token < seqlen) & (idx_feats < dim) # token-index # feature-index + o_ptrs = o_ptr + o_offset + idx_token * stride_o_token + (idx_feats * stride_o_dim) + + tl.store(o_ptrs, acc, mask=mask_1d) + + +def causal_conv1d_update( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + activation: Optional[str] = None, + cache_seqlens: Optional[torch.Tensor] = None, + conv_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + query_start_loc: Optional[torch.Tensor] = None, + pad_slot_id: int = -1, +): + """Spec-decode capable conv1d update. When num_accepted_tokens/query_start_loc + are None it must behave like a single-token decode update. x may be (batch, dim) + single-token or (num_tokens, dim) flattened varlen with query_start_loc grouping + each request's S+1 candidates. conv_state is (num_slots, dim, state_len) with + state_len = (width-1)+S widened. Read offset = num_accepted_tokens-1; writes to + the same slot. + + Args: + x: input tensor of shape ``(batch, dim)`` (single-token decode), + ``(batch, dim, seqlen)`` (single/multi token), or ``(num_tokens, dim)`` + flattened varlen grouped by ``query_start_loc``. + conv_state: ``(num_slots, dim, state_len)`` with ``state_len >= width - 1``. + For spec decode the slot is widened to ``(width - 1) + S`` where ``S`` is + the number of speculative tokens (so ``seqlen == S + 1``). + weight: depthwise filter of shape ``(dim, width)``. + bias: optional ``(dim,)`` bias. + activation: ``None``, ``"silu"`` or ``"swish"``. + cache_seqlens: accepted for call-compatibility with the non-spec wrapper; + unused here. + conv_state_indices: ``(batch,)`` int32 mapping each request to its conv_state + slot. Required when ``query_start_loc`` is given. + num_accepted_tokens: ``(batch,)`` int32. When not None the conv_state read + offset for each request is ``num_accepted_tokens - 1`` (sliding window + spec-decode update). + query_start_loc: ``(batch + 1,)`` int32 varlen cumulative token offsets; when + None the call is a plain single-/multi-token decode update. + pad_slot_id: slot id that marks padded entries to skip. + + Returns: + Output tensor with the same shape as ``x`` (the kernel overwrites ``x`` in + place), one conv output per input token. + """ + if activation is not None: + assert activation in ["silu", "swish"] + + original_x_dtype = x.dtype + x = x.to(conv_state.dtype) + unsqueeze = query_start_loc is None and x.dim() == 2 + if unsqueeze: + # make it (batch, dim, seqlen) with seqlen == 1 + x = x.unsqueeze(-1) + if query_start_loc is None: + batch, dim, seqlen = x.shape + else: + assert conv_state_indices is not None + batch = conv_state_indices.size(0) + dim = x.size(1) + # The MTP verify layout is uniform (mtp_step+1) tokens per request, so seqlen is + # structurally x.size(0) // batch. Compute it without a D2H sync on query_start_loc on + # BOTH the capture and eager paths (#8a) — the eager .item() ran once per GDN layer per + # decode step. .item() is also illegal during CUDA-graph capture. + assert x.size(0) % batch == 0, "varlen conv update expects a uniform per-request length" + seqlen = x.size(0) // batch + _, width = weight.shape + # conv_state: (num_slots, dim, state_len), where state_len >= width - 1 + num_cache_lines, _, state_len = conv_state.size() + + # adopt the strategy in vLLM that overwrites 'x' directly, rather than creating a new tensor 'o' + out = x + stride_w_dim, stride_w_width = weight.stride() + + if query_start_loc is None: + # X (batch, dim, seqlen) + stride_x_seq, stride_x_dim, stride_x_token = x.stride() + stride_o_seq, stride_o_dim, stride_o_token = out.stride() + else: + # X (num_tokens, dim) + stride_x_token, stride_x_dim = x.stride() + stride_x_seq = 0 + stride_o_token, stride_o_dim = out.stride() + stride_o_seq = 0 + + stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride() + stride_state_indices = conv_state_indices.stride(0) if conv_state_indices is not None else 0 + if num_accepted_tokens is not None: + state_len = width - 1 + (seqlen - 1) # effective state_len needed + else: + state_len = width - 1 + np2_statelen = triton.next_power_of_2(state_len) + + def grid(META): + return ( + batch, + triton.cdiv(dim, META["BLOCK_N"]), + ) + + _causal_conv1d_update_kernel[grid]( + # Pointers to matrices + x, + weight, + bias, + conv_state, + conv_state_indices, + num_accepted_tokens, + query_start_loc, + out, + # Matrix dimensions + batch, + dim, + seqlen, + state_len, + num_cache_lines, + # stride + stride_x_seq, + stride_x_dim, + stride_x_token, + stride_w_dim, + stride_w_width, + stride_istate_seq, + stride_istate_dim, + stride_istate_token, + stride_state_indices, + stride_o_seq, + stride_o_dim, + stride_o_token, + # others + pad_slot_id, + # META + HAS_BIAS=bias is not None, + KERNEL_WIDTH=width, + SILU_ACTIVATION=activation in ["silu", "swish"], + IS_VARLEN=query_start_loc is not None, + IS_SPEC_DECODING=num_accepted_tokens is not None, + NP2_STATELEN=np2_statelen, + USE_PAD_SLOT=pad_slot_id is not None, + BLOCK_N=256, + ) + if unsqueeze: + out = out.squeeze(-1) + return out.to(original_x_dtype) diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py index b0dc41a3c..5dfbd6e4a 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py @@ -214,20 +214,26 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( def _ensure_qkv_token_strided(x: torch.Tensor, inner_numel: int): - """Return q/k/v and token stride, copying only when needed.""" + """Return q/k/v and per-token stride, copying only when needed. + + Supports the decode layout [tokens, 1, head, dim] and the MTP verify / + varlen layout [1, tokens, head, dim]; the token dimension is the non-unit + leading dim. Both are column views of a packed projection output, so the + tail [head, dim] is contiguous and no copy is needed. + """ if x is None: return None, 0 - # Decode layout must be [tokens, 1, head, dim]. - assert x.shape[1] == 1, "q/k/v must use decode layout [tokens, 1, head, dim]" + assert x.shape[0] == 1 or x.shape[1] == 1, "q/k/v must use layout [tokens, 1, head, dim] or [1, tokens, head, dim]" # Packed tail [head, dim] means the last two strides are [dim, 1]. tail_contiguous = x.stride()[-2:] == (x.shape[-1], 1) if not tail_contiguous: x = x.contiguous() return x, inner_numel - else: - return x, x.stride(0) + # Token dim is the non-unit leading dim (dim 0 for decode, dim 1 for verify). + tok_dim = 0 if x.shape[1] == 1 else 1 + return x, x.stride(tok_dim) def _ensure_gate_token_strided(x: torch.Tensor, inner_numel: int): @@ -264,11 +270,10 @@ def fused_recurrent_gated_delta_rule_fwd( ) -> tuple[torch.Tensor, torch.Tensor]: B, T, H, K, V = *k.shape, v.shape[-1] HV = v.shape[2] - # In LightLLM's Qwen3Next inference path this fused recurrent kernel is - # used only for decode. Prefill/varlen requests are handled by - # chunk_gated_delta_rule, so keep cu_seqlens out of this strided-view path. - assert cu_seqlens is None, "cu_seqlens is not supported by the decode-only fused recurrent kernel" - N = B + # Decode passes cu_seqlens=None (equal-length one-token sequences); the + # Qwen3Next MTP verify path passes cu_seqlens for variable-length verify + # chunks. Both flow through the per-token strided-view path below. + N = B if cu_seqlens is None else len(cu_seqlens) - 1 q, stride_q_tok = _ensure_qkv_token_strided(q, H * K) k, stride_k_tok = _ensure_qkv_token_strided(k, H * K) v, stride_v_tok = _ensure_qkv_token_strided(v, HV * V) @@ -468,10 +473,10 @@ def fused_recurrent_gated_delta_rule( inplace_final_state: bool: Whether to store the final state in-place to save memory. Default: `True`. - cu_seqlens (torch.LongTensor): - Must be `None`. In LightLLM this fused recurrent kernel is used only - by the Qwen3Next decode path; prefill/varlen requests use - `chunk_gated_delta_rule`. + cu_seqlens (Optional[torch.LongTensor]): + Cumulative sequence lengths of shape `[N+1]` for variable-length + inputs (the Qwen3Next MTP verify path). `None` for plain decode, + where sequences are treated as equal-length (one token each). ssm_state_indices (Optional[torch.Tensor]): Indices to map the input sequences to the initial/final states. num_accepted_tokens (Optional[torch.Tensor]): @@ -500,9 +505,6 @@ def fused_recurrent_gated_delta_rule( initial_state=h0, ) """ - # This wrapper is only used for Qwen3Next decode inference in LightLLM. - # Keep varlen/prefill inputs on chunk_gated_delta_rule instead. - assert cu_seqlens is None, "cu_seqlens is not supported by the decode-only fused recurrent kernel" if scale is None: scale = k.shape[-1] ** -0.5 else: diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 2f79c441b..86cf82af8 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -374,6 +374,11 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L if not self.is_linear_att_mixed_model: return + # 当 dynamic prompt cache 被禁用时 radix_cache 为 None,没有大页/小页缓冲可写, + # 线性层状态仅存于 req_manager 的 GPU buffer 即可,直接跳过跨请求缓存拷贝。 + if self.radix_cache is None: + return + # 大页对应的 linear att 的拷贝 big_page_token_num = self.args.linear_att_hash_page_size * self.args.linear_att_page_block_num big_page_buffer_ids = [] @@ -397,6 +402,12 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L from lightllm.common.basemodel.triton_kernel.linear_att_copy import copy_linear_att_state_to_kv_buffer + # accept 数量改由 GPU 常驻的 req_to_accept_len 按 req_idx gather(不再读 req.mtp_accept_len)。 + req_idxs = torch.tensor( + [req.req_idx for req in reqs], dtype=torch.int32, requires_grad=False, device="cpu" + ).cuda(non_blocking=True) + b_num_accepted_tokens = self.req_manager.req_to_accept_len[req_idxs] + copy_linear_att_state_to_kv_buffer( b_req_idx=b_req_idx, big_page_buffer_ids=big_page_buffer_ids, @@ -405,6 +416,7 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L cpu_kv_conv_state=self.radix_cache.linear_att_big_page_buffers.conv_state_cache.buffer, cpu_kv_ssm_state=self.radix_cache.linear_att_big_page_buffers.ssm_state_cache.buffer, mtp_step=self.args.mtp_step, + b_num_accepted_tokens=b_num_accepted_tokens, ) assert not self.args.disable_chunked_prefill, "chunked prefill mode must be enabled for linear att mixed model" @@ -420,9 +432,20 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L self.radix_cache.linear_att_small_page_buffers.alloc_one_state_cache() ) if req.tail_linear_att_small_page_buffer_id is not None: - src_buffer_idx = req.req_idx * (self.args.mtp_step + 1) - gpu_conv_state = self.req_manager.req_to_conv_state.buffer[:, src_buffer_idx, ...] - gpu_ssm_state = self.req_manager.req_to_ssm_state.buffer[:, src_buffer_idx, ...] + # 冷路径(prefill 跨小页边界):单标量从 GPU buffer 读回做 Python 切片下标。 + accept_len = int(self.req_manager.req_to_accept_len[req.req_idx].item()) + assert 1 <= accept_len <= self.args.mtp_step + 1, ( + f"mtp_accept_len={accept_len} out of range " + f"[1, {self.args.mtp_step + 1}]; would slice past the widened conv slot" + ) + canonical_off = accept_len - 1 + conv_src_idx = req.req_idx + ssm_src_idx = req.req_idx * (self.args.mtp_step + 1) + canonical_off + narrow_w = self.req_manager.linear_config.get_persisted_conv_state_shape()[-1] + gpu_conv_state = self.req_manager.req_to_conv_state.buffer[ + :, conv_src_idx, ..., canonical_off : canonical_off + narrow_w + ] + gpu_ssm_state = self.req_manager.req_to_ssm_state.buffer[:, ssm_src_idx, ...] dst_buffer_idx = req.tail_linear_att_small_page_buffer_id dst_conv_state, dst_ssm_state = self.radix_cache.linear_att_small_page_buffers.get_state_cache( diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 25b5b7e1a..e890b098b 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -16,7 +16,7 @@ from lightllm.common.linear_att_cache_manager import LinearAttCacheManager from lightllm.server.router.dynamic_prompt.linear_att_radix_cache import LinearAttPagedRadixCache from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache -from lightllm.common.basemodel.batch_objs import ModelOutput, ModelInput +from lightllm.common.basemodel.batch_objs import ModelOutput from lightllm.common.basemodel.triton_kernel.mtp_utils import mtp_verify from lightllm.utils.dist_utils import init_distributed_env from lightllm.utils.envs_utils import get_unique_server_name @@ -41,10 +41,6 @@ ) from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventManager, OverlapEventPack -from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel -from lightllm.models.qwen3_moe_mtp.model import Qwen3MOEMTPModel -from lightllm.models.mistral_mtp.model import MistralMTPModel -from lightllm.models.glm4_moe_lite_mtp.model import Glm4MoeLiteMTPModel from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.server.pd_io_struct import PDChunckedTransTaskRet @@ -328,22 +324,11 @@ def init_mtp_draft_model(self, main_kvargs: dict): "mtp_previous_draft_models": self.draft_models.copy(), } - # Select MTP model class based on model type + # Select MTP model class based on model type (single source of truth: #10). + from lightllm.server.router.model_infer.mode_backend.mtp_model_factory import create_mtp_draft_model + model_type = mtp_model_cfg.get("model_type", "") - if model_type == "deepseek_v3": - assert self.args.mtp_mode in ["vanilla_with_att", "eagle_with_att"] - self.draft_models.append(Deepseek3MTPModel(mtp_model_kvargs)) - elif model_type == "qwen3_moe": - assert self.args.mtp_mode in ["vanilla_no_att", "eagle_no_att"] - self.draft_models.append(Qwen3MOEMTPModel(mtp_model_kvargs)) - elif model_type == "mistral": - assert self.args.mtp_mode in ["vanilla_no_att", "eagle_no_att"] - self.draft_models.append(MistralMTPModel(mtp_model_kvargs)) - elif mtp_model_cfg["model_type"] == "glm4_moe_lite": - assert self.args.mtp_mode in ["vanilla_with_att", "eagle_with_att"] - self.draft_models.append(Glm4MoeLiteMTPModel(mtp_model_kvargs)) - else: - raise ValueError(f"Unsupported MTP model type: {model_type}") + self.draft_models.append(create_mtp_draft_model(model_type, self.args.mtp_mode, mtp_model_kvargs)) self.logger.info(f"loaded mtp model class {self.draft_models[i].__class__}") return @@ -773,8 +758,7 @@ def _update_mtp_accept_ratio( def _gen_argmax_token_ids(self, model_output: ModelOutput): logits = model_output.logits - probs = torch.softmax(logits, dim=-1) - draft_next_token_ids_gpu = torch.argmax(probs, dim=-1) + draft_next_token_ids_gpu = torch.argmax(logits, dim=-1) return draft_next_token_ids_gpu def _sample_and_scatter_token( diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 792a10a78..65bb96163 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -1,5 +1,6 @@ import torch import time +import copy from typing import List, Optional, Callable, Dict, Any from queue import Queue from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend @@ -19,6 +20,7 @@ from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.common.basemodel.triton_kernel.mtp_utils import ( mtp_scatter_next_token_ids, + scatter_mtp_accept_len, ) from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_current_device_id @@ -241,22 +243,20 @@ def decode_mtp( model_input, run_reqs = prepare_decode_inputs(decode_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): - b_mtp_index_cpu = model_input.b_mtp_index model_output = self.model.forward(model_input) next_token_ids, next_token_logprobs = sample(model_output.logits, run_reqs, self.eos_id) # verify the next_token_ids - b_req_mtp_start_loc = [index for index, mtp_index in enumerate(b_mtp_index_cpu) if mtp_index == 0] - b_req_mtp_start_loc = g_pin_mem_manager.gen_from_list( - key="b_req_mtp_start_loc", - data=b_req_mtp_start_loc, - dtype=torch.int32, - ).cuda(non_blocking=True) + n_real = model_input.batch_size // (self.mtp_step + 1) + b_req_mtp_start_loc = torch.arange(n_real, dtype=torch.int32, device="cuda") * (self.mtp_step + 1) mtp_accept_len, accepted_index = self._verify_mtp_v2( new_next_token_ids=next_token_ids, b_req_idx=model_input.b_req_idx, b_req_mtp_start_loc=b_req_mtp_start_loc, ) + scatter_mtp_accept_len( + self.model.req_manager.req_to_accept_len, b_req_mtp_start_loc, model_input.b_req_idx, mtp_accept_len + ) accepted_index_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( key="accepted_index", gpu_tensor=accepted_index, diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index e6b9d1c18..9c83a5f35 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -1,3 +1,4 @@ +import copy import torch import time import torch.nn.functional as F @@ -20,7 +21,7 @@ from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager -from lightllm.common.basemodel.triton_kernel.mtp_utils import mtp_scatter_next_token_ids +from lightllm.common.basemodel.triton_kernel.mtp_utils import mtp_scatter_next_token_ids, scatter_mtp_accept_len from .control_state import DPControlState @@ -462,6 +463,9 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): b_req_idx=b_req_idx, b_req_mtp_start_loc=b_req_mtp_start_loc, ) + scatter_mtp_accept_len( + self.model.req_manager.req_to_accept_len, b_req_mtp_start_loc, b_req_idx, mtp_accept_len + ) accepted_index_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( key="accepted_index", gpu_tensor=accepted_index, @@ -587,7 +591,6 @@ def _draft_decode_eagle( real_req_num = req_num // (self.mtp_step + 1) padded_req_num = model_input.batch_size // (self.mtp_step + 1) - real_req_num - eagle_mem_indexes_cpu = None if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(real_req_num * self.mtp_step) eagle_mem_indexes_cpu = g_infer_context.req_manager.mem_manager.alloc(real_req_num * self.mtp_step) @@ -742,7 +745,6 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf b_mtp_index_cpu0 = model_input0.b_mtp_index b_mtp_index_cpu1 = model_input1.b_mtp_index with torch.cuda.stream(g_infer_context.get_overlap_stream()): - model_output0, model_output1 = self.model.microbatch_overlap_decode(model_input0, model_input1) logits0 = model_output0.logits logits1 = model_output1.logits @@ -773,6 +775,9 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf b_req_idx=b_req_idx, b_req_mtp_start_loc=b_req_mtp_start_loc, ) + scatter_mtp_accept_len( + self.model.req_manager.req_to_accept_len, b_req_mtp_start_loc, b_req_idx, mtp_accept_len + ) accepted_index_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( key="accepted_index", gpu_tensor=accepted_index, @@ -879,7 +884,7 @@ def _draft_decode_vanilla_overlap( draft_next_token_ids_gpu1 = torch.zeros((model_input1.batch_size), dtype=torch.int64, device="cuda") if req_num0 > 0: draft_next_token_ids_gpu0[0:req_num0].copy_(next_token_ids[0:req_num0], non_blocking=True) - if req_num1 > 1: + if req_num1 > 0: draft_next_token_ids_gpu1[0:req_num1].copy_( next_token_ids[req_num0 : (req_num0 + req_num1)], non_blocking=True ) @@ -937,7 +942,7 @@ def _draft_decode_eagle_overlap( draft_next_token_ids_gpu1 = torch.zeros((model_input1.batch_size), dtype=torch.int64, device="cuda") if req_num0 > 0: draft_next_token_ids_gpu0[0:req_num0].copy_(next_token_ids[0:req_num0], non_blocking=True) - if req_num1 > 1: + if req_num1 > 0: draft_next_token_ids_gpu1[0:req_num1].copy_( next_token_ids[req_num0 : (req_num0 + req_num1)], non_blocking=True ) diff --git a/lightllm/server/router/model_infer/mode_backend/mtp_model_factory.py b/lightllm/server/router/model_infer/mode_backend/mtp_model_factory.py new file mode 100644 index 000000000..1b4ade1ac --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/mtp_model_factory.py @@ -0,0 +1,33 @@ +from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel +from lightllm.models.qwen3_moe_mtp.model import Qwen3MOEMTPModel +from lightllm.models.mistral_mtp.model import MistralMTPModel +from lightllm.models.glm4_moe_lite_mtp.model import Glm4MoeLiteMTPModel + + +def create_mtp_draft_model(model_type: str, mtp_mode: str, mtp_model_kvargs: dict): + """Single source of truth for (model_type, mtp_mode) -> MTP draft model (#10). + Shared by base_backend and the static MTP benchmark.""" + if model_type == "deepseek_v3": + assert mtp_mode in ["vanilla_with_att", "eagle_with_att"] + return Deepseek3MTPModel(mtp_model_kvargs) + elif model_type == "qwen3_moe": + assert mtp_mode in ["vanilla_no_att", "eagle_no_att"] + return Qwen3MOEMTPModel(mtp_model_kvargs) + elif model_type == "mistral": + assert mtp_mode in ["vanilla_no_att", "eagle_no_att"] + return MistralMTPModel(mtp_model_kvargs) + elif model_type == "glm4_moe_lite": + assert mtp_mode in ["vanilla_with_att", "eagle_with_att"] + return Glm4MoeLiteMTPModel(mtp_model_kvargs) + elif model_type in ("qwen3_5", "qwen3_5_text"): + assert mtp_mode in ["vanilla_with_att", "eagle_with_att"] + from lightllm.models.qwen3_5_mtp.model import Qwen3_5MTPModel + + return Qwen3_5MTPModel(mtp_model_kvargs) + elif model_type in ("qwen3_5_moe", "qwen3_5_moe_text"): + assert mtp_mode in ["vanilla_with_att", "eagle_with_att"] + from lightllm.models.qwen3_5_moe_mtp.model import Qwen3_5MoeMTPModel + + return Qwen3_5MoeMTPModel(mtp_model_kvargs) + else: + raise ValueError(f"Unsupported MTP model type: {model_type}") diff --git a/lightllm/utils/kv_cache_utils.py b/lightllm/utils/kv_cache_utils.py index 494908cb1..ff5ad0127 100644 --- a/lightllm/utils/kv_cache_utils.py +++ b/lightllm/utils/kv_cache_utils.py @@ -120,8 +120,8 @@ def calcu_cpu_cache_meta() -> "CpuKVCacheMeta": if args.mtp_mode is not None: # TODO 可能会存在不同mtp模式的精度问题 - assert is_linear_att_mixed_model(args.model_dir) is False, "linear att mixed model does not support mtp mode" - cpu_cache_meta.layer_num += get_added_mtp_kv_layer_num() + if not is_linear_att_mixed_model(args.model_dir): + cpu_cache_meta.layer_num += get_added_mtp_kv_layer_num() cpu_cache_page_num = int( (args.cpu_cache_storage_size * 1024 * 1024 * 1024) / (cpu_cache_meta.calcu_one_page_size()) diff --git a/unit_tests/models/qwen3next/test_fused_recurrent_strided.py b/unit_tests/models/qwen3next/test_fused_recurrent_strided.py index cf9d06ec9..8464ca210 100644 --- a/unit_tests/models/qwen3next/test_fused_recurrent_strided.py +++ b/unit_tests/models/qwen3next/test_fused_recurrent_strided.py @@ -60,23 +60,11 @@ def run(q_, k_, v_, a_, b_, state): assert torch.equal(state_ref, state_strided) -def test_cu_seqlens_is_not_supported(): - """The fused recurrent kernel is decode-only in LightLLM's Qwen3Next path.""" - H, HV, K, V = 2, 2, 4, 4 - q = torch.randn(1, 2, H, K, device="cuda", dtype=torch.bfloat16) - k = torch.randn(1, 2, H, K, device="cuda", dtype=torch.bfloat16) - v = torch.randn(1, 2, HV, V, device="cuda", dtype=torch.bfloat16) - initial_state = torch.randn(1, HV, K, V, device="cuda", dtype=torch.bfloat16) - cu_seqlens = torch.tensor([0, 2], device="cuda", dtype=torch.long) - - with pytest.raises(AssertionError, match="decode-only fused recurrent kernel"): - fused_recurrent_gated_delta_rule( - q=q, - k=k, - v=v, - initial_state=initial_state, - cu_seqlens=cu_seqlens, - ) +# NOTE: the decode-only `cu_seqlens is None` contract from upstream #1349 was +# intentionally lifted on this branch so the Qwen3Next MTP verify path can drive +# the kernel with variable-length verify chunks (cu_seqlens + 2D SSM index +# rows). That varlen path is exercised end-to-end by the MTP GSM8K accuracy +# check rather than a hand-rolled unit test. if __name__ == "__main__":