@@ -404,8 +404,17 @@ def build(
404404 num_actual_tokens_pcp_padded = long_seq_metadata .num_actual_tokens_pcp_padded if long_seq_metadata else None
405405 num_computed_tokens_of_pcp_dcp = long_seq_metadata .num_computed_tokens_of_pcp_dcp if long_seq_metadata else None
406406
407+ query_lens_pcp_full = long_seq_metadata .query_lens_pcp_full_cpu \
408+ if long_seq_metadata else None
409+ max_query_len_pcp_full = long_seq_metadata .max_query_len_pcp_full \
410+ if long_seq_metadata else None
407411 num_decodes , num_prefills , num_decode_tokens , num_prefill_tokens = \
408- split_decodes_and_prefills (common_attn_metadata , decode_threshold = self .decode_threshold )
412+ split_decodes_and_prefills (
413+ common_attn_metadata ,
414+ decode_threshold = self .decode_threshold ,
415+ query_lens_pcp_full = query_lens_pcp_full ,
416+ max_query_len_pcp_full = max_query_len_pcp_full ,
417+ )
409418 assert num_decodes + num_prefills == num_reqs
410419 assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
411420
@@ -422,17 +431,9 @@ def build(
422431 common_attn_metadata .block_table_tensor [:graph_pad_size ])
423432 else :
424433 block_table = (common_attn_metadata .block_table_tensor [:num_reqs ])
425- # NOTE: Currently, MTP-fullgraph is incompatibility pcp
426- if self .pcp_size > 1 :
427- num_decodes_flatten = num_decodes * self .decode_threshold
428- block_table = common_attn_metadata .block_table_tensor [:
429- num_decodes_flatten
430- +
431- num_prefills ]
432434 if num_actual_tokens_pcp_padded is None :
433435 num_actual_tokens_pcp_padded = num_actual_tokens
434436
435- # NOTE: Currently, MTP-fullgraph is incompatibility pcp
436437 slot_mapping = common_attn_metadata .slot_mapping [:
437438 num_actual_tokens_pcp_padded ]
438439 input_positions = common_attn_metadata .positions [:
@@ -455,6 +456,13 @@ def build(
455456 seq_lens = common_attn_metadata .seq_lens_cpu [:num_reqs ]
456457 num_computed_tokens_cpu = (seq_lens - query_lens )
457458
459+ if self .pcp_size * self .dcp_size > 1 :
460+ num_decodes_flatten = query_lens [:num_decodes ].sum ().item ()
461+ block_table = common_attn_metadata .block_table_tensor [:
462+ num_decodes_flatten
463+ +
464+ num_prefills ]
465+
458466 prefill_metadata = None
459467 chunked_context_metadata = None
460468 if num_prefills > 0 :
@@ -519,8 +527,9 @@ def build(
519527 if self .dcp_size * self .pcp_size > 1 :
520528 if num_computed_tokens_of_pcp_dcp is not None :
521529 local_context_lens_allranks = torch .tensor (
522- num_computed_tokens_of_pcp_dcp [reqs_start :num_reqs ]
523- ).reshape (- 1 , self .dcp_size * self .pcp_size )
530+ num_computed_tokens_of_pcp_dcp [
531+ num_decodes_flatten :]).reshape (
532+ - 1 , self .dcp_size * self .pcp_size )
524533 # Note(qcs): The max local context lengths
525534 # padded to `cp_local_block_size`.
526535 padded_local_context_lens_cpu = (cdiv (
@@ -614,7 +623,7 @@ def build(
614623 cos = cos ,
615624 pcp_metadata = pcp_metadata ,
616625 )
617- if self .pcp_size > 1 :
626+ if self .pcp_size * self . dcp_size > 1 :
618627 prefill_metadata .block_table = block_table [
619628 num_decodes_flatten :, ...]
620629
@@ -628,13 +637,12 @@ def build(
628637 max_seq_lens = seq_lens [:num_decodes ].max ().item ()
629638 seq_lens = seq_lens [:num_decodes ]
630639 input_positions = input_positions [:num_decode_tokens ]
631- if self .pcp_size > 1 :
640+ if self .pcp_size * self . dcp_size > 1 :
632641 # For pcp + spec decode, we flatten seq_lens and block_table
633642 # to avoid irregular spec_attn_mask shape
634643 block_table = block_table [:num_decodes_flatten , ...]
635644 else :
636645 block_table = block_table [:num_decodes , ...]
637- # NOTE: Currently, MTP-fullgraph is incompatibility pcp
638646 # NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
639647 if graph_pad_size > num_decodes and \
640648 self .speculative_config .disable_padded_drafter_batch :
@@ -644,8 +652,7 @@ def build(
644652 if num_computed_tokens_of_pcp_dcp is not None :
645653 # [bs, pcp_size, dcp_size]
646654 num_computed_tokens_of_cp_dcp_array = np .array (
647- num_computed_tokens_of_pcp_dcp )[:num_decodes *
648- self .decode_threshold ]
655+ num_computed_tokens_of_pcp_dcp )[:num_decodes_flatten ]
649656
650657 cp_seq_len = num_computed_tokens_of_cp_dcp_array [:,
651658 self .pcp_rank ,
@@ -1872,8 +1879,11 @@ def _forward_decode_pcp_dcp(
18721879 "return_lse" : True ,
18731880 "calc_type" : "calc_type_ring" ,
18741881 }
1875- graph_params = get_graph_params ()
18761882 forward_context : ForwardContext = get_forward_context ()
1883+ if forward_context .is_mtp_model :
1884+ graph_params = get_mtp_graph_params ()
1885+ else :
1886+ graph_params = get_graph_params ()
18771887 if forward_context .capturing :
18781888 stream = torch_npu .npu .current_stream ()
18791889 event = torch .npu .ExternalEvent ()
0 commit comments