3232from vllm_ascend .attention .utils import AscendCommonAttentionMetadata
3333from vllm_ascend .compilation .acl_graph import (ACLGraphWrapper ,
3434 set_mtp_graph_params ,
35+ update_mla_attn_dcp_pcp_params ,
3536 update_mla_attn_params )
3637from vllm_ascend .spec_decode .interface import Proposer , SpecDcodeType
3738from vllm_ascend .utils import (ProfileExecuteDuration , lmhead_tp_enable ,
@@ -98,6 +99,7 @@ def __init__(
9899 self .pcp_size = self .runner .pcp_size
99100 self .dcp_size = self .runner .dcp_size
100101 self .pcp_rank = self .runner .pcp_rank
102+ self .dcp_rank = self .runner .dcp_rank
101103
102104 self .attn_metadata_builder : Optional [AttentionMetadataBuilder ] = None
103105 self .draft_indexer_metadata_builder : Optional [
@@ -268,6 +270,13 @@ def dummy_run(self,
268270 cos = self .runner .cos ,
269271 sin = self .runner .sin ,
270272 )
273+ if self .pcp_size * self .dcp_size > 1 :
274+ # update long_seq related params and flatten block_table
275+ common_attn_metadata .prefill_context_parallel_metadata = \
276+ self .runner .long_seq_metadata
277+ common_attn_metadata .block_table_tensor = \
278+ self .runner .input_batch .block_table [0 ].get_device_tensor ()[
279+ :num_reqs * self .decode_threshold ]
271280
272281 builder = self .runner .attn_groups [0 ][0 ].get_metadata_builder ()
273282 attn_metadata_mtp = builder .build_for_graph_capture (
@@ -313,9 +322,15 @@ def dummy_run(self,
313322 if forward_context .cudagraph_runtime_mode == CUDAGraphMode .FULL and \
314323 not forward_context .capturing :
315324 if self .vllm_config .model_config .use_mla :
316- update_mla_attn_params (
317- self .update_stream , forward_context , num_tokens ,
318- self .vllm_config .speculative_config )
325+ if self .pcp_size * self .dcp_size > 1 :
326+ update_mla_attn_dcp_pcp_params (
327+ self .update_stream , forward_context ,
328+ num_tokens )
329+ else :
330+ update_mla_attn_params (
331+ self .update_stream , forward_context ,
332+ num_tokens ,
333+ self .vllm_config .speculative_config )
319334 if self .enable_shared_expert_dp :
320335 positions = torch .ops .vllm .maybe_all_gather_and_maybe_unpad (
321336 positions , True )
@@ -369,7 +384,7 @@ def generate_token_ids(self,
369384 valid_sampled_tokens_count )
370385
371386 req_scheduled_tokens = scheduler_output .num_scheduled_tokens
372- if self .pcp_size > 1 :
387+ if self .pcp_size * self . dcp_size > 1 :
373388 long_seq_metadata = self .runner .long_seq_metadata
374389 input_ids_pcp_full = self .runner .input_ids_pcp_full
375390 query_start_loc_pcp_full = self .runner .query_start_loc_pcp_full
@@ -405,7 +420,6 @@ def generate_token_ids(self,
405420 common_attn_metadata .query_start_loc = \
406421 query_start_loc_pcp_full [:num_reqs + 1 ]
407422 if self .speculative_config .disable_padded_drafter_batch :
408- # NOTE: Currently, MTP-fullgraph is incompatibility with pcp
409423 token_indices_to_sample = None
410424 common_attn_metadata , token_indices = \
411425 self ._prepare_inputs (
@@ -638,28 +652,36 @@ def _propose(
638652 self .input_ids [last_token_indices ] = next_token_ids
639653
640654 # update pcp related params
641- if self .pcp_size > 1 :
655+ if self .pcp_size * self . dcp_size > 1 :
642656 assert long_seq_metadata is not None
643657 common_attn_metadata .prefill_context_parallel_metadata = long_seq_metadata
658+ ori_last_token_indices = last_token_indices .clone ()
659+ query_lens_d = self .runner .query_lens [:num_decode_reqs ]
660+ if self .pcp_size > 1 :
644661 # 1. preprocess decode/prefill input_ids & target_hidden_states
645662 # decode input_ids: keep unchanged
646663 # decode target_hidden_states: remove padding
647664 # prefill input_ids: add padding and pcp split
648665 # prefill target_hidden_states: pcp split
649- num_tokens_d = num_decode_reqs * self . decode_threshold
666+ num_tokens_d = query_lens_d . sum (). item ()
650667 num_tokens_d_padded = num_tokens_d * self .pcp_size
651668 input_ids_d = self .input_ids [:num_tokens_d ]
652669 input_ids_p = self .input_ids [num_tokens_d :num_tokens ]
653670 target_hidden_states_d_padded = \
654671 target_hidden_states [:num_tokens_d_padded ]
655672 if num_tokens_d :
656673 # remove padding (from pcp all-gather) in decode part
657- target_hidden_states_d = target_hidden_states_d_padded .reshape (
658- [
659- num_decode_reqs , self .decode_threshold * self .pcp_size ,
660- - 1
661- ])[:, :self .decode_threshold , :].reshape (
662- [num_tokens_d , - 1 ])
674+ mask_start_loc = torch .cat ([
675+ torch .tensor ([0 ], dtype = torch .int32 ),
676+ torch .cumsum (query_lens_d * self .pcp_size , dim = 0 )[:- 1 ]
677+ ])
678+ mask_len = query_lens_d
679+ mask = []
680+ for req_id in range (num_decode_reqs ):
681+ mask += list (
682+ range (mask_start_loc [req_id ],
683+ mask_start_loc [req_id ] + mask_len [req_id ]))
684+ target_hidden_states_d = target_hidden_states_d_padded [mask ]
663685 else :
664686 target_hidden_states_d = target_hidden_states_d_padded
665687 target_hidden_states_p = target_hidden_states [num_tokens_d_padded :]
@@ -794,10 +816,15 @@ def _propose(
794816 forward_context = get_forward_context ()
795817 if forward_context .cudagraph_runtime_mode == CUDAGraphMode .FULL :
796818 if self .vllm_config .model_config .use_mla :
797- update_mla_attn_params (
798- self .update_stream , forward_context ,
799- num_input_tokens ,
800- self .vllm_config .speculative_config )
819+ if self .pcp_size * self .dcp_size > 1 :
820+ update_mla_attn_dcp_pcp_params (
821+ self .update_stream , forward_context ,
822+ num_input_tokens )
823+ else :
824+ update_mla_attn_params (
825+ self .update_stream , forward_context ,
826+ num_input_tokens ,
827+ self .vllm_config .speculative_config )
801828
802829 if self .enable_shared_expert_dp :
803830 hidden_states = torch .ops .vllm .maybe_all_gather_and_maybe_unpad (
@@ -816,6 +843,8 @@ def _propose(
816843 (0 , max_num_reqs_across_dp - num_indices ))
817844
818845 if self .pcp_size > 1 :
846+ # remove graph padding before all_gather
847+ hidden_states = hidden_states [:num_tokens ]
819848 hidden_states = get_pcp_group ().all_gather (hidden_states , 0 )
820849 hidden_states = torch .index_select (
821850 hidden_states , 0 , self .runner .
@@ -856,6 +885,51 @@ def _propose(
856885 last_token_indices = self .arange [:batch_size ]
857886 if getattr (attn_metadata_i , "num_decode_tokens" , 0 ):
858887 attn_metadata_i .num_decode_tokens = batch_size
888+ if self .pcp_size * self .dcp_size > 1 :
889+ positions = target_positions [ori_last_token_indices ]
890+ # For pcp/dcp, tokens are split across different cp ranks,
891+ # so we can not simply update slot_mapping by += 1.
892+ # Instead, we pre-allocate mtp slot_mapping in model_runner
893+ # (_generate_pcp_mtp_input), and use updated slot_indices
894+ # to get corresponding slot_mapping in each step.
895+ num_reject_tokens = torch .tensor (
896+ self .runner .cu_num_tokens_pcp_full ,
897+ dtype = torch .int32 ).to (
898+ self .device ) - ori_last_token_indices - 1
899+ num_accept_tokens = \
900+ query_lens_d .to (self .device ) - num_reject_tokens
901+ ori_seq_len = attn_metadata_i .seq_lens
902+ mtp_slot_mapping = self .runner .mtp_slot_pad
903+
904+ # slot_mapping index base offset:
905+ # scheduled tokens + pre-allocated mtp tokens + accepted tokens
906+ slot_idx_base = (
907+ torch .cat ([
908+ torch .tensor (
909+ [0 ], dtype = torch .int32 , device = self .device ),
910+ (torch .cumsum (query_lens_d , dim = 0 )[:- 1 ] *
911+ self .pcp_size ).to (self .device )
912+ ]) +
913+ torch .arange (num_decode_reqs , device = self .device ) *
914+ (self .num_speculative_tokens - 1 ) * self .pcp_size +
915+ (num_accept_tokens - 1 ) * self .pcp_size )
916+ slot_indices_list = []
917+ for req_id in range (num_decode_reqs ):
918+ slot_indices_list .append (
919+ torch .arange (slot_idx_base [req_id ],
920+ slot_idx_base [req_id ] + self .pcp_size ,
921+ device = self .device ))
922+ slot_indices = torch .cat (slot_indices_list , dim = 0 )
923+
924+ # fold block_table (restore it to original size before flattened)
925+ block_indices = torch .cat ([
926+ torch .tensor ([0 ], dtype = torch .int32 ),
927+ torch .cumsum (query_lens_d , dim = 0 )[:- 1 ]
928+ ])
929+ attn_metadata_i .decode .block_table [:batch_size ] = \
930+ attn_metadata_i .decode .block_table [block_indices ]
931+ attn_metadata_i .decode .block_table = \
932+ attn_metadata_i .decode .block_table [:batch_size ]
859933
860934 input_ids = draft_token_ids_list [- 1 ].int ()
861935 positions += 1
@@ -902,13 +976,40 @@ def _propose(
902976 # Otherwise, the KV cache will be inadvertently updated with the
903977 # padding tokens.
904978 slot_mapping += 1
979+ if self .pcp_size > 1 :
980+ exceeds_max_model_len = exceeds_max_model_len .repeat_interleave (
981+ slot_mapping .size (0 ) // exceeds_max_model_len .size (0 ))
905982 slot_mapping .masked_fill_ (exceeds_max_model_len , PADDING_SLOT_ID )
906983
907984 # copy inputs to buffer for cudagraph
908985 self .input_ids [:batch_size ] = input_ids
909986 self .positions [:batch_size ] = clamped_positions
910987 self .hidden_states [:hidden_states .shape [0 ]] = hidden_states
911- attn_metadata_i .slot_mapping [:batch_size ] = slot_mapping
988+ if self .pcp_size * self .dcp_size > 1 :
989+ # update local seq_len and batch_seq_mask
990+ num_computed_tokens_of_pcp_dcp = self .runner ._get_cp_local_seq_lens (
991+ ori_seq_len + step + 1 ,
992+ self .pcp_size ,
993+ self .dcp_size ,
994+ self .runner .parallel_config .cp_kv_cache_interleave_size ,
995+ )
996+ cp_seq_len = \
997+ num_computed_tokens_of_pcp_dcp [:, self .pcp_rank , self .dcp_rank ]
998+ batch_seq_mask = (cp_seq_len == 0 )
999+ builder .batch_seq_mask_buf [:batch_seq_mask .shape [0 ]].copy_ (
1000+ batch_seq_mask , non_blocking = True )
1001+ batch_seq_mask = builder .batch_seq_mask_buf [:batch_seq_mask .
1002+ shape [0 ]]
1003+ cp_seq_len = torch .where (cp_seq_len == 0 , 1 , cp_seq_len )
1004+ attn_metadata_i .decode .cp_seq_len = cp_seq_len
1005+ attn_metadata_i .decode .batch_seq_mask = batch_seq_mask
1006+ # update slot_mapping
1007+ slot_indices += self .pcp_size
1008+ slot_mapping = mtp_slot_mapping [slot_indices ]
1009+ attn_metadata_i .slot_mapping [:batch_size *
1010+ self .pcp_size ] = slot_mapping
1011+ else :
1012+ attn_metadata_i .slot_mapping [:batch_size ] = slot_mapping
9121013 if self .speculative_config .disable_padded_drafter_batch :
9131014 self .positions [batch_size :num_input_tokens ] = 0
9141015 self .input_ids [batch_size :num_input_tokens ] = 0
0 commit comments