Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,7 @@ dist
.vscode
tmp/
requirements-musa.txt
logs/
logs/

/benchmark/
artifacts/
28 changes: 15 additions & 13 deletions lightllm/common/basemodel/attention/fa3/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,12 @@ def init_state(self):
torch.arange(batch_size, device=device), self.infer_state.b_q_seq_len
)
# 为了减少推理计算量,在推理外部初始化k_descale和v_descale
self.k_descale = offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
self.v_descale = offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)

self.k_descale = (
offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
)
self.v_descale = (
offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
)

def prefill_att(
self,
Expand Down Expand Up @@ -116,20 +119,19 @@ def init_state(self):
super().init_state()
self.backend: Fp8Fa3AttBackend = self.backend

args_mtp_step = get_env_start_args().mtp_step
att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1)
assert self.infer_state.batch_size % (args_mtp_step + 1) == 0

device = self.infer_state.input_ids.device
batch_size = att_batch_size
batch_size = self.b_att_seq_len.shape[0]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Accessing self.b_att_seq_len directly on self will raise an AttributeError because b_att_seq_len is initialized on self.infer_state (via init_mtp_verify_extra_state), not on the attention layer object itself. It should be accessed via self.infer_state.b_att_seq_len.

Suggested change
batch_size = self.b_att_seq_len.shape[0]
batch_size = self.infer_state.b_att_seq_len.shape[0]

mem_manager = self.backend.model.mem_manager

offline_scales: torch.Tensor = mem_manager.scales
head_num = mem_manager.head_num

# 为了减少推理计算量,在推理外部初始化k_descale和v_descale
self.k_descale = offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
self.v_descale = offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
self.k_descale = (
offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
)
self.v_descale = (
offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
)

return

Expand Down Expand Up @@ -180,11 +182,11 @@ def _fp8_decode_att(
k_cache=cache_k,
v_cache=cache_v,
page_table=self.page_table,
cache_seqlens=self.infer_state.b_seq_len,
cache_seqlens=self.b_att_seq_len,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Accessing self.b_att_seq_len directly on self will raise an AttributeError because b_att_seq_len is initialized on self.infer_state, not on the attention layer object itself. It should be accessed via self.infer_state.b_att_seq_len.

Suggested change
cache_seqlens=self.b_att_seq_len,
cache_seqlens=self.infer_state.b_att_seq_len,

cu_seqlens_q=self.cu_seqlens_q,
cu_seqlens_k_new=self.cu_seqlens_k,
max_seqlen_q=self.decode_max_q_seq_len,
causal=False,
causal=True,
window_size=(-1, -1),
softcap=0.0,
q_descale=q_scale.view(self.infer_state.batch_size, k_head_num),
Expand Down
50 changes: 36 additions & 14 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import gc
import copy
import json
import math
import torch
import torch.nn.functional as F
import triton
Expand Down Expand Up @@ -337,6 +338,7 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0)
infer_state.b_req_idx = model_input.b_req_idx
infer_state.b_seq_len = model_input.b_seq_len
infer_state.b_mtp_index = model_input.b_mtp_index
infer_state.b_num_accepted_tokens = model_input.b_num_accepted_tokens
if model_input.is_prefill:
if model_input.b_ready_cache_len is not None:
infer_state.b_ready_cache_len = model_input.b_ready_cache_len
Expand Down Expand Up @@ -379,6 +381,16 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0)

return infer_state

def _get_decode_padding_unit(self, model_input: ModelInput) -> int:
padding_unit = self.tp_world_size_ if self.args.enable_tpsp_mix_mode else 1
if (not model_input.is_prefill) and self.args.mtp_step > 0:
padding_unit = math.lcm(padding_unit, self.args.mtp_step + 1)
return padding_unit

def _get_decode_infer_batch_size(self, model_input: ModelInput) -> int:
padding_unit = self._get_decode_padding_unit(model_input)
return triton.cdiv(model_input.batch_size, padding_unit) * padding_unit

def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_size: int):
if model_input.batch_size == new_batch_size:
return model_input
Expand All @@ -388,16 +400,34 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s
padded_batch_size = new_batch_size - model_input.batch_size
new_model_input = copy.copy(model_input)
new_model_input.batch_size = new_batch_size
new_model_input.total_token_num += padded_batch_size * 2
new_model_input.max_kv_seq_len = max(2, model_input.max_kv_seq_len)

