Skip to content

Commit 64c24c1

Browse files
committed
[feature] support pcp + mtp in full graph
Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
1 parent 134e011 commit 64c24c1

File tree

6 files changed

+264
-58
lines changed

6 files changed

+264
-58
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -404,8 +404,17 @@ def build(
404404
num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded if long_seq_metadata else None
405405
num_computed_tokens_of_pcp_dcp = long_seq_metadata.num_computed_tokens_of_pcp_dcp if long_seq_metadata else None
406406

407+
query_lens_pcp_full = long_seq_metadata.query_lens_pcp_full_cpu \
408+
if long_seq_metadata else None
409+
max_query_len_pcp_full = long_seq_metadata.max_query_len_pcp_full \
410+
if long_seq_metadata else None
407411
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
408-
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
412+
split_decodes_and_prefills(
413+
common_attn_metadata,
414+
decode_threshold=self.decode_threshold,
415+
query_lens_pcp_full=query_lens_pcp_full,
416+
max_query_len_pcp_full=max_query_len_pcp_full,
417+
)
409418
assert num_decodes + num_prefills == num_reqs
410419
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
411420

@@ -422,17 +431,9 @@ def build(
422431
common_attn_metadata.block_table_tensor[:graph_pad_size])
423432
else:
424433
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
425-
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
426-
if self.pcp_size > 1:
427-
num_decodes_flatten = num_decodes * self.decode_threshold
428-
block_table = common_attn_metadata.block_table_tensor[:
429-
num_decodes_flatten
430-
+
431-
num_prefills]
432434
if num_actual_tokens_pcp_padded is None:
433435
num_actual_tokens_pcp_padded = num_actual_tokens
434436

435-
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
436437
slot_mapping = common_attn_metadata.slot_mapping[:
437438
num_actual_tokens_pcp_padded]
438439
input_positions = common_attn_metadata.positions[:
@@ -455,6 +456,13 @@ def build(
455456
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
456457
num_computed_tokens_cpu = (seq_lens - query_lens)
457458

459+
if self.pcp_size * self.dcp_size > 1:
460+
num_decodes_flatten = query_lens[:num_decodes].sum().item()
461+
block_table = common_attn_metadata.block_table_tensor[:
462+
num_decodes_flatten
463+
+
464+
num_prefills]
465+
458466
prefill_metadata = None
459467
chunked_context_metadata = None
460468
if num_prefills > 0:
@@ -519,8 +527,9 @@ def build(
519527
if self.dcp_size * self.pcp_size > 1:
520528
if num_computed_tokens_of_pcp_dcp is not None:
521529
local_context_lens_allranks = torch.tensor(
522-
num_computed_tokens_of_pcp_dcp[reqs_start:num_reqs]
523-
).reshape(-1, self.dcp_size * self.pcp_size)
530+
num_computed_tokens_of_pcp_dcp[
531+
num_decodes_flatten:]).reshape(
532+
-1, self.dcp_size * self.pcp_size)
524533
# Note(qcs): The max local context lengths
525534
# padded to `cp_local_block_size`.
526535
padded_local_context_lens_cpu = (cdiv(
@@ -614,7 +623,7 @@ def build(
614623
cos=cos,
615624
pcp_metadata=pcp_metadata,
616625
)
617-
if self.pcp_size > 1:
626+
if self.pcp_size * self.dcp_size > 1:
618627
prefill_metadata.block_table = block_table[
619628
num_decodes_flatten:, ...]
620629

@@ -628,13 +637,12 @@ def build(
628637
max_seq_lens = seq_lens[:num_decodes].max().item()
629638
seq_lens = seq_lens[:num_decodes]
630639
input_positions = input_positions[:num_decode_tokens]
631-
if self.pcp_size > 1:
640+
if self.pcp_size * self.dcp_size > 1:
632641
# For pcp + spec decode, we flatten seq_lens and block_table
633642
# to avoid irregular spec_attn_mask shape
634643
block_table = block_table[:num_decodes_flatten, ...]
635644
else:
636645
block_table = block_table[:num_decodes, ...]
637-
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
638646
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
639647
if graph_pad_size > num_decodes and \
640648
self.speculative_config.disable_padded_drafter_batch:
@@ -644,8 +652,7 @@ def build(
644652
if num_computed_tokens_of_pcp_dcp is not None:
645653
# [bs, pcp_size, dcp_size]
646654
num_computed_tokens_of_cp_dcp_array = np.array(
647-
num_computed_tokens_of_pcp_dcp)[:num_decodes *
648-
self.decode_threshold]
655+
num_computed_tokens_of_pcp_dcp)[:num_decodes_flatten]
649656

650657
cp_seq_len = num_computed_tokens_of_cp_dcp_array[:,
651658
self.pcp_rank,
@@ -1872,8 +1879,11 @@ def _forward_decode_pcp_dcp(
18721879
"return_lse": True,
18731880
"calc_type": "calc_type_ring",
18741881
}
1875-
graph_params = get_graph_params()
18761882
forward_context: ForwardContext = get_forward_context()
1883+
if forward_context.is_mtp_model:
1884+
graph_params = get_mtp_graph_params()
1885+
else:
1886+
graph_params = get_graph_params()
18771887
if forward_context.capturing:
18781888
stream = torch_npu.npu.current_stream()
18791889
event = torch.npu.ExternalEvent()

vllm_ascend/attention/utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ class AscendPrefillContextParallelMetadata:
4242

4343
pcp_prefill_mask: torch.Tensor = None
4444

45+
query_lens_pcp_full_cpu: torch.Tensor = None
46+
47+
max_query_len_pcp_full: int = None
48+
4549

4650
@dataclass
4751
class AscendCommonAttentionMetadata:
@@ -135,6 +139,8 @@ def filter_chunked_req_indices(
135139
def split_decodes_and_prefills(
136140
common_attn_metadata: AscendCommonAttentionMetadata,
137141
decode_threshold: int = 1,
142+
query_lens_pcp_full: torch.Tensor = None,
143+
max_query_len_pcp_full: int = None,
138144
) -> tuple[int, int, int, int]:
139145
"""
140146
Assuming a reordered batch, finds the boundary between prefill and decode
@@ -151,15 +157,17 @@ def split_decodes_and_prefills(
151157
num_decode_tokens: The number of tokens in the decode requests.
152158
num_prefill_tokens: The number of tokens in the prefill requests.
153159
"""
154-
max_query_len = common_attn_metadata.max_query_len
160+
max_query_len = common_attn_metadata.max_query_len \
161+
if max_query_len_pcp_full is None else max_query_len_pcp_full
155162
num_reqs = common_attn_metadata.num_reqs
156163
num_tokens = common_attn_metadata.num_actual_tokens
157164
query_start_loc = common_attn_metadata.query_start_loc_cpu
158165

159166
if max_query_len <= decode_threshold:
160167
return num_reqs, 0, num_tokens, 0
161168

162-
query_lens = query_start_loc[1:] - query_start_loc[:-1]
169+
query_lens = (query_start_loc[1:] - query_start_loc[:-1]) \
170+
if query_lens_pcp_full is None else query_lens_pcp_full
163171
is_prefill = query_lens > decode_threshold
164172
if not torch.any(is_prefill):
165173
return num_reqs, 0, num_tokens, 0

vllm_ascend/compilation/acl_graph.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,10 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
369369

370370
def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
371371
runtime_shape):
372-
graph_params = get_graph_params()
372+
if forward_context.is_mtp_model:
373+
graph_params = get_mtp_graph_params()
374+
else:
375+
graph_params = get_graph_params()
373376
# FIXME: Behold! We are using a temporary hack here to update the args
374377
# for each layer's attention op in the graph.
375378
with torch.npu.stream(update_stream):

0 commit comments

Comments
 (0)