From 52bfacba4a20bcbfe43235d73ed369ca35b5c433 Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Mon, 15 Jun 2026 14:04:46 +0800 Subject: [PATCH 1/4] feat: opt fa3 and flashinfer --- lightllm/common/basemodel/attention/fa3/fp.py | 1 + .../basemodel/attention/flashinfer/fp.py | 85 ++++++++++++++++--- .../basemodel/attention/flashinfer/mla.py | 7 +- lightllm/common/basemodel/basemodel.py | 9 +- lightllm/common/basemodel/batch_objs.py | 3 + lightllm/common/basemodel/infer_struct.py | 2 + .../triton_kernel/repack_kv_index.py | 2 - lightllm/distributed/flashinfer_all_reduce.py | 1 + 8 files changed, 95 insertions(+), 15 deletions(-) diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py index 952bb39d9..fb81006c2 100644 --- a/lightllm/common/basemodel/attention/fa3/fp.py +++ b/lightllm/common/basemodel/attention/fa3/fp.py @@ -235,6 +235,7 @@ def _normal_decode_att( causal=True, window_size=window_size, softcap=0.0, + num_splits=32, k_descale=k_descale, v_descale=v_descale, return_softmax_lse=False, diff --git a/lightllm/common/basemodel/attention/flashinfer/fp.py b/lightllm/common/basemodel/attention/flashinfer/fp.py index 91a004ec2..554d406c1 100644 --- a/lightllm/common/basemodel/attention/flashinfer/fp.py +++ b/lightllm/common/basemodel/attention/flashinfer/fp.py @@ -6,6 +6,49 @@ from .env_utils import set_flashinfer_envs +def _fast_plan_tensor_core_decode( + decode_wrapper, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + indptr_host, + kv_lens_arr_host, + max_kv_len, +): + batch_size = len(kv_lens_arr_host) + qo_indptr_host = getattr(decode_wrapper, "_qo_indptr_host", None) + if qo_indptr_host is None or len(qo_indptr_host) != batch_size + 1: + from flashinfer.decode import _get_range_buf + + qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") + decode_wrapper._qo_indptr_host = qo_indptr_host + + decode_wrapper._max_kv_len = max_kv_len + + args = [ + decode_wrapper._float_workspace_buffer, + decode_wrapper._int_workspace_buffer, + decode_wrapper._pin_memory_int_workspace_buffer, + qo_indptr_host, + indptr_host, + kv_lens_arr_host, + batch_size, + batch_size, + num_qo_heads, + num_kv_heads, + page_size, + decode_wrapper.is_cuda_graph_enabled, + head_dim, + head_dim, + False, + -1, + ] + if decode_wrapper._backend == "fa2": + args.extend([-1, False, 0]) + decode_wrapper._plan_info = decode_wrapper._cached_module.plan(*args) + + class FlashInferAttBackend(BaseAttBackend): def __init__(self, model): set_flashinfer_envs() @@ -25,6 +68,10 @@ def __init__(self, model): model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() ), ] + self.kv_starts_host_buffer = [ + torch.empty((model.graph_max_batch_size + 1,), dtype=torch.int32, device="cpu"), + torch.empty((model.graph_max_batch_size + 1,), dtype=torch.int32, device="cpu"), + ] self.q_data_type = model.data_type self.kv_data_type = model.data_type @@ -124,11 +171,11 @@ class FlashInferDecodeAttState(BaseDecodeAttState): kv_last_page_len_buffer: torch.Tensor = None kv_indices: torch.Tensor = None kv_starts: torch.Tensor = None + kv_starts_host: torch.Tensor = None + kv_seq_lens_host: torch.Tensor = None decode_wrapper: object = None def init_state(self): - import flashinfer - self.backend: FlashInferAttBackend = self.backend device = self.infer_state.input_ids.device model = self.backend.model @@ -156,6 +203,18 @@ def init_state(self): self.kv_indices, ) self.kv_starts = self.infer_state.b1_cu_kv_seq_len.int() + if self.infer_state.b_seq_len_cpu is not None: + self.kv_seq_lens_host = self.infer_state.b_seq_len_cpu + self.kv_starts_host = self.backend.kv_starts_host_buffer[self.infer_state.microbatch_index][ + : self.infer_state.batch_size + 1 + ] + self.kv_starts_host[0] = 0 + torch.cumsum(self.infer_state.b_seq_len_cpu, dim=0, out=self.kv_starts_host[1:]) + if self.infer_state.skip_decode_att_wrapper_init: + return + + import flashinfer + assert self.decode_wrapper is None self.decode_wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( self.backend.workspace_buffer, @@ -181,19 +240,25 @@ def init_state(self): return def copy_for_decode_cuda_graph(self, new_state: "FlashInferDecodeAttState"): - super().copy_for_decode_cuda_graph(new_state) - self.decode_wrapper.plan( - new_state.kv_starts, - new_state.kv_indices, - new_state.kv_last_page_len_buffer, + if new_state.kv_seq_lens_host is not None: + # FlashInfer tensor-core decode updates its split-kv plan at 128-token + # boundaries for this path. page_size is 1 here, so pages == tokens. + skip_plan_key = tuple((seq_len + 127) // 128 for seq_len in new_state.kv_seq_lens_host.tolist()) + if getattr(self.decode_wrapper, "_skip_plan_key", None) == skip_plan_key: + return + + _fast_plan_tensor_core_decode( + self.decode_wrapper, new_state.backend.tp_q_head_num, new_state.backend.tp_kv_head_num, new_state.backend.head_dim, 1, - q_data_type=new_state.backend.q_data_type, - kv_data_type=new_state.backend.kv_data_type, - non_blocking=True, + new_state.kv_starts_host, + new_state.kv_seq_lens_host, + new_state.infer_state.max_kv_seq_len, ) + if new_state.kv_seq_lens_host is not None: + self.decode_wrapper._skip_plan_key = skip_plan_key def decode_att( self, diff --git a/lightllm/common/basemodel/attention/flashinfer/mla.py b/lightllm/common/basemodel/attention/flashinfer/mla.py index 84b44dc45..4c80f1848 100644 --- a/lightllm/common/basemodel/attention/flashinfer/mla.py +++ b/lightllm/common/basemodel/attention/flashinfer/mla.py @@ -116,8 +116,6 @@ class MlaFlashInferDecodeAttState(BaseDecodeAttState): decode_wrapper: object = None def init_state(self): - import flashinfer - self.backend: MlaFlashInferAttBackend = self.backend model = self.backend.model device = self.infer_state.input_ids.device @@ -145,6 +143,11 @@ def init_state(self): self.infer_state.max_kv_seq_len, self.kv_indices, ) + if self.infer_state.skip_decode_att_wrapper_init: + return + + import flashinfer + assert self.decode_wrapper is None self.decode_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 94f9d4c1a..497e7e7a2 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -314,6 +314,7 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0) assert model_input.b_req_idx.shape[0] == model_input.b_seq_len.shape[0] infer_state.b_req_idx = model_input.b_req_idx infer_state.b_seq_len = model_input.b_seq_len + infer_state.b_seq_len_cpu = model_input.b_seq_len_cpu infer_state.b_mtp_index = model_input.b_mtp_index if model_input.is_prefill: if model_input.b_ready_cache_len is not None: @@ -371,6 +372,10 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s new_model_input.b_mtp_index, (0, padded_batch_size), mode="constant", value=0 ) new_model_input.b_seq_len = F.pad(new_model_input.b_seq_len, (0, padded_batch_size), mode="constant", value=2) + if new_model_input.b_seq_len_cpu is not None: + new_model_input.b_seq_len_cpu = F.pad( + new_model_input.b_seq_len_cpu, (0, padded_batch_size), mode="constant", value=2 + ) new_model_input.mem_indexes = F.pad( new_model_input.mem_indexes, (0, padded_batch_size), @@ -562,6 +567,8 @@ def _decode( model_input=model_input, new_batch_size=infer_batch_size ) infer_state = self._create_inferstate(model_input) + need_capture = self.graph.need_capture(infer_batch_size) + infer_state.skip_decode_att_wrapper_init = not need_capture copy_kv_index_to_req( self.req_manager.req_to_token_indexs, infer_state.b_req_idx, @@ -571,7 +578,7 @@ def _decode( infer_state.init_some_extra_state(self) infer_state.init_att_state() - if self.graph.need_capture(infer_batch_size): + if need_capture: infer_state.is_cuda_graph = True model_output: ModelOutput = self.graph.capture_decode(self._token_forward, infer_state) else: diff --git a/lightllm/common/basemodel/batch_objs.py b/lightllm/common/basemodel/batch_objs.py index 1795ff9a8..81bf3cfd5 100644 --- a/lightllm/common/basemodel/batch_objs.py +++ b/lightllm/common/basemodel/batch_objs.py @@ -42,6 +42,7 @@ class ModelInput: multimodal_params: list = None # cpu 变量 mem_indexes_cpu: torch.Tensor = None + b_seq_len_cpu: torch.Tensor = None # prefill 阶段使用的参数,但是不是推理过程使用的参数,是推理外部进行资源管理 # 的一些变量 b_prefill_has_output_cpu: List[bool] = None # 标记进行prefill的请求是否具有输出 @@ -64,6 +65,8 @@ def to_cuda(self): assert self.is_prefill self.b_req_idx = self.b_req_idx.cuda(non_blocking=True) + if not self.b_seq_len.is_cuda: + self.b_seq_len_cpu = self.b_seq_len self.b_seq_len = self.b_seq_len.cuda(non_blocking=True) self.b_mtp_index = self.b_mtp_index.cuda(non_blocking=True) if self.b_ready_cache_len is not None: diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index 711484c83..575f1ee25 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -40,6 +40,7 @@ def __init__(self): self.b_mtp_index: torch.Tensor = None self.b_seq_len: torch.Tensor = None + self.b_seq_len_cpu: torch.Tensor = None # max_cache_len 用于 prefill 阶段标识请求中最大 cache的kv 的长度 self.max_cache_len: int = None # prefix_total_token_num 用于 prefill 阶段标识当前请求中所有已经ready的kv的长度 @@ -56,6 +57,7 @@ def __init__(self): self.return_all_prompt_logics: bool = False self.multimodal_params: dict = None self.is_cuda_graph: bool = False # 标记是否是cuda graph的捕获推理 + self.skip_decode_att_wrapper_init: bool = False self.dist_group: CustomProcessGroup = None # 在microbatch overlap的运行模式下,用于标记当前 microbatch 的 index 序号 diff --git a/lightllm/common/basemodel/triton_kernel/repack_kv_index.py b/lightllm/common/basemodel/triton_kernel/repack_kv_index.py index e86d2e819..c218d15e0 100644 --- a/lightllm/common/basemodel/triton_kernel/repack_kv_index.py +++ b/lightllm/common/basemodel/triton_kernel/repack_kv_index.py @@ -36,8 +36,6 @@ def _fwd_kernel_repack_kv_index( @torch.no_grad() def repack_kv_index(kv_index, req_index, seq_len, start_loc, max_seq_len, out_kv_index): batch_size = req_index.shape[0] - # flashinfer requires out_kv_index to be zeroed before use - out_kv_index.zero_() BLOCK = 64 grid = ( batch_size, diff --git a/lightllm/distributed/flashinfer_all_reduce.py b/lightllm/distributed/flashinfer_all_reduce.py index 27856d9ac..f2dba1272 100644 --- a/lightllm/distributed/flashinfer_all_reduce.py +++ b/lightllm/distributed/flashinfer_all_reduce.py @@ -132,4 +132,5 @@ def all_reduce(self, inp: torch.Tensor) -> torch.Tensor: input=inp, workspace=self._workspace, pattern=flashinfer_comm.AllReduceFusionPattern.kAllReduce, + launch_with_pdl=True, ) From 3377848dd42baa25195ecc1ac6e107caf1b7f7ac Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Wed, 24 Jun 2026 10:07:56 +0800 Subject: [PATCH 2/4] feat: use flashinfer fast_plan --- .../basemodel/attention/flashinfer/fp.py | 75 ++++--------------- 1 file changed, 14 insertions(+), 61 deletions(-) diff --git a/lightllm/common/basemodel/attention/flashinfer/fp.py b/lightllm/common/basemodel/attention/flashinfer/fp.py index 554d406c1..d640dd863 100644 --- a/lightllm/common/basemodel/attention/flashinfer/fp.py +++ b/lightllm/common/basemodel/attention/flashinfer/fp.py @@ -6,49 +6,6 @@ from .env_utils import set_flashinfer_envs -def _fast_plan_tensor_core_decode( - decode_wrapper, - num_qo_heads, - num_kv_heads, - head_dim, - page_size, - indptr_host, - kv_lens_arr_host, - max_kv_len, -): - batch_size = len(kv_lens_arr_host) - qo_indptr_host = getattr(decode_wrapper, "_qo_indptr_host", None) - if qo_indptr_host is None or len(qo_indptr_host) != batch_size + 1: - from flashinfer.decode import _get_range_buf - - qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") - decode_wrapper._qo_indptr_host = qo_indptr_host - - decode_wrapper._max_kv_len = max_kv_len - - args = [ - decode_wrapper._float_workspace_buffer, - decode_wrapper._int_workspace_buffer, - decode_wrapper._pin_memory_int_workspace_buffer, - qo_indptr_host, - indptr_host, - kv_lens_arr_host, - batch_size, - batch_size, - num_qo_heads, - num_kv_heads, - page_size, - decode_wrapper.is_cuda_graph_enabled, - head_dim, - head_dim, - False, - -1, - ] - if decode_wrapper._backend == "fa2": - args.extend([-1, False, 0]) - decode_wrapper._plan_info = decode_wrapper._cached_module.plan(*args) - - class FlashInferAttBackend(BaseAttBackend): def __init__(self, model): set_flashinfer_envs() @@ -172,7 +129,6 @@ class FlashInferDecodeAttState(BaseDecodeAttState): kv_indices: torch.Tensor = None kv_starts: torch.Tensor = None kv_starts_host: torch.Tensor = None - kv_seq_lens_host: torch.Tensor = None decode_wrapper: object = None def init_state(self): @@ -204,7 +160,6 @@ def init_state(self): ) self.kv_starts = self.infer_state.b1_cu_kv_seq_len.int() if self.infer_state.b_seq_len_cpu is not None: - self.kv_seq_lens_host = self.infer_state.b_seq_len_cpu self.kv_starts_host = self.backend.kv_starts_host_buffer[self.infer_state.microbatch_index][ : self.infer_state.batch_size + 1 ] @@ -240,25 +195,23 @@ def init_state(self): return def copy_for_decode_cuda_graph(self, new_state: "FlashInferDecodeAttState"): - if new_state.kv_seq_lens_host is not None: - # FlashInfer tensor-core decode updates its split-kv plan at 128-token - # boundaries for this path. page_size is 1 here, so pages == tokens. - skip_plan_key = tuple((seq_len + 127) // 128 for seq_len in new_state.kv_seq_lens_host.tolist()) - if getattr(self.decode_wrapper, "_skip_plan_key", None) == skip_plan_key: - return + from flashinfer.decode import fast_decode_plan - _fast_plan_tensor_core_decode( + fast_decode_plan( self.decode_wrapper, - new_state.backend.tp_q_head_num, - new_state.backend.tp_kv_head_num, - new_state.backend.head_dim, - 1, - new_state.kv_starts_host, - new_state.kv_seq_lens_host, - new_state.infer_state.max_kv_seq_len, + indptr=new_state.kv_starts, + indices=new_state.kv_indices, + last_page_len=new_state.kv_last_page_len_buffer, + num_qo_heads=new_state.backend.tp_q_head_num, + num_kv_heads=new_state.backend.tp_kv_head_num, + head_dim=new_state.backend.head_dim, + page_size=1, + q_data_type=new_state.backend.q_data_type, + kv_data_type=new_state.backend.kv_data_type, + non_blocking=True, + global_override_indptr_cpu=new_state.kv_starts_host, ) - if new_state.kv_seq_lens_host is not None: - self.decode_wrapper._skip_plan_key = skip_plan_key + self.decode_wrapper._max_kv_len = new_state.infer_state.max_kv_seq_len def decode_att( self, From e5246a615dca79bd46b17324277810c26b11a452 Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Thu, 25 Jun 2026 10:41:55 +0800 Subject: [PATCH 3/4] feat: add fast plan for mla.py --- .../basemodel/attention/flashinfer/fp.py | 1 - .../basemodel/attention/flashinfer/mla.py | 56 ++++++++++++++++--- 2 files changed, 47 insertions(+), 10 deletions(-) diff --git a/lightllm/common/basemodel/attention/flashinfer/fp.py b/lightllm/common/basemodel/attention/flashinfer/fp.py index d640dd863..c71254072 100644 --- a/lightllm/common/basemodel/attention/flashinfer/fp.py +++ b/lightllm/common/basemodel/attention/flashinfer/fp.py @@ -211,7 +211,6 @@ def copy_for_decode_cuda_graph(self, new_state: "FlashInferDecodeAttState"): non_blocking=True, global_override_indptr_cpu=new_state.kv_starts_host, ) - self.decode_wrapper._max_kv_len = new_state.infer_state.max_kv_seq_len def decode_att( self, diff --git a/lightllm/common/basemodel/attention/flashinfer/mla.py b/lightllm/common/basemodel/attention/flashinfer/mla.py index 4c80f1848..0c8bcd492 100644 --- a/lightllm/common/basemodel/attention/flashinfer/mla.py +++ b/lightllm/common/basemodel/attention/flashinfer/mla.py @@ -7,6 +7,33 @@ from .env_utils import set_flashinfer_envs +def _fast_plan_mla_decode( + decode_wrapper, + qo_indptr_cpu, + kv_indptr_cpu, + kv_len_arr_cpu, + num_heads, + head_dim_ckv, + page_size, + causal, + sm_scale, +): + decode_wrapper._causal = causal + decode_wrapper._page_size = page_size + decode_wrapper._sm_scale = sm_scale + decode_wrapper._plan_info = decode_wrapper._cached_module.plan( + decode_wrapper._float_workspace_buffer, + decode_wrapper._int_workspace_buffer, + decode_wrapper._pin_memory_int_workspace_buffer, + qo_indptr_cpu, + kv_indptr_cpu, + kv_len_arr_cpu, + num_heads, + head_dim_ckv, + causal, + ) + + class MlaFlashInferAttBackend(BaseAttBackend): def __init__(self, model): set_flashinfer_envs() @@ -30,6 +57,10 @@ def __init__(self, model): model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() ), ] + self.kv_starts_host_buffer = [ + torch.empty((model.graph_max_batch_size + 1,), dtype=torch.int32, device="cpu"), + torch.empty((model.graph_max_batch_size + 1,), dtype=torch.int32, device="cpu"), + ] from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale @@ -113,6 +144,8 @@ def _mla_prefill_att( class MlaFlashInferDecodeAttState(BaseDecodeAttState): kv_indices: torch.Tensor = None kv_starts: torch.Tensor = None + q_indptr_host: torch.Tensor = None + kv_starts_host: torch.Tensor = None decode_wrapper: object = None def init_state(self): @@ -124,6 +157,7 @@ def init_state(self): self.kv_starts = self.infer_state.b1_cu_kv_seq_len self.q_indptr = torch.arange(batch_size + 1, dtype=torch.int32, device="cuda") + self.q_indptr_host = torch.arange(batch_size + 1, dtype=torch.int32, device="cpu") if batch_size <= model.graph_max_batch_size and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch: self.kv_indices = self.backend.kv_indices_buffer[self.infer_state.microbatch_index][ : batch_size * self.backend.max_seq_length @@ -143,6 +177,12 @@ def init_state(self): self.infer_state.max_kv_seq_len, self.kv_indices, ) + if self.infer_state.b_seq_len_cpu is not None: + self.kv_starts_host = self.backend.kv_starts_host_buffer[self.infer_state.microbatch_index][ + : batch_size + 1 + ] + self.kv_starts_host[0] = 0 + torch.cumsum(self.infer_state.b_seq_len_cpu, dim=0, out=self.kv_starts_host[1:]) if self.infer_state.skip_decode_att_wrapper_init: return @@ -175,20 +215,18 @@ def init_state(self): return def copy_for_decode_cuda_graph(self, new_state: "MlaFlashInferDecodeAttState"): - super().copy_for_decode_cuda_graph(new_state) - self.decode_wrapper.plan( - new_state.q_indptr, - new_state.kv_starts, - new_state.kv_indices, - new_state.infer_state.b_seq_len, + assert new_state.kv_starts_host is not None + assert new_state.infer_state.b_seq_len_cpu is not None + _fast_plan_mla_decode( + self.decode_wrapper, + new_state.q_indptr_host, + new_state.kv_starts_host, + new_state.infer_state.b_seq_len_cpu, new_state.backend.tp_q_head_num, new_state.backend.kv_lora_rank, - new_state.backend.qk_rope_head_dim, 1, False, # causal new_state.backend.softmax_scale, - new_state.backend.q_data_type, - new_state.backend.kv_data_type, ) def decode_att( From 910b6c06fde3e357ceebae43500bc95a6ffc00b5 Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Mon, 29 Jun 2026 10:41:31 +0800 Subject: [PATCH 4/4] feat: remove num_splits=32 --- lightllm/common/basemodel/attention/fa3/fp.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py index fb81006c2..952bb39d9 100644 --- a/lightllm/common/basemodel/attention/fa3/fp.py +++ b/lightllm/common/basemodel/attention/fa3/fp.py @@ -235,7 +235,6 @@ def _normal_decode_att( causal=True, window_size=window_size, softcap=0.0, - num_splits=32, k_descale=k_descale, v_descale=v_descale, return_softmax_lse=False,