3333from vllm_ascend .attention .utils import AscendCommonAttentionMetadata
3434from vllm_ascend .compilation .acl_graph import (ACLGraphWrapper ,
3535 set_mtp_graph_params ,
36+ update_mla_attn_dcp_pcp_params ,
3637 update_mla_attn_params )
3738from vllm_ascend .spec_decode .interface import Proposer , SpecDcodeType
3839from vllm_ascend .utils import (ProfileExecuteDuration , lmhead_tp_enable ,
@@ -99,6 +100,7 @@ def __init__(
99100 self .pcp_size = self .runner .pcp_size
100101 self .dcp_size = self .runner .dcp_size
101102 self .pcp_rank = self .runner .pcp_rank
103+ self .dcp_rank = self .runner .dcp_rank
102104
103105 self .attn_metadata_builder : Optional [AttentionMetadataBuilder ] = None
104106 self .draft_indexer_metadata_builder : Optional [
@@ -272,6 +274,13 @@ def dummy_run(self,
272274 cos = self .runner .cos ,
273275 sin = self .runner .sin ,
274276 )
277+ if self .pcp_size * self .dcp_size > 1 :
278+ # update long_seq related params and flatten block_table
279+ common_attn_metadata .prefill_context_parallel_metadata = \
280+ self .runner .long_seq_metadata
281+ common_attn_metadata .block_table_tensor = \
282+ self .runner .input_batch .block_table [0 ].get_device_tensor ()[
283+ :num_reqs * self .decode_threshold ]
275284
276285 builder = self .runner .attn_groups [0 ][0 ].get_metadata_builder ()
277286 attn_metadata_mtp = builder .build_for_graph_capture (
@@ -317,9 +326,15 @@ def dummy_run(self,
317326 if forward_context .cudagraph_runtime_mode == CUDAGraphMode .FULL and \
318327 not forward_context .capturing :
319328 if self .vllm_config .model_config .use_mla :
320- update_mla_attn_params (
321- self .update_stream , forward_context , num_tokens ,
322- self .vllm_config .speculative_config )
329+ if self .pcp_size * self .dcp_size > 1 :
330+ update_mla_attn_dcp_pcp_params (
331+ self .update_stream , forward_context ,
332+ num_tokens )
333+ else :
334+ update_mla_attn_params (
335+ self .update_stream , forward_context ,
336+ num_tokens ,
337+ self .vllm_config .speculative_config )
323338 if self .enable_shared_expert_dp :
324339 positions = torch .ops .vllm .maybe_all_gather_and_maybe_unpad (
325340 positions , True )
@@ -373,7 +388,7 @@ def generate_token_ids(self,
373388 valid_sampled_tokens_count )
374389
375390 req_scheduled_tokens = scheduler_output .num_scheduled_tokens
376- if self .pcp_size > 1 :
391+ if self .pcp_size * self . dcp_size > 1 :
377392 long_seq_metadata = self .runner .long_seq_metadata
378393 input_ids_pcp_full = self .runner .input_ids_pcp_full
379394 query_start_loc_pcp_full = self .runner .query_start_loc_pcp_full
@@ -409,7 +424,6 @@ def generate_token_ids(self,
409424 common_attn_metadata .query_start_loc = \
410425 query_start_loc_pcp_full [:num_reqs + 1 ]
411426 if self .speculative_config .disable_padded_drafter_batch :
412- # NOTE: Currently, MTP-fullgraph is incompatibility with pcp
413427 token_indices_to_sample = None
414428 common_attn_metadata , token_indices = \
415429 self ._prepare_inputs (
@@ -642,28 +656,36 @@ def _propose(
642656 self .input_ids [last_token_indices ] = next_token_ids
643657
644658 # update pcp related params
645- if self .pcp_size > 1 :
659+ if self .pcp_size * self . dcp_size > 1 :
646660 assert long_seq_metadata is not None
647661 common_attn_metadata .prefill_context_parallel_metadata = long_seq_metadata
662+ ori_last_token_indices = last_token_indices .clone ()
663+ query_lens_d = self .runner .query_lens [:num_decode_reqs ]
664+ if self .pcp_size > 1 :
648665 # 1. preprocess decode/prefill input_ids & target_hidden_states
649666 # decode input_ids: keep unchanged
650667 # decode target_hidden_states: remove padding
651668 # prefill input_ids: add padding and pcp split
652669 # prefill target_hidden_states: pcp split
653- num_tokens_d = num_decode_reqs * self . decode_threshold
670+ num_tokens_d = query_lens_d . sum (). item ()
654671 num_tokens_d_padded = num_tokens_d * self .pcp_size
655672 input_ids_d = self .input_ids [:num_tokens_d ]
656673 input_ids_p = self .input_ids [num_tokens_d :num_tokens ]
657674 target_hidden_states_d_padded = \
658675 target_hidden_states [:num_tokens_d_padded ]
659676 if num_tokens_d :
660677 # remove padding (from pcp all-gather) in decode part
661- target_hidden_states_d = target_hidden_states_d_padded .reshape (
662- [
663- num_decode_reqs , self .decode_threshold * self .pcp_size ,
664- - 1
665- ])[:, :self .decode_threshold , :].reshape (
666- [num_tokens_d , - 1 ])
678+ mask_start_loc = torch .cat ([
679+ torch .tensor ([0 ], dtype = torch .int32 ),
680+ torch .cumsum (query_lens_d * self .pcp_size , dim = 0 )[:- 1 ]
681+ ])
682+ mask_len = query_lens_d
683+ mask = []
684+ for req_id in range (num_decode_reqs ):
685+ mask += list (
686+ range (mask_start_loc [req_id ],
687+ mask_start_loc [req_id ] + mask_len [req_id ]))
688+ target_hidden_states_d = target_hidden_states_d_padded [mask ]
667689 else :
668690 target_hidden_states_d = target_hidden_states_d_padded
669691 target_hidden_states_p = target_hidden_states [num_tokens_d_padded :]
@@ -798,10 +820,15 @@ def _propose(
798820 forward_context = get_forward_context ()
799821 if forward_context .cudagraph_runtime_mode == CUDAGraphMode .FULL :
800822 if self .vllm_config .model_config .use_mla :
801- update_mla_attn_params (
802- self .update_stream , forward_context ,
803- num_input_tokens ,
804- self .vllm_config .speculative_config )
823+ if self .pcp_size * self .dcp_size > 1 :
824+ update_mla_attn_dcp_pcp_params (
825+ self .update_stream , forward_context ,
826+ num_input_tokens )
827+ else :
828+ update_mla_attn_params (
829+ self .update_stream , forward_context ,
830+ num_input_tokens ,
831+ self .vllm_config .speculative_config )
805832
806833 if self .enable_shared_expert_dp :
807834 hidden_states = torch .ops .vllm .maybe_all_gather_and_maybe_unpad (
@@ -820,6 +847,8 @@ def _propose(
820847 (0 , max_num_reqs_across_dp - num_indices ))
821848
822849 if self .pcp_size > 1 :
850+ # remove graph padding before all_gather
851+ hidden_states = hidden_states [:num_tokens ]
823852 hidden_states = get_pcp_group ().all_gather (hidden_states , 0 )
824853 hidden_states = torch .index_select (
825854 hidden_states , 0 , self .runner .
@@ -860,6 +889,51 @@ def _propose(
860889 last_token_indices = self .arange [:batch_size ]
861890 if getattr (attn_metadata_i , "num_decode_tokens" , 0 ):
862891 attn_metadata_i .num_decode_tokens = batch_size
892+ if self .pcp_size * self .dcp_size > 1 :
893+ positions = target_positions [ori_last_token_indices ]
894+ # For pcp/dcp, tokens are split across different cp ranks,
895+ # so we can not simply update slot_mapping by += 1.
896+ # Instead, we pre-allocate mtp slot_mapping in model_runner
897+ # (_generate_pcp_mtp_input), and use updated slot_indices
898+ # to get corresponding slot_mapping in each step.
899+ num_reject_tokens = torch .tensor (
900+ self .runner .cu_num_tokens_pcp_full ,
901+ dtype = torch .int32 ).to (
902+ self .device ) - ori_last_token_indices - 1
903+ num_accept_tokens = \
904+ query_lens_d .to (self .device ) - num_reject_tokens
905+ ori_seq_len = attn_metadata_i .seq_lens
906+ mtp_slot_mapping = self .runner .mtp_slot_pad
907+
908+ # slot_mapping index base offset:
909+ # scheduled tokens + pre-allocated mtp tokens + accepted tokens
910+ slot_idx_base = (
911+ torch .cat ([
912+ torch .tensor (
913+ [0 ], dtype = torch .int32 , device = self .device ),
914+ (torch .cumsum (query_lens_d , dim = 0 )[:- 1 ] *
915+ self .pcp_size ).to (self .device )
916+ ]) +
917+ torch .arange (num_decode_reqs , device = self .device ) *
918+ (self .num_speculative_tokens - 1 ) * self .pcp_size +
919+ (num_accept_tokens - 1 ) * self .pcp_size )
920+ slot_indices_list = []
921+ for req_id in range (num_decode_reqs ):
922+ slot_indices_list .append (
923+ torch .arange (slot_idx_base [req_id ],
924+ slot_idx_base [req_id ] + self .pcp_size ,
925+ device = self .device ))
926+ slot_indices = torch .cat (slot_indices_list , dim = 0 )
927+
928+ # fold block_table (restore it to original size before flattened)
929+ block_indices = torch .cat ([
930+ torch .tensor ([0 ], dtype = torch .int32 ),
931+ torch .cumsum (query_lens_d , dim = 0 )[:- 1 ]
932+ ])
933+ attn_metadata_i .decode .block_table [:batch_size ] = \
934+ attn_metadata_i .decode .block_table [block_indices ]
935+ attn_metadata_i .decode .block_table = \
936+ attn_metadata_i .decode .block_table [:batch_size ]
863937
864938 input_ids = draft_token_ids_list [- 1 ].int ()
865939 positions += 1
@@ -906,13 +980,40 @@ def _propose(
906980 # Otherwise, the KV cache will be inadvertently updated with the
907981 # padding tokens.
908982 slot_mapping += 1
983+ if self .pcp_size > 1 :
984+ exceeds_max_model_len = exceeds_max_model_len .repeat_interleave (
985+ slot_mapping .size (0 ) // exceeds_max_model_len .size (0 ))
909986 slot_mapping .masked_fill_ (exceeds_max_model_len , PADDING_SLOT_ID )
910987
911988 # copy inputs to buffer for cudagraph
912989 self .input_ids [:batch_size ] = input_ids
913990 self .positions [:batch_size ] = clamped_positions
914991 self .hidden_states [:hidden_states .shape [0 ]] = hidden_states
915- attn_metadata_i .slot_mapping [:batch_size ] = slot_mapping
992+ if self .pcp_size * self .dcp_size > 1 :
993+ # update local seq_len and batch_seq_mask
994+ num_computed_tokens_of_pcp_dcp = self .runner ._get_cp_local_seq_lens (
995+ ori_seq_len + step + 1 ,
996+ self .pcp_size ,
997+ self .dcp_size ,
998+ self .runner .parallel_config .cp_kv_cache_interleave_size ,
999+ )
1000+ cp_seq_len = \
1001+ num_computed_tokens_of_pcp_dcp [:, self .pcp_rank , self .dcp_rank ]
1002+ batch_seq_mask = (cp_seq_len == 0 )
1003+ builder .batch_seq_mask_buf [:batch_seq_mask .shape [0 ]].copy_ (
1004+ batch_seq_mask , non_blocking = True )
1005+ batch_seq_mask = builder .batch_seq_mask_buf [:batch_seq_mask .
1006+ shape [0 ]]
1007+ cp_seq_len = torch .where (cp_seq_len == 0 , 1 , cp_seq_len )
1008+ attn_metadata_i .decode .cp_seq_len = cp_seq_len
1009+ attn_metadata_i .decode .batch_seq_mask = batch_seq_mask
1010+ # update slot_mapping
1011+ slot_indices += self .pcp_size
1012+ slot_mapping = mtp_slot_mapping [slot_indices ]
1013+ attn_metadata_i .slot_mapping [:batch_size *
1014+ self .pcp_size ] = slot_mapping
1015+ else :
1016+ attn_metadata_i .slot_mapping [:batch_size ] = slot_mapping
9161017 if self .speculative_config .disable_padded_drafter_batch :
9171018 self .positions [batch_size :num_input_tokens ] = 0
9181019 self .input_ids [batch_size :num_input_tokens ] = 0
0 commit comments