Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lightllm/common/basemodel/attention/fa3/fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def _normal_decode_att(
causal=True,
window_size=window_size,
softcap=0.0,
num_splits=32,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=False,
Expand Down
39 changes: 28 additions & 11 deletions lightllm/common/basemodel/attention/flashinfer/fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def __init__(self, model):
model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id()
),
]
self.kv_starts_host_buffer = [
torch.empty((model.graph_max_batch_size + 1,), dtype=torch.int32, device="cpu"),
torch.empty((model.graph_max_batch_size + 1,), dtype=torch.int32, device="cpu"),
]
self.q_data_type = model.data_type
self.kv_data_type = model.data_type

Expand Down Expand Up @@ -124,11 +128,10 @@ class FlashInferDecodeAttState(BaseDecodeAttState):
kv_last_page_len_buffer: torch.Tensor = None
kv_indices: torch.Tensor = None
kv_starts: torch.Tensor = None
kv_starts_host: torch.Tensor = None
decode_wrapper: object = None

def init_state(self):
import flashinfer

self.backend: FlashInferAttBackend = self.backend
device = self.infer_state.input_ids.device
model = self.backend.model
Expand Down Expand Up @@ -156,6 +159,17 @@ def init_state(self):
self.kv_indices,
)
self.kv_starts = self.infer_state.b1_cu_kv_seq_len.int()
if self.infer_state.b_seq_len_cpu is not None:
self.kv_starts_host = self.backend.kv_starts_host_buffer[self.infer_state.microbatch_index][
: self.infer_state.batch_size + 1
]
self.kv_starts_host[0] = 0
torch.cumsum(self.infer_state.b_seq_len_cpu, dim=0, out=self.kv_starts_host[1:])
if self.infer_state.skip_decode_att_wrapper_init:
return

import flashinfer