is_mtp_grouped_decode = (not model_input.is_prefill) and self.args.mtp_step > 0
if is_mtp_grouped_decode:
mtp_size = self.args.mtp_step + 1
assert padded_batch_size % mtp_size == 0
padded_req_num = padded_batch_size // mtp_size
new_model_input.total_token_num += padded_req_num * (mtp_size * (mtp_size + 3) // 2)
new_model_input.max_kv_seq_len = max(mtp_size + 1, model_input.max_kv_seq_len)
pad_seq_len = torch.arange(
2, mtp_size + 2, dtype=new_model_input.b_seq_len.dtype, device=new_model_input.b_seq_len.device
).repeat(padded_req_num)
new_model_input.b_seq_len = torch.cat((new_model_input.b_seq_len, pad_seq_len), dim=0)
# b_num_accepted_tokens 不再随 model_input 流转/补齐:它在 GDN 的 init_mtp_verify_extra_state
# 里按 req_first 从 req_to_accept_len gather,padding 组 req_first=HOLD(槽恒为 1)自然得 1。
else:
new_model_input.total_token_num += padded_batch_size * 2
new_model_input.max_kv_seq_len = max(2, model_input.max_kv_seq_len)
new_model_input.b_seq_len = F.pad(
new_model_input.b_seq_len, (0, padded_batch_size), mode="constant", value=2
)

new_model_input.input_ids = F.pad(new_model_input.input_ids, (0, padded_batch_size), mode="constant", value=1)
new_model_input.b_req_idx = F.pad(
new_model_input.b_req_idx, (0, padded_batch_size), mode="constant", value=self.req_manager.HOLD_REQUEST_ID
)
new_model_input.b_mtp_index = F.pad(
new_model_input.b_mtp_index, (0, padded_batch_size), mode="constant", value=0
)
new_model_input.b_seq_len = F.pad(new_model_input.b_seq_len, (0, padded_batch_size), mode="constant", value=2)
new_model_input.mem_indexes = F.pad(
new_model_input.mem_indexes,
(0, padded_batch_size),
Expand Down Expand Up @@ -576,10 +606,7 @@ def _decode(
)

origin_batch_size = model_input.batch_size
if self.args.enable_tpsp_mix_mode:
infer_batch_size = triton.cdiv(model_input.batch_size, self.tp_world_size_) * self.tp_world_size_
else:
infer_batch_size = model_input.batch_size
infer_batch_size = self._get_decode_infer_batch_size(model_input)

if self.graph is not None and self.graph.can_run(
batch_size=infer_batch_size, max_len_in_batch=model_input.max_kv_seq_len
Expand Down Expand Up @@ -831,7 +858,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode

origin_batch_size = model_input0.batch_size
max_len_in_batch = max(model_input0.max_kv_seq_len, model_input1.max_kv_seq_len)
infer_batch_size = triton.cdiv(origin_batch_size, self.tp_world_size_) * self.tp_world_size_
infer_batch_size = self._get_decode_infer_batch_size(model_input0)

if self.graph is not None and self.graph.can_run(infer_batch_size, max_len_in_batch):
infer_batch_size = self.graph.find_closest_graph_batch_size(infer_batch_size)
Expand Down Expand Up @@ -1202,12 +1229,7 @@ def _init_padded_req(self):
def _gen_special_model_input(self, token_num: int):
special_model_input = {}

is_mtp_draft_model = (
"Deepseek3MTPModel" in str(self.__class__)
or "Qwen3MOEMTPModel" in str(self.__class__)
or "MistralMTPModel" in str(self.__class__)
or "Glm4MoeLiteMTPModel" in str(self.__class__)
)
is_mtp_draft_model = getattr(self, "is_mtp_draft_model", False)
if is_mtp_draft_model:
special_model_input["mtp_draft_input_hiddens"] = torch.randn(
token_num, self.config["hidden_size"], dtype=self.data_type, device="cuda"
Expand Down
4 changes: 4 additions & 0 deletions lightllm/common/basemodel/batch_objs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class ModelInput:
# 的 draft 模型的输入
mtp_draft_input_hiddens: Optional[torch.Tensor] = None

b_num_accepted_tokens: Optional[torch.Tensor] = None

def to_cuda(self):
if self.input_ids is not None:
self.input_ids = self.input_ids.cuda(non_blocking=True)
Expand All @@ -66,6 +68,8 @@ def to_cuda(self):
self.b_req_idx = self.b_req_idx.cuda(non_blocking=True)
self.b_seq_len = self.b_seq_len.cuda(non_blocking=True)
self.b_mtp_index = self.b_mtp_index.cuda(non_blocking=True)
if self.b_num_accepted_tokens is not None:
self.b_num_accepted_tokens = self.b_num_accepted_tokens.cuda(non_blocking=True)
if self.b_ready_cache_len is not None:
self.b_ready_cache_len = self.b_ready_cache_len.cuda(non_blocking=True)
if self.b_prefill_start_loc is not None:
Expand Down
Loading