diff --git a/lightllm/common/basemodel/attention/flashinfer/fp.py b/lightllm/common/basemodel/attention/flashinfer/fp.py index 91a004ec2..c71254072 100644 --- a/lightllm/common/basemodel/attention/flashinfer/fp.py +++ b/lightllm/common/basemodel/attention/flashinfer/fp.py @@ -25,6 +25,10 @@ def __init__(self, model): model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() ), ] + self.kv_starts_host_buffer = [ + torch.empty((model.graph_max_batch_size + 1,), dtype=torch.int32, device="cpu"), + torch.empty((model.graph_max_batch_size + 1,), dtype=torch.int32, device="cpu"), + ] self.q_data_type = model.data_type self.kv_data_type = model.data_type @@ -124,11 +128,10 @@ class FlashInferDecodeAttState(BaseDecodeAttState): kv_last_page_len_buffer: torch.Tensor = None kv_indices: torch.Tensor = None kv_starts: torch.Tensor = None + kv_starts_host: torch.Tensor = None decode_wrapper: object = None def init_state(self): - import flashinfer - self.backend: FlashInferAttBackend = self.backend device = self.infer_state.input_ids.device model = self.backend.model @@ -156,6 +159,17 @@ def init_state(self): self.kv_indices, ) self.kv_starts = self.infer_state.b1_cu_kv_seq_len.int() + if self.infer_state.b_seq_len_cpu is not None: + self.kv_starts_host = self.backend.kv_starts_host_buffer[self.infer_state.microbatch_index][ + : self.infer_state.batch_size + 1 + ] + self.kv_starts_host[0] = 0 + torch.cumsum(self.infer_state.b_seq_len_cpu, dim=0, out=self.kv_starts_host[1:]) + if self.infer_state.skip_decode_att_wrapper_init: + return + + import flashinfer + assert self.decode_wrapper is None self.decode_wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( self.backend.workspace_buffer, @@ -181,18 +195,21 @@ def init_state(self): return def copy_for_decode_cuda_graph(self, new_state: "FlashInferDecodeAttState"): - super().copy_for_decode_cuda_graph(new_state) - self.decode_wrapper.plan( - new_state.kv_starts, - new_state.kv_indices, - new_state.kv_last_page_len_buffer, - new_state.backend.tp_q_head_num, - new_state.backend.tp_kv_head_num, - new_state.backend.head_dim, - 1, + from flashinfer.decode import fast_decode_plan + + fast_decode_plan( + self.decode_wrapper, + indptr=new_state.kv_starts, + indices=new_state.kv_indices, + last_page_len=new_state.kv_last_page_len_buffer, + num_qo_heads=new_state.backend.tp_q_head_num, + num_kv_heads=new_state.backend.tp_kv_head_num, + head_dim=new_state.backend.head_dim, + page_size=1, q_data_type=new_state.backend.q_data_type, kv_data_type=new_state.backend.kv_data_type, non_blocking=True, + global_override_indptr_cpu=new_state.kv_starts_host, ) def decode_att( diff --git a/lightllm/common/basemodel/attention/flashinfer/mla.py b/lightllm/common/basemodel/attention/flashinfer/mla.py index 84b44dc45..0c8bcd492 100644 --- a/lightllm/common/basemodel/attention/flashinfer/mla.py +++ b/lightllm/common/basemodel/attention/flashinfer/mla.py @@ -7,6 +7,33 @@ from .env_utils import set_flashinfer_envs +def _fast_plan_mla_decode( + decode_wrapper, + qo_indptr_cpu, + kv_indptr_cpu, + kv_len_arr_cpu, + num_heads, + head_dim_ckv, + page_size, + causal, + sm_scale, +): + decode_wrapper._causal = causal + decode_wrapper._page_size = page_size + decode_wrapper._sm_scale = sm_scale + decode_wrapper._plan_info = decode_wrapper._cached_module.plan( + decode_wrapper._float_workspace_buffer, + decode_wrapper._int_workspace_buffer, + decode_wrapper._pin_memory_int_workspace_buffer, + qo_indptr_cpu, + kv_indptr_cpu, + kv_len_arr_cpu, + num_heads, + head_dim_ckv, + causal, + ) + + class MlaFlashInferAttBackend(BaseAttBackend): def __init__(self, model): set_flashinfer_envs() @@ -30,6 +57,10 @@ def __init__(self, model): model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() ), ] + self.kv_starts_host_buffer = [ + torch.empty((model.graph_max_batch_size + 1,), dtype=torch.int32, device="cpu"), + torch.empty((model.graph_max_batch_size + 1,), dtype=torch.int32, device="cpu"), + ] from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale @@ -113,11 +144,11 @@ def _mla_prefill_att( class MlaFlashInferDecodeAttState(BaseDecodeAttState): kv_indices: torch.Tensor = None kv_starts: torch.Tensor = None + q_indptr_host: torch.Tensor = None + kv_starts_host: torch.Tensor = None decode_wrapper: object = None def init_state(self): - import flashinfer - self.backend: MlaFlashInferAttBackend = self.backend model = self.backend.model device = self.infer_state.input_ids.device @@ -126,6 +157,7 @@ def init_state(self): self.kv_starts = self.infer_state.b1_cu_kv_seq_len self.q_indptr = torch.arange(batch_size + 1, dtype=torch.int32, device="cuda") + self.q_indptr_host = torch.arange(batch_size + 1, dtype=torch.int32, device="cpu") if batch_size <= model.graph_max_batch_size and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch: self.kv_indices = self.backend.kv_indices_buffer[self.infer_state.microbatch_index][ : batch_size * self.backend.max_seq_length @@ -145,6 +177,17 @@ def init_state(self): self.infer_state.max_kv_seq_len, self.kv_indices, ) + if self.infer_state.b_seq_len_cpu is not None: + self.kv_starts_host = self.backend.kv_starts_host_buffer[self.infer_state.microbatch_index][ + : batch_size + 1 + ] + self.kv_starts_host[0] = 0 + torch.cumsum(self.infer_state.b_seq_len_cpu, dim=0, out=self.kv_starts_host[1:]) + if self.infer_state.skip_decode_att_wrapper_init: + return + + import flashinfer + assert self.decode_wrapper is None self.decode_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( @@ -172,20 +215,18 @@ def init_state(self): return def copy_for_decode_cuda_graph(self, new_state: "MlaFlashInferDecodeAttState"): - super().copy_for_decode_cuda_graph(new_state) - self.decode_wrapper.plan( - new_state.q_indptr, - new_state.kv_starts, - new_state.kv_indices, - new_state.infer_state.b_seq_len, + assert new_state.kv_starts_host is not None + assert new_state.infer_state.b_seq_len_cpu is not None + _fast_plan_mla_decode( + self.decode_wrapper, + new_state.q_indptr_host, + new_state.kv_starts_host, + new_state.infer_state.b_seq_len_cpu, new_state.backend.tp_q_head_num, new_state.backend.kv_lora_rank, - new_state.backend.qk_rope_head_dim, 1, False, # causal new_state.backend.softmax_scale, - new_state.backend.q_data_type, - new_state.backend.kv_data_type, ) def decode_att( diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 94f9d4c1a..497e7e7a2 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -314,6 +314,7 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0) assert model_input.b_req_idx.shape[0] == model_input.b_seq_len.shape[0] infer_state.b_req_idx = model_input.b_req_idx infer_state.b_seq_len = model_input.b_seq_len + infer_state.b_seq_len_cpu = model_input.b_seq_len_cpu infer_state.b_mtp_index = model_input.b_mtp_index if model_input.is_prefill: if model_input.b_ready_cache_len is not None: @@ -371,6 +372,10 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s new_model_input.b_mtp_index, (0, padded_batch_size), mode="constant", value=0 ) new_model_input.b_seq_len = F.pad(new_model_input.b_seq_len, (0, padded_batch_size), mode="constant", value=2) + if new_model_input.b_seq_len_cpu is not None: + new_model_input.b_seq_len_cpu = F.pad( + new_model_input.b_seq_len_cpu, (0, padded_batch_size), mode="constant", value=2 + ) new_model_input.mem_indexes = F.pad( new_model_input.mem_indexes, (0, padded_batch_size), @@ -562,6 +567,8 @@ def _decode( model_input=model_input, new_batch_size=infer_batch_size ) infer_state = self._create_inferstate(model_input) + need_capture = self.graph.need_capture(infer_batch_size) + infer_state.skip_decode_att_wrapper_init = not need_capture copy_kv_index_to_req( self.req_manager.req_to_token_indexs, infer_state.b_req_idx, @@ -571,7 +578,7 @@ def _decode( infer_state.init_some_extra_state(self) infer_state.init_att_state() - if self.graph.need_capture(infer_batch_size): + if need_capture: infer_state.is_cuda_graph = True model_output: ModelOutput = self.graph.capture_decode(self._token_forward, infer_state) else: diff --git a/lightllm/common/basemodel/batch_objs.py b/lightllm/common/basemodel/batch_objs.py index 1795ff9a8..81bf3cfd5 100644 --- a/lightllm/common/basemodel/batch_objs.py +++ b/lightllm/common/basemodel/batch_objs.py @@ -42,6 +42,7 @@ class ModelInput: multimodal_params: list = None # cpu 变量 mem_indexes_cpu: torch.Tensor = None + b_seq_len_cpu: torch.Tensor = None # prefill 阶段使用的参数,但是不是推理过程使用的参数,是推理外部进行资源管理 # 的一些变量 b_prefill_has_output_cpu: List[bool] = None # 标记进行prefill的请求是否具有输出 @@ -64,6 +65,8 @@ def to_cuda(self): assert self.is_prefill self.b_req_idx = self.b_req_idx.cuda(non_blocking=True) + if not self.b_seq_len.is_cuda: + self.b_seq_len_cpu = self.b_seq_len self.b_seq_len = self.b_seq_len.cuda(non_blocking=True) self.b_mtp_index = self.b_mtp_index.cuda(non_blocking=True) if self.b_ready_cache_len is not None: diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index 711484c83..575f1ee25 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -40,6 +40,7 @@ def __init__(self): self.b_mtp_index: torch.Tensor = None self.b_seq_len: torch.Tensor = None + self.b_seq_len_cpu: torch.Tensor = None # max_cache_len 用于 prefill 阶段标识请求中最大 cache的kv 的长度 self.max_cache_len: int = None # prefix_total_token_num 用于 prefill 阶段标识当前请求中所有已经ready的kv的长度 @@ -56,6 +57,7 @@ def __init__(self): self.return_all_prompt_logics: bool = False self.multimodal_params: dict = None self.is_cuda_graph: bool = False # 标记是否是cuda graph的捕获推理 + self.skip_decode_att_wrapper_init: bool = False self.dist_group: CustomProcessGroup = None # 在microbatch overlap的运行模式下,用于标记当前 microbatch 的 index 序号 diff --git a/lightllm/common/basemodel/triton_kernel/repack_kv_index.py b/lightllm/common/basemodel/triton_kernel/repack_kv_index.py index e86d2e819..c218d15e0 100644 --- a/lightllm/common/basemodel/triton_kernel/repack_kv_index.py +++ b/lightllm/common/basemodel/triton_kernel/repack_kv_index.py @@ -36,8 +36,6 @@ def _fwd_kernel_repack_kv_index( @torch.no_grad() def repack_kv_index(kv_index, req_index, seq_len, start_loc, max_seq_len, out_kv_index): batch_size = req_index.shape[0] - # flashinfer requires out_kv_index to be zeroed before use - out_kv_index.zero_() BLOCK = 64 grid = ( batch_size, diff --git a/lightllm/distributed/flashinfer_all_reduce.py b/lightllm/distributed/flashinfer_all_reduce.py index 27856d9ac..f2dba1272 100644 --- a/lightllm/distributed/flashinfer_all_reduce.py +++ b/lightllm/distributed/flashinfer_all_reduce.py @@ -132,4 +132,5 @@ def all_reduce(self, inp: torch.Tensor) -> torch.Tensor: input=inp, workspace=self._workspace, pattern=flashinfer_comm.AllReduceFusionPattern.kAllReduce, + launch_with_pdl=True, )