assert self.decode_wrapper is None
self.decode_wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper(
self.backend.workspace_buffer,
Expand All @@ -181,18 +195,21 @@ def init_state(self):
return

def copy_for_decode_cuda_graph(self, new_state: "FlashInferDecodeAttState"):
super().copy_for_decode_cuda_graph(new_state)
self.decode_wrapper.plan(
new_state.kv_starts,
new_state.kv_indices,
new_state.kv_last_page_len_buffer,
new_state.backend.tp_q_head_num,
new_state.backend.tp_kv_head_num,
new_state.backend.head_dim,
1,
from flashinfer.decode import fast_decode_plan

fast_decode_plan(
self.decode_wrapper,
indptr=new_state.kv_starts,
indices=new_state.kv_indices,
last_page_len=new_state.kv_last_page_len_buffer,
num_qo_heads=new_state.backend.tp_q_head_num,
num_kv_heads=new_state.backend.tp_kv_head_num,
head_dim=new_state.backend.head_dim,
page_size=1,
q_data_type=new_state.backend.q_data_type,
kv_data_type=new_state.backend.kv_data_type,
non_blocking=True,
global_override_indptr_cpu=new_state.kv_starts_host,
)

def decode_att(
Expand Down
63 changes: 52 additions & 11 deletions lightllm/common/basemodel/attention/flashinfer/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,33 @@
from .env_utils import set_flashinfer_envs


def _fast_plan_mla_decode(
decode_wrapper,
qo_indptr_cpu,
kv_indptr_cpu,
kv_len_arr_cpu,
num_heads,
head_dim_ckv,
page_size,
causal,
sm_scale,
):
decode_wrapper._causal = causal
decode_wrapper._page_size = page_size
decode_wrapper._sm_scale = sm_scale
decode_wrapper._plan_info = decode_wrapper._cached_module.plan(
decode_wrapper._float_workspace_buffer,
decode_wrapper._int_workspace_buffer,
decode_wrapper._pin_memory_int_workspace_buffer,
qo_indptr_cpu,
kv_indptr_cpu,
kv_len_arr_cpu,
num_heads,
head_dim_ckv,
causal,
)


class MlaFlashInferAttBackend(BaseAttBackend):
def __init__(self, model):
set_flashinfer_envs()
Expand All @@ -30,6 +57,10 @@ def __init__(self, model):
model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id()
),
]
self.kv_starts_host_buffer = [
torch.empty((model.graph_max_batch_size + 1,), dtype=torch.int32, device="cpu"),
torch.empty((model.graph_max_batch_size + 1,), dtype=torch.int32, device="cpu"),
]

from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale

Expand Down Expand Up @@ -113,11 +144,11 @@ def _mla_prefill_att(
class MlaFlashInferDecodeAttState(BaseDecodeAttState):
kv_indices: torch.Tensor = None
kv_starts: torch.Tensor = None
q_indptr_host: torch.Tensor = None
kv_starts_host: torch.Tensor = None
decode_wrapper: object = None

def init_state(self):
import flashinfer

self.backend: MlaFlashInferAttBackend = self.backend
model = self.backend.model
device = self.infer_state.input_ids.device
Expand All @@ -126,6 +157,7 @@ def init_state(self):
self.kv_starts = self.infer_state.b1_cu_kv_seq_len

self.q_indptr = torch.arange(batch_size + 1, dtype=torch.int32, device="cuda")
self.q_indptr_host = torch.arange(batch_size + 1, dtype=torch.int32, device="cpu")
if batch_size <= model.graph_max_batch_size and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch:
self.kv_indices = self.backend.kv_indices_buffer[self.infer_state.microbatch_index][
: batch_size * self.backend.max_seq_length
Expand All @@ -145,6 +177,17 @@ def init_state(self):
self.infer_state.max_kv_seq_len,
self.kv_indices,
)
if self.infer_state.b_seq_len_cpu is not None:
self.kv_starts_host = self.backend.kv_starts_host_buffer[self.infer_state.microbatch_index][
: batch_size + 1
]
self.kv_starts_host[0] = 0
torch.cumsum(self.infer_state.b_seq_len_cpu, dim=0, out=self.kv_starts_host[1:])
if self.infer_state.skip_decode_att_wrapper_init:
return

import flashinfer

assert self.decode_wrapper is None

self.decode_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
Expand Down Expand Up @@ -172,20 +215,18 @@ def init_state(self):
return

def copy_for_decode_cuda_graph(self, new_state: "MlaFlashInferDecodeAttState"):
super().copy_for_decode_cuda_graph(new_state)
self.decode_wrapper.plan(
new_state.q_indptr,
new_state.kv_starts,
new_state.kv_indices,
new_state.infer_state.b_seq_len,
assert new_state.kv_starts_host is not None
assert new_state.infer_state.b_seq_len_cpu is not None
_fast_plan_mla_decode(
self.decode_wrapper,
new_state.q_indptr_host,
new_state.kv_starts_host,
new_state.infer_state.b_seq_len_cpu,
new_state.backend.tp_q_head_num,
new_state.backend.kv_lora_rank,
new_state.backend.qk_rope_head_dim,
1,
False, # causal
new_state.backend.softmax_scale,
new_state.backend.q_data_type,
new_state.backend.kv_data_type,
)

def decode_att(
Expand Down
9 changes: 8 additions & 1 deletion lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0)
assert model_input.b_req_idx.shape[0] == model_input.b_seq_len.shape[0]
infer_state.b_req_idx = model_input.b_req_idx
infer_state.b_seq_len = model_input.b_seq_len
infer_state.b_seq_len_cpu = model_input.b_seq_len_cpu
infer_state.b_mtp_index = model_input.b_mtp_index
if model_input.is_prefill:
if model_input.b_ready_cache_len is not None:
Expand Down Expand Up @@ -371,6 +372,10 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s
new_model_input.b_mtp_index, (0, padded_batch_size), mode="constant", value=0
)
new_model_input.b_seq_len = F.pad(new_model_input.b_seq_len, (0, padded_batch_size), mode="constant", value=2)
if new_model_input.b_seq_len_cpu is not None:
new_model_input.b_seq_len_cpu = F.pad(
new_model_input.b_seq_len_cpu, (0, padded_batch_size), mode="constant", value=2
)
new_model_input.mem_indexes = F.pad(
new_model_input.mem_indexes,
(0, padded_batch_size),
Expand Down Expand Up @@ -562,6 +567,8 @@ def _decode(
model_input=model_input, new_batch_size=infer_batch_size
)
infer_state = self._create_inferstate(model_input)
need_capture = self.graph.need_capture(infer_batch_size)
infer_state.skip_decode_att_wrapper_init = not need_capture
copy_kv_index_to_req(
self.req_manager.req_to_token_indexs,
infer_state.b_req_idx,
Expand All @@ -571,7 +578,7 @@ def _decode(
infer_state.init_some_extra_state(self)
infer_state.init_att_state()

if self.graph.need_capture(infer_batch_size):
if need_capture:
infer_state.is_cuda_graph = True
model_output: ModelOutput = self.graph.capture_decode(self._token_forward, infer_state)
else:
Expand Down
3 changes: 3 additions & 0 deletions lightllm/common/basemodel/batch_objs.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class ModelInput:
multimodal_params: list = None
# cpu 变量
mem_indexes_cpu: torch.Tensor = None
b_seq_len_cpu: torch.Tensor = None
# prefill 阶段使用的参数,但是不是推理过程使用的参数,是推理外部进行资源管理
# 的一些变量
b_prefill_has_output_cpu: List[bool] = None # 标记进行prefill的请求是否具有输出
Expand All @@ -64,6 +65,8 @@ def to_cuda(self):
assert self.is_prefill

self.b_req_idx = self.b_req_idx.cuda(non_blocking=True)
if not self.b_seq_len.is_cuda:
self.b_seq_len_cpu = self.b_seq_len
self.b_seq_len = self.b_seq_len.cuda(non_blocking=True)
Comment on lines +68 to 70

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If self.b_seq_len is already on CUDA when to_cuda is called, self.b_seq_len_cpu will remain None. This prevents the fast host-based planning path from being used. We can copy the tensor to CPU if it is already on CUDA to ensure self.b_seq_len_cpu is always populated.

        if not self.b_seq_len.is_cuda:
            self.b_seq_len_cpu = self.b_seq_len
        else:
            self.b_seq_len_cpu = self.b_seq_len.cpu()
        self.b_seq_len = self.b_seq_len.cuda(non_blocking=True)

self.b_mtp_index = self.b_mtp_index.cuda(non_blocking=True)
if self.b_ready_cache_len is not None:
Expand Down
2 changes: 2 additions & 0 deletions lightllm/common/basemodel/infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self):
self.b_mtp_index: torch.Tensor = None

self.b_seq_len: torch.Tensor = None
self.b_seq_len_cpu: torch.Tensor = None
# max_cache_len 用于 prefill 阶段标识请求中最大 cache的kv 的长度
self.max_cache_len: int = None
# prefix_total_token_num 用于 prefill 阶段标识当前请求中所有已经ready的kv的长度
Expand All @@ -56,6 +57,7 @@ def __init__(self):
self.return_all_prompt_logics: bool = False
self.multimodal_params: dict = None
self.is_cuda_graph: bool = False # 标记是否是cuda graph的捕获推理
self.skip_decode_att_wrapper_init: bool = False
self.dist_group: CustomProcessGroup = None

# 在microbatch overlap的运行模式下,用于标记当前 microbatch 的 index 序号
Expand Down
2 changes: 0 additions & 2 deletions lightllm/common/basemodel/triton_kernel/repack_kv_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ def _fwd_kernel_repack_kv_index(
@torch.no_grad()
def repack_kv_index(kv_index, req_index, seq_len, start_loc, max_seq_len, out_kv_index):
batch_size = req_index.shape[0]
# flashinfer requires out_kv_index to be zeroed before use
out_kv_index.zero_()
BLOCK = 64
grid = (
batch_size,
Expand Down
1 change: 1 addition & 0 deletions lightllm/distributed/flashinfer_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,5 @@ def all_reduce(self, inp: torch.Tensor) -> torch.Tensor:
input=inp,
workspace=self._workspace,
pattern=flashinfer_comm.AllReduceFusionPattern.kAllReduce,
launch_with_pdl=True,
)
Loading