Skip to content

Commit 8cc1619

Browse files
committed
[feature] support pcp + mtp in full graph
Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
1 parent 134e011 commit 8cc1619

File tree

5 files changed

+236
-39
lines changed

5 files changed

+236
-39
lines changed

vllm_ascend/attention/utils.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ class AscendPrefillContextParallelMetadata:
4242

4343
pcp_prefill_mask: torch.Tensor = None
4444

45+
query_lens_pcp_full_cpu: torch.Tensor = None
46+
47+
max_query_len_pcp_full: int = None
48+
4549

4650
@dataclass
4751
class AscendCommonAttentionMetadata:
@@ -135,10 +139,15 @@ def filter_chunked_req_indices(
135139
def split_decodes_and_prefills(
136140
common_attn_metadata: AscendCommonAttentionMetadata,
137141
decode_threshold: int = 1,
142+
query_lens: torch.Tensor = None,
143+
max_query_len: int = None,
138144
) -> tuple[int, int, int, int]:
139145
"""
140146
Assuming a reordered batch, finds the boundary between prefill and decode
141147
requests.
148+
While pcp > 1, query_lens is split across pcp ranks,
149+
In this case we pass in the original query_lens and max_query_len
150+
to distinguish prefills and decodes.
142151
143152
Args:
144153
common_attn_metadata: AscendCommonAttentionMetadata object containing the
@@ -151,15 +160,17 @@ def split_decodes_and_prefills(
151160
num_decode_tokens: The number of tokens in the decode requests.
152161
num_prefill_tokens: The number of tokens in the prefill requests.
153162
"""
154-
max_query_len = common_attn_metadata.max_query_len
163+
max_query_len = common_attn_metadata.max_query_len \
164+
if max_query_len is None else max_query_len
155165
num_reqs = common_attn_metadata.num_reqs
156166
num_tokens = common_attn_metadata.num_actual_tokens
157167
query_start_loc = common_attn_metadata.query_start_loc_cpu
158168

159169
if max_query_len <= decode_threshold:
160170
return num_reqs, 0, num_tokens, 0
161171

162-
query_lens = query_start_loc[1:] - query_start_loc[:-1]
172+
query_lens = (query_start_loc[1:] - query_start_loc[:-1]) \
173+
if query_lens is None else query_lens
163174
is_prefill = query_lens > decode_threshold
164175
if not torch.any(is_prefill):
165176
return num_reqs, 0, num_tokens, 0

vllm_ascend/compilation/acl_graph.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,10 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
369369

370370
def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
371371
runtime_shape):
372-
graph_params = get_graph_params()
372+
if forward_context.is_mtp_model:
373+
graph_params = get_mtp_graph_params()
374+
else:
375+
graph_params = get_graph_params()
373376
# FIXME: Behold! We are using a temporary hack here to update the args
374377
# for each layer's attention op in the graph.
375378
with torch.npu.stream(update_stream):

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 119 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
3434
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
3535
set_mtp_graph_params,
36+
update_mla_attn_dcp_pcp_params,
3637
update_mla_attn_params)
3738
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
3839
from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable,
@@ -99,6 +100,7 @@ def __init__(
99100
self.pcp_size = self.runner.pcp_size
100101
self.dcp_size = self.runner.dcp_size
101102
self.pcp_rank = self.runner.pcp_rank
103+
self.dcp_rank = self.runner.dcp_rank
102104

103105
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
104106
self.draft_indexer_metadata_builder: Optional[
@@ -272,6 +274,13 @@ def dummy_run(self,
272274
cos=self.runner.cos,
273275
sin=self.runner.sin,
274276
)
277+
if self.pcp_size * self.dcp_size > 1:
278+
# update long_seq related params and flatten block_table
279+
common_attn_metadata.prefill_context_parallel_metadata=\
280+
self.runner.long_seq_metadata
281+
common_attn_metadata.block_table_tensor = \
282+
self.runner.input_batch.block_table[0].get_device_tensor()[
283+
:num_reqs * self.decode_threshold]
275284

276285
builder = self.runner.attn_groups[0][0].get_metadata_builder()
277286
attn_metadata_mtp = builder.build_for_graph_capture(
@@ -317,9 +326,15 @@ def dummy_run(self,
317326
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \
318327
not forward_context.capturing:
319328
if self.vllm_config.model_config.use_mla:
320-
update_mla_attn_params(
321-
self.update_stream, forward_context, num_tokens,
322-
self.vllm_config.speculative_config)
329+
if self.pcp_size * self.dcp_size > 1:
330+
update_mla_attn_dcp_pcp_params(
331+
self.update_stream, forward_context,
332+
num_tokens)
333+
else:
334+
update_mla_attn_params(
335+
self.update_stream, forward_context,
336+
num_tokens,
337+
self.vllm_config.speculative_config)
323338
if self.enable_shared_expert_dp:
324339
positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
325340
positions, True)
@@ -373,7 +388,7 @@ def generate_token_ids(self,
373388
valid_sampled_tokens_count)
374389

375390
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
376-
if self.pcp_size > 1:
391+
if self.pcp_size * self.dcp_size > 1:
377392
long_seq_metadata = self.runner.long_seq_metadata
378393
input_ids_pcp_full = self.runner.input_ids_pcp_full
379394
query_start_loc_pcp_full = self.runner.query_start_loc_pcp_full
@@ -409,7 +424,6 @@ def generate_token_ids(self,
409424
common_attn_metadata.query_start_loc = \
410425
query_start_loc_pcp_full[:num_reqs + 1]
411426
if self.speculative_config.disable_padded_drafter_batch:
412-
# NOTE: Currently, MTP-fullgraph is incompatibility with pcp
413427
token_indices_to_sample = None
414428
common_attn_metadata, token_indices =\
415429
self._prepare_inputs(
@@ -642,28 +656,36 @@ def _propose(
642656
self.input_ids[last_token_indices] = next_token_ids
643657

644658
# update pcp related params
645-
if self.pcp_size > 1:
659+
if self.pcp_size * self.dcp_size > 1:
646660
assert long_seq_metadata is not None
647661
common_attn_metadata.prefill_context_parallel_metadata = long_seq_metadata
662+
ori_last_token_indices = last_token_indices.clone()
663+
query_lens_d = self.runner.query_lens[:num_decode_reqs]
664+
if self.pcp_size > 1:
648665
# 1. preprocess decode/prefill input_ids & target_hidden_states
649666
# decode input_ids: keep unchanged
650667
# decode target_hidden_states: remove padding
651668
# prefill input_ids: add padding and pcp split
652669
# prefill target_hidden_states: pcp split
653-
num_tokens_d = num_decode_reqs * self.decode_threshold
670+
num_tokens_d = query_lens_d.sum().item()
654671
num_tokens_d_padded = num_tokens_d * self.pcp_size
655672
input_ids_d = self.input_ids[:num_tokens_d]
656673
input_ids_p = self.input_ids[num_tokens_d:num_tokens]
657674
target_hidden_states_d_padded = \
658675
target_hidden_states[:num_tokens_d_padded]
659676
if num_tokens_d:
660677
# remove padding (from pcp all-gather) in decode part
661-
target_hidden_states_d = target_hidden_states_d_padded.reshape(
662-
[
663-
num_decode_reqs, self.decode_threshold * self.pcp_size,
664-
-1
665-
])[:, :self.decode_threshold, :].reshape(
666-
[num_tokens_d, -1])
678+
mask_start_loc = torch.cat([
679+
torch.tensor([0], dtype=torch.int32),
680+
torch.cumsum(query_lens_d * self.pcp_size, dim=0)[:-1]
681+
])
682+
mask_len = query_lens_d
683+
mask = []
684+
for req_id in range(num_decode_reqs):
685+
mask += list(
686+
range(mask_start_loc[req_id],
687+
mask_start_loc[req_id] + mask_len[req_id]))
688+
target_hidden_states_d = target_hidden_states_d_padded[mask]
667689
else:
668690
target_hidden_states_d = target_hidden_states_d_padded
669691
target_hidden_states_p = target_hidden_states[num_tokens_d_padded:]
@@ -798,10 +820,15 @@ def _propose(
798820
forward_context = get_forward_context()
799821
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
800822
if self.vllm_config.model_config.use_mla:
801-
update_mla_attn_params(
802-
self.update_stream, forward_context,
803-
num_input_tokens,
804-
self.vllm_config.speculative_config)
823+
if self.pcp_size * self.dcp_size > 1:
824+
update_mla_attn_dcp_pcp_params(
825+
self.update_stream, forward_context,
826+
num_input_tokens)
827+
else:
828+
update_mla_attn_params(
829+
self.update_stream, forward_context,
830+
num_input_tokens,
831+
self.vllm_config.speculative_config)
805832

806833
if self.enable_shared_expert_dp:
807834
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
@@ -820,6 +847,8 @@ def _propose(
820847
(0, max_num_reqs_across_dp - num_indices))
821848

822849
if self.pcp_size > 1:
850+
# remove graph padding before all_gather
851+
hidden_states = hidden_states[:num_tokens]
823852
hidden_states = get_pcp_group().all_gather(hidden_states, 0)
824853
hidden_states = torch.index_select(
825854
hidden_states, 0, self.runner.
@@ -860,6 +889,51 @@ def _propose(
860889
last_token_indices = self.arange[:batch_size]
861890
if getattr(attn_metadata_i, "num_decode_tokens", 0):
862891
attn_metadata_i.num_decode_tokens = batch_size
892+
if self.pcp_size * self.dcp_size > 1:
893+
positions = target_positions[ori_last_token_indices]
894+
# For pcp/dcp, tokens are split across different cp ranks,
895+
# so we can not simply update slot_mapping by += 1.
896+
# Instead, we pre-allocate mtp slot_mapping in model_runner
897+
# (_generate_pcp_mtp_input), and use updated slot_indices
898+
# to get corresponding slot_mapping in each step.
899+
num_reject_tokens = torch.tensor(
900+
self.runner.cu_num_tokens_pcp_full,
901+
dtype=torch.int32).to(
902+
self.device) - ori_last_token_indices - 1
903+
num_accept_tokens = \
904+
query_lens_d.to(self.device) - num_reject_tokens
905+
ori_seq_len = attn_metadata_i.seq_lens
906+
mtp_slot_mapping = self.runner.mtp_slot_pad
907+
908+
# slot_mapping index base offset:
909+
# scheduled tokens + pre-allocated mtp tokens + accepted tokens
910+
slot_idx_base = (
911+
torch.cat([
912+
torch.tensor(
913+
[0], dtype=torch.int32, device=self.device),
914+
(torch.cumsum(query_lens_d, dim=0)[:-1] *
915+
self.pcp_size).to(self.device)
916+
]) +
917+
torch.arange(num_decode_reqs, device=self.device) *
918+
(self.num_speculative_tokens - 1) * self.pcp_size +
919+
(num_accept_tokens - 1) * self.pcp_size)
920+
slot_indices_list = []
921+
for req_id in range(num_decode_reqs):
922+
slot_indices_list.append(
923+
torch.arange(slot_idx_base[req_id],
924+
slot_idx_base[req_id] + self.pcp_size,
925+
device=self.device))
926+
slot_indices = torch.cat(slot_indices_list, dim=0)
927+
928+
# fold block_table (restore it to original size before flattened)
929+
block_indices = torch.cat([
930+
torch.tensor([0], dtype=torch.int32),
931+
torch.cumsum(query_lens_d, dim=0)[:-1]
932+
])
933+
attn_metadata_i.decode.block_table[:batch_size] = \
934+
attn_metadata_i.decode.block_table[block_indices]
935+
attn_metadata_i.decode.block_table = \
936+
attn_metadata_i.decode.block_table[:batch_size]
863937

864938
input_ids = draft_token_ids_list[-1].int()
865939
positions += 1
@@ -906,13 +980,40 @@ def _propose(
906980
# Otherwise, the KV cache will be inadvertently updated with the
907981
# padding tokens.
908982
slot_mapping += 1
983+
if self.pcp_size > 1:
984+
exceeds_max_model_len = exceeds_max_model_len.repeat_interleave(
985+
slot_mapping.size(0) // exceeds_max_model_len.size(0))
909986
slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID)
910987

911988
# copy inputs to buffer for cudagraph
912989
self.input_ids[:batch_size] = input_ids
913990
self.positions[:batch_size] = clamped_positions
914991
self.hidden_states[:hidden_states.shape[0]] = hidden_states
915-
attn_metadata_i.slot_mapping[:batch_size] = slot_mapping
992+
if self.pcp_size * self.dcp_size > 1:
993+
# update local seq_len and batch_seq_mask
994+
num_computed_tokens_of_pcp_dcp = self.runner._get_cp_local_seq_lens(
995+
ori_seq_len + step + 1,
996+
self.pcp_size,
997+
self.dcp_size,
998+
self.runner.parallel_config.cp_kv_cache_interleave_size,
999+
)
1000+
cp_seq_len = \
1001+
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank]
1002+
batch_seq_mask = (cp_seq_len == 0)
1003+
builder.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_(
1004+
batch_seq_mask, non_blocking=True)
1005+
batch_seq_mask = builder.batch_seq_mask_buf[:batch_seq_mask.
1006+
shape[0]]
1007+
cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len)
1008+
attn_metadata_i.decode.cp_seq_len = cp_seq_len
1009+
attn_metadata_i.decode.batch_seq_mask = batch_seq_mask
1010+
# update slot_mapping
1011+
slot_indices += self.pcp_size
1012+
slot_mapping = mtp_slot_mapping[slot_indices]
1013+
attn_metadata_i.slot_mapping[:batch_size *
1014+
self.pcp_size] = slot_mapping
1015+
else:
1016+
attn_metadata_i.slot_mapping[:batch_size] = slot_mapping
9161017
if self.speculative_config.disable_padded_drafter_batch:
9171018
self.positions[batch_size:num_input_tokens] = 0
9181019
self.input_ids[batch_size:num_input_tokens] = 0

vllm_ascend/worker/block_table.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __init__(self,
7474
logical_table_size = max_num_blocks_per_req
7575

7676
duplicate_size = 1
77-
if self.pcp_world_size > 1:
77+
if self.pcp_world_size * self.dcp_world_size > 1:
7878
duplicate_size += num_speculative_tokens
7979
self.block_table = torch.zeros(
8080
(max_num_reqs * duplicate_size, logical_table_size),

0 commit comments

Comments
 (0)