Skip to content

Commit 7cc6d7b

Browse files
committed
[feature] support pcp + mtp in full graph
Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
1 parent 2b819bb commit 7cc6d7b

File tree

5 files changed

+220
-49
lines changed

5 files changed

+220
-49
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -423,17 +423,9 @@ def build(
423423
common_attn_metadata.block_table_tensor[:graph_pad_size])
424424
else:
425425
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
426-
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
427-
if self.pcp_size > 1:
428-
num_decodes_flatten = num_decodes * self.decode_threshold
429-
block_table = common_attn_metadata.block_table_tensor[:
430-
num_decodes_flatten
431-
+
432-
num_prefills]
433426
if num_actual_tokens_pcp_padded is None:
434427
num_actual_tokens_pcp_padded = num_actual_tokens
435428

436-
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
437429
slot_mapping = common_attn_metadata.slot_mapping[:
438430
num_actual_tokens_pcp_padded]
439431
input_positions = common_attn_metadata.positions[:
@@ -456,6 +448,13 @@ def build(
456448
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
457449
num_computed_tokens_cpu = (seq_lens - query_lens)
458450

451+
if self.pcp_size * self.dcp_size > 1:
452+
num_decodes_flatten = query_lens[:num_decodes].sum().item()
453+
block_table = common_attn_metadata.block_table_tensor[:
454+
num_decodes_flatten
455+
+
456+
num_prefills]
457+
459458
prefill_metadata = None
460459
chunked_context_metadata = None
461460
if num_prefills > 0:
@@ -520,8 +519,9 @@ def build(
520519
if self.dcp_size * self.pcp_size > 1:
521520
if num_computed_tokens_of_pcp_dcp is not None:
522521
local_context_lens_allranks = torch.tensor(
523-
num_computed_tokens_of_pcp_dcp[reqs_start:num_reqs]
524-
).reshape(-1, self.dcp_size * self.pcp_size)
522+
num_computed_tokens_of_pcp_dcp[
523+
num_decodes_flatten:]).reshape(
524+
-1, self.dcp_size * self.pcp_size)
525525
# Note(qcs): The max local context lengths
526526
# padded to `cp_local_block_size`.
527527
padded_local_context_lens_cpu = (cdiv(
@@ -615,7 +615,7 @@ def build(
615615
cos=cos,
616616
pcp_metadata=pcp_metadata,
617617
)
618-
if self.pcp_size > 1:
618+
if self.pcp_size * self.dcp_size > 1:
619619
prefill_metadata.block_table = block_table[
620620
num_decodes_flatten:, ...]
621621

@@ -629,13 +629,12 @@ def build(
629629
max_seq_lens = seq_lens[:num_decodes].max().item()
630630
seq_lens = seq_lens[:num_decodes]
631631
input_positions = input_positions[:num_decode_tokens]
632-
if self.pcp_size > 1:
632+
if self.pcp_size * self.dcp_size > 1:
633633
# For pcp + spec decode, we flatten seq_lens and block_table
634634
# to avoid irregular spec_attn_mask shape
635635
block_table = block_table[:num_decodes_flatten, ...]
636636
else:
637637
block_table = block_table[:num_decodes, ...]
638-
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
639638
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
640639
if graph_pad_size > num_decodes and \
641640
self.speculative_config.disable_padded_drafter_batch:
@@ -645,8 +644,7 @@ def build(
645644
if num_computed_tokens_of_pcp_dcp is not None:
646645
# [bs, pcp_size, dcp_size]
647646
num_computed_tokens_of_cp_dcp_array = np.array(
648-
num_computed_tokens_of_pcp_dcp)[:num_decodes *
649-
self.decode_threshold]
647+
num_computed_tokens_of_pcp_dcp)[:num_decodes_flatten]
650648

651649
cp_seq_len = num_computed_tokens_of_cp_dcp_array[:,
652650
self.pcp_rank,
@@ -1902,8 +1900,11 @@ def _forward_decode_pcp_dcp(
19021900
"return_lse": True,
19031901
"calc_type": "calc_type_ring",
19041902
}
1905-
graph_params = get_graph_params()
19061903
forward_context: ForwardContext = get_forward_context()
1904+
if forward_context.is_mtp_model:
1905+
graph_params = get_mtp_graph_params()
1906+
else:
1907+
graph_params = get_graph_params()
19071908
if forward_context.capturing:
19081909
stream = torch_npu.npu.current_stream()
19091910
event = torch.npu.ExternalEvent()

vllm_ascend/compilation/acl_graph.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,10 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
369369

370370
def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
371371
runtime_shape):
372-
graph_params = get_graph_params()
372+
if forward_context.is_mtp_model:
373+
graph_params = get_mtp_graph_params()
374+
else:
375+
graph_params = get_graph_params()
373376
# FIXME: Behold! We are using a temporary hack here to update the args
374377
# for each layer's attention op in the graph.
375378
with torch.npu.stream(update_stream):

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 119 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
3333
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
3434
set_mtp_graph_params,
35+
update_mla_attn_dcp_pcp_params,
3536
update_mla_attn_params)
3637
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
3738
from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable,
@@ -98,6 +99,7 @@ def __init__(
9899
self.pcp_size = self.runner.pcp_size
99100
self.dcp_size = self.runner.dcp_size
100101
self.pcp_rank = self.runner.pcp_rank
102+
self.dcp_rank = self.runner.dcp_rank
101103

102104
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
103105
self.draft_indexer_metadata_builder: Optional[
@@ -268,6 +270,13 @@ def dummy_run(self,
268270
cos=self.runner.cos,
269271
sin=self.runner.sin,
270272
)
273+
if self.pcp_size * self.dcp_size > 1:
274+
# update long_seq related params and flatten block_table
275+
common_attn_metadata.prefill_context_parallel_metadata=\
276+
self.runner.long_seq_metadata
277+
common_attn_metadata.block_table_tensor = \
278+
self.runner.input_batch.block_table[0].get_device_tensor()[
279+
:num_reqs * self.decode_threshold]
271280

272281
builder = self.runner.attn_groups[0][0].get_metadata_builder()
273282
attn_metadata_mtp = builder.build_for_graph_capture(
@@ -313,9 +322,15 @@ def dummy_run(self,
313322
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \
314323
not forward_context.capturing:
315324
if self.vllm_config.model_config.use_mla:
316-
update_mla_attn_params(
317-
self.update_stream, forward_context, num_tokens,
318-
self.vllm_config.speculative_config)
325+
if self.pcp_size * self.dcp_size > 1:
326+
update_mla_attn_dcp_pcp_params(
327+
self.update_stream, forward_context,
328+
num_tokens)
329+
else:
330+
update_mla_attn_params(
331+
self.update_stream, forward_context,
332+
num_tokens,
333+
self.vllm_config.speculative_config)
319334
if self.enable_shared_expert_dp:
320335
positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
321336
positions, True)
@@ -369,7 +384,7 @@ def generate_token_ids(self,
369384
valid_sampled_tokens_count)
370385

371386
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
372-
if self.pcp_size > 1:
387+
if self.pcp_size * self.dcp_size > 1:
373388
long_seq_metadata = self.runner.long_seq_metadata
374389
input_ids_pcp_full = self.runner.input_ids_pcp_full
375390
query_start_loc_pcp_full = self.runner.query_start_loc_pcp_full
@@ -405,7 +420,6 @@ def generate_token_ids(self,
405420
common_attn_metadata.query_start_loc = \
406421
query_start_loc_pcp_full[:num_reqs + 1]
407422
if self.speculative_config.disable_padded_drafter_batch:
408-
# NOTE: Currently, MTP-fullgraph is incompatibility with pcp
409423
token_indices_to_sample = None
410424
common_attn_metadata, token_indices =\
411425
self._prepare_inputs(
@@ -638,28 +652,36 @@ def _propose(
638652
self.input_ids[last_token_indices] = next_token_ids
639653

640654
# update pcp related params
641-
if self.pcp_size > 1:
655+
if self.pcp_size * self.dcp_size > 1:
642656
assert long_seq_metadata is not None
643657
common_attn_metadata.prefill_context_parallel_metadata = long_seq_metadata
658+
ori_last_token_indices = last_token_indices.clone()
659+
query_lens_d = self.runner.query_lens[:num_decode_reqs]
660+
if self.pcp_size > 1:
644661
# 1. preprocess decode/prefill input_ids & target_hidden_states
645662
# decode input_ids: keep unchanged
646663
# decode target_hidden_states: remove padding
647664
# prefill input_ids: add padding and pcp split
648665
# prefill target_hidden_states: pcp split
649-
num_tokens_d = num_decode_reqs * self.decode_threshold
666+
num_tokens_d = query_lens_d.sum().item()
650667
num_tokens_d_padded = num_tokens_d * self.pcp_size
651668
input_ids_d = self.input_ids[:num_tokens_d]
652669
input_ids_p = self.input_ids[num_tokens_d:num_tokens]
653670
target_hidden_states_d_padded = \
654671
target_hidden_states[:num_tokens_d_padded]
655672
if num_tokens_d:
656673
# remove padding (from pcp all-gather) in decode part
657-
target_hidden_states_d = target_hidden_states_d_padded.reshape(
658-
[
659-
num_decode_reqs, self.decode_threshold * self.pcp_size,
660-
-1
661-
])[:, :self.decode_threshold, :].reshape(
662-
[num_tokens_d, -1])
674+
mask_start_loc = torch.cat([
675+
torch.tensor([0], dtype=torch.int32),
676+
torch.cumsum(query_lens_d * self.pcp_size, dim=0)[:-1]
677+
])
678+
mask_len = query_lens_d
679+
mask = []
680+
for req_id in range(num_decode_reqs):
681+
mask += list(
682+
range(mask_start_loc[req_id],
683+
mask_start_loc[req_id] + mask_len[req_id]))
684+
target_hidden_states_d = target_hidden_states_d_padded[mask]
663685
else:
664686
target_hidden_states_d = target_hidden_states_d_padded
665687
target_hidden_states_p = target_hidden_states[num_tokens_d_padded:]
@@ -794,10 +816,15 @@ def _propose(
794816
forward_context = get_forward_context()
795817
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
796818
if self.vllm_config.model_config.use_mla:
797-
update_mla_attn_params(
798-
self.update_stream, forward_context,
799-
num_input_tokens,
800-
self.vllm_config.speculative_config)
819+
if self.pcp_size * self.dcp_size > 1:
820+
update_mla_attn_dcp_pcp_params(
821+
self.update_stream, forward_context,
822+
num_input_tokens)
823+
else:
824+
update_mla_attn_params(
825+
self.update_stream, forward_context,
826+
num_input_tokens,
827+
self.vllm_config.speculative_config)
801828

802829
if self.enable_shared_expert_dp:
803830
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
@@ -816,6 +843,8 @@ def _propose(
816843
(0, max_num_reqs_across_dp - num_indices))
817844

818845
if self.pcp_size > 1:
846+
# remove graph padding before all_gather
847+
hidden_states = hidden_states[:num_tokens]
819848
hidden_states = get_pcp_group().all_gather(hidden_states, 0)
820849
hidden_states = torch.index_select(
821850
hidden_states, 0, self.runner.
@@ -856,6 +885,51 @@ def _propose(
856885
last_token_indices = self.arange[:batch_size]
857886
if getattr(attn_metadata_i, "num_decode_tokens", 0):
858887
attn_metadata_i.num_decode_tokens = batch_size
888+
if self.pcp_size * self.dcp_size > 1:
889+
positions = target_positions[ori_last_token_indices]
890+
# For pcp/dcp, tokens are split across different cp ranks,
891+
# so we can not simply update slot_mapping by += 1.
892+
# Instead, we pre-allocate mtp slot_mapping in model_runner
893+
# (_generate_pcp_mtp_input), and use updated slot_indices
894+
# to get corresponding slot_mapping in each step.
895+
num_reject_tokens = torch.tensor(
896+
self.runner.cu_num_tokens_pcp_full,
897+
dtype=torch.int32).to(
898+
self.device) - ori_last_token_indices - 1
899+
num_accept_tokens = \
900+
query_lens_d.to(self.device) - num_reject_tokens
901+
ori_seq_len = attn_metadata_i.seq_lens
902+
mtp_slot_mapping = self.runner.mtp_slot_pad
903+
904+
# slot_mapping index base offset:
905+
# scheduled tokens + pre-allocated mtp tokens + accepted tokens
906+
slot_idx_base = (
907+
torch.cat([
908+
torch.tensor(
909+
[0], dtype=torch.int32, device=self.device),
910+
(torch.cumsum(query_lens_d, dim=0)[:-1] *
911+
self.pcp_size).to(self.device)
912+
]) +
913+
torch.arange(num_decode_reqs, device=self.device) *
914+
(self.num_speculative_tokens - 1) * self.pcp_size +
915+
(num_accept_tokens - 1) * self.pcp_size)
916+
slot_indices_list = []
917+
for req_id in range(num_decode_reqs):
918+
slot_indices_list.append(
919+
torch.arange(slot_idx_base[req_id],
920+
slot_idx_base[req_id] + self.pcp_size,
921+
device=self.device))
922+
slot_indices = torch.cat(slot_indices_list, dim=0)
923+
924+
# fold block_table (restore it to original size before flattened)
925+
block_indices = torch.cat([
926+
torch.tensor([0], dtype=torch.int32),
927+
torch.cumsum(query_lens_d, dim=0)[:-1]
928+
])
929+
attn_metadata_i.decode.block_table[:batch_size] = \
930+
attn_metadata_i.decode.block_table[block_indices]
931+
attn_metadata_i.decode.block_table = \
932+
attn_metadata_i.decode.block_table[:batch_size]
859933

860934
input_ids = draft_token_ids_list[-1].int()
861935
positions += 1
@@ -902,13 +976,40 @@ def _propose(
902976
# Otherwise, the KV cache will be inadvertently updated with the
903977
# padding tokens.
904978
slot_mapping += 1
979+
if self.pcp_size > 1:
980+
exceeds_max_model_len = exceeds_max_model_len.repeat_interleave(
981+
slot_mapping.size(0) // exceeds_max_model_len.size(0))
905982
slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID)
906983

907984
# copy inputs to buffer for cudagraph
908985
self.input_ids[:batch_size] = input_ids
909986
self.positions[:batch_size] = clamped_positions
910987
self.hidden_states[:hidden_states.shape[0]] = hidden_states
911-
attn_metadata_i.slot_mapping[:batch_size] = slot_mapping
988+
if self.pcp_size * self.dcp_size > 1:
989+
# update local seq_len and batch_seq_mask
990+
num_computed_tokens_of_pcp_dcp = self.runner._get_cp_local_seq_lens(
991+
ori_seq_len + step + 1,
992+
self.pcp_size,
993+
self.dcp_size,
994+
self.runner.parallel_config.cp_kv_cache_interleave_size,
995+
)
996+
cp_seq_len = \
997+
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank]
998+
batch_seq_mask = (cp_seq_len == 0)
999+
builder.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_(
1000+
batch_seq_mask, non_blocking=True)
1001+
batch_seq_mask = builder.batch_seq_mask_buf[:batch_seq_mask.
1002+
shape[0]]
1003+
cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len)
1004+
attn_metadata_i.decode.cp_seq_len = cp_seq_len
1005+
attn_metadata_i.decode.batch_seq_mask = batch_seq_mask
1006+
# update slot_mapping
1007+
slot_indices += self.pcp_size
1008+
slot_mapping = mtp_slot_mapping[slot_indices]
1009+
attn_metadata_i.slot_mapping[:batch_size *
1010+
self.pcp_size] = slot_mapping
1011+
else:
1012+
attn_metadata_i.slot_mapping[:batch_size] = slot_mapping
9121013
if self.speculative_config.disable_padded_drafter_batch:
9131014
self.positions[batch_size:num_input_tokens] = 0
9141015
self.input_ids[batch_size:num_input_tokens] = 0

vllm_ascend/worker/block_table.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __init__(self,
7474
logical_table_size = max_num_blocks_per_req
7575

7676
duplicate_size = 1
77-
if self.pcp_world_size > 1:
77+
if self.pcp_world_size * self.dcp_world_size > 1:
7878
duplicate_size += num_speculative_tokens
7979
self.block_table = torch.zeros(
8080
(max_num_reqs * duplicate_size, logical_table_size),

0 commit comments

Comments
 (0)