Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 171 additions & 0 deletions tests/e2e/multicard/long_sequence/test_mtp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
#

import os

import pytest
from vllm import SamplingParams

from tests.e2e.conftest import VllmRunner
from vllm_ascend.utils import vllm_version_is

os.environ["HCCL_BUFFSIZE"] = "512"


@pytest.mark.skipif(vllm_version_is('0.12.0'),
reason="0.12.0 is not supported for context sequence.")
def test_pcp_dcp_mtp1_eager():
prompts = [
"The capital of France is", "Hello, my name is Tom, I am",
"The president of United States is", "AI future is"
]
model = "wemaster/deepseek_mtp_main_random_bf16"
sampling_params = SamplingParams(max_tokens=32, temperature=0.0)
with VllmRunner(
model,
max_model_len=1024,
tensor_parallel_size=2,
prefill_context_parallel_size=2,
decode_context_parallel_size=2,
max_num_batched_tokens=1024,
enable_expert_parallel=True,
block_size=128,
speculative_config={
"num_speculative_tokens": 1,
"method": "deepseek_mtp",
},
enforce_eager=True,
) as runner:
runner.model.generate(prompts, sampling_params)


@pytest.mark.skipif(vllm_version_is('0.12.0'),
reason="0.12.0 is not supported for context sequence.")
def test_pcp_dcp_mtp3_eager():
prompts = [
"The capital of France is", "Hello, my name is Tom, I am",
"The president of United States is", "AI future is"
]
model = "wemaster/deepseek_mtp_main_random_bf16"
sampling_params = SamplingParams(max_tokens=32, temperature=0.0)
with VllmRunner(
model,
max_model_len=1024,
tensor_parallel_size=2,
prefill_context_parallel_size=2,
decode_context_parallel_size=2,
max_num_batched_tokens=1024,
enable_expert_parallel=True,
block_size=128,
speculative_config={
"num_speculative_tokens": 3,
"method": "deepseek_mtp",
},
enforce_eager=True,
) as runner:
runner.model.generate(prompts, sampling_params)


@pytest.mark.skipif(vllm_version_is('0.12.0'),
reason="0.12.0 is not supported for context sequence.")
def test_pcp_dcp_mtp3_piecewise_graph():
prompts = [
"The capital of France is", "Hello, my name is Tom, I am",
"The president of United States is", "AI future is"
]
model = "wemaster/deepseek_mtp_main_random_bf16"
sampling_params = SamplingParams(max_tokens=32, temperature=0.0)
with VllmRunner(
model,
max_model_len=1024,
tensor_parallel_size=2,
prefill_context_parallel_size=2,
decode_context_parallel_size=2,
max_num_batched_tokens=1024,
enable_expert_parallel=True,
block_size=128,
speculative_config={
"num_speculative_tokens": 3,
"method": "deepseek_mtp",
},
compilation_config={
"cudagraph_mode": "PIECEWISE",
"cudagraph_capture_sizes": [4, 8, 16],
},
) as runner:
runner.model.generate(prompts, sampling_params)


@pytest.mark.skipif(vllm_version_is('0.12.0'),
reason="0.12.0 is not supported for context sequence.")
def test_pcp_dcp_mtp3_full_graph():
prompts = [
"The capital of France is", "Hello, my name is Tom, I am",
"The president of United States is", "AI future is"
]
model = "wemaster/deepseek_mtp_main_random_bf16"
sampling_params = SamplingParams(max_tokens=32, temperature=0.0)
with VllmRunner(
model,
max_model_len=1024,
tensor_parallel_size=2,
prefill_context_parallel_size=2,
decode_context_parallel_size=2,
max_num_batched_tokens=1024,
enable_expert_parallel=True,
block_size=128,
speculative_config={
"num_speculative_tokens": 3,
"method": "deepseek_mtp",
},
compilation_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
"cudagraph_capture_sizes": [4, 8, 16],
},
) as runner:
runner.model.generate(prompts, sampling_params)


@pytest.mark.skipif(vllm_version_is('0.12.0'),
reason="0.12.0 is not supported for context sequence.")
def test_dcp_mtp3_full_graph():
prompts = [
"The capital of France is", "Hello, my name is Tom, I am",
"The president of United States is", "AI future is"
]
model = "wemaster/deepseek_mtp_main_random_bf16"
sampling_params = SamplingParams(max_tokens=32, temperature=0.0)
with VllmRunner(
model,
max_model_len=1024,
tensor_parallel_size=2,
decode_context_parallel_size=2,
max_num_batched_tokens=1024,
enable_expert_parallel=True,
block_size=128,
speculative_config={
"num_speculative_tokens": 3,
"method": "deepseek_mtp",
},
compilation_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
"cudagraph_capture_sizes": [4, 8, 16],
},
) as runner:
runner.model.generate(prompts, sampling_params)
23 changes: 19 additions & 4 deletions tests/ut/spec_decode/test_mtp_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch

from vllm_ascend.ascend_config import init_ascend_config
Expand Down Expand Up @@ -216,10 +217,23 @@ def test_generate_token_ids(self, mock_cpu_gpu_buffer):
mock_deps.runner.input_ids = torch.arange(16, dtype=torch.int32)
mock_deps.runner.spec_decode_common_attn_metadata = MagicMock()
mock_deps.runner.pcp_size = 2
mock_deps.runner.input_ids_pcp_full = torch.arange(32,
dtype=torch.int32)
mock_deps.runner.query_start_loc_pcp_full_cpu = torch.tensor(
[0, 8, 16, 24, 32])
mock_deps.runner.dcp_size = 1
mock_deps.runner.input_ids_pcp_full = CpuGpuBuffer(
32,
dtype=torch.int32,
pin_memory=False,
device='cpu',
)
mock_deps.runner.input_ids_pcp_full.cpu = \
torch.arange(32, dtype=torch.int32)
mock_deps.runner.query_start_loc_pcp_full = CpuGpuBuffer(
5,
dtype=torch.int32,
pin_memory=False,
device='cpu',
)
mock_deps.runner.query_start_loc_pcp_full.cpu = \
torch.tensor([0, 8, 16, 24, 32])
mock_deps.positions = torch.arange(16, dtype=torch.int32)
mock_deps.hidden_states = torch.zeros(16, 4096, dtype=torch.float16)
mock_deps.sampled_token_ids = torch.tensor([[100, 101, -1],
Expand All @@ -233,6 +247,7 @@ def test_generate_token_ids(self, mock_cpu_gpu_buffer):
proposer.speculative_config = MagicMock(
disable_padded_drafter_batch=False)
proposer.pcp_size = mock_deps.runner.pcp_size
proposer.dcp_size = mock_deps.runner.dcp_size
proposer.prepare_next_token_ids_padded = MagicMock(
return_value=(torch.tensor([101, 200, 302]), 3))
proposer.prepare_inputs_padded = MagicMock(
Expand Down
22 changes: 15 additions & 7 deletions tests/ut/worker/test_model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens,

mock_runner.input_batch = MagicMock()
mock_runner.input_batch.num_reqs = num_reqs
mock_runner.speculative_config = None

num_computed_tokens = []
num_prompt_tokens = []
Expand Down Expand Up @@ -169,23 +170,24 @@ def test_pcp_allgather_restore_idx_slicing():


@pytest.mark.parametrize(
"tokens, num_reqs, num_computed_tokens, num_prompt_tokens, pcp_size, pcp_rank, expected_pcp_tokens",
"tokens, num_reqs, num_computed_tokens, num_prompt_tokens," \
"pcp_size, pcp_rank, decode_threshold, expected_pcp_tokens",
[
# Case 1: prefill only
([8, 12, 16], 3, [0, 0, 0], [8, 12, 16], 4, 0, [2, 4, 4]),
([8, 12, 16], 3, [0, 0, 0], [8, 12, 16], 4, 0, 1, [2, 4, 4]),
# Case 2: mix prefill and decode
([8, 4, 12], 3, [8, 4, 0], [8, 4, 12], 4, 0, [8, 4, 4]),
# Case 2: mix prefill and decode (with spec decode)
([8, 4, 12], 3, [8, 4, 0], [8, 4, 12], 4, 0, 8, [8, 4, 4]),
# Case 3: request which need to be padded
([3, 7, 9], 3, [0, 0, 0], [3, 7, 9], 4, 0, [2, 2, 4]),
([3, 7, 9], 3, [0, 0, 0], [3, 7, 9], 4, 0, 1, [2, 2, 4]),
# Case 4: single request
([10], 1, [0], [10], 4, 0, [4]),
([10], 1, [0], [10], 4, 0, 1, [4]),
])
def test_update_tokens_for_pcp_basic(tokens, num_reqs, num_computed_tokens,
num_prompt_tokens, pcp_size, pcp_rank,
expected_pcp_tokens):
decode_threshold, expected_pcp_tokens):
mock_runner = MagicMock(spec=NPUModelRunner)
mock_runner.pcp_size = pcp_size
mock_runner.pcp_rank = pcp_rank
Expand All @@ -201,6 +203,7 @@ def test_update_tokens_for_pcp_basic(tokens, num_reqs, num_computed_tokens,

mock_runner.num_pcp_pads = [0] * num_reqs
mock_runner.arange_np = np.arange(10000)
mock_runner.decode_threshold = decode_threshold

mock_runner._update_tokens_for_pcp = NPUModelRunner._update_tokens_for_pcp.__get__(
mock_runner, NPUModelRunner)
Expand Down Expand Up @@ -243,6 +246,7 @@ def test_update_tokens_for_pcp_with_padding():

mock_runner.num_pcp_pads = [0, 0, 0]
mock_runner.pcp_allgather_restore_idx = torch.zeros(1000, dtype=torch.long)
mock_runner.decode_threshold = 1

mock_runner._update_tokens_for_pcp = NPUModelRunner._update_tokens_for_pcp.__get__(
mock_runner, NPUModelRunner)
Expand Down Expand Up @@ -279,6 +283,7 @@ def test_update_tokens_for_pcp_unpad_mask():

mock_runner.num_pcp_pads = [0, 0]
mock_runner.pcp_allgather_restore_idx = torch.zeros(1000, dtype=torch.long)
mock_runner.decode_threshold = 1

mock_runner._update_tokens_for_pcp = NPUModelRunner._update_tokens_for_pcp.__get__(
mock_runner, NPUModelRunner)
Expand Down Expand Up @@ -369,6 +374,9 @@ def pcp_mtp_mock_runner():

mock_runner.input_ids_pcp_full = NPUModelRunner._make_buffer(
mock_runner, max_num_tokens, dtype=torch.int32)
mock_runner.query_lens_pcp_full = NPUModelRunner._make_buffer(
mock_runner, max_num_reqs, dtype=torch.int32)
mock_runner.decode_threshold = 1

mock_runner.arange_np = np.arange(max_model_len)
mock_runner.input_batch = MagicMock()
Expand Down
44 changes: 21 additions & 23 deletions vllm_ascend/attention/mla_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
split_decodes_and_prefills,
wait_for_kv_layer_from_connector)
from vllm_ascend.compilation.acl_graph import (get_graph_params,
get_mtp_graph_params,
update_graph_params_workspaces)
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
from vllm_ascend.ops.shared_weight_layer import (
Expand Down Expand Up @@ -92,6 +93,10 @@ def build(
num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded
if num_actual_tokens_pcp_padded is None:
num_actual_tokens_pcp_padded = num_actual_tokens
# In dcp only spec decode graph padding case,
# num_actual_tokens_pcp_padded may be less than num_actual_tokens
num_actual_tokens_pcp_padded = max(num_actual_tokens_pcp_padded,
num_actual_tokens)
num_computed_tokens_of_pcp_dcp = long_seq_metadata.num_computed_tokens_of_pcp_dcp
assert num_computed_tokens_of_pcp_dcp is not None

Expand All @@ -113,15 +118,6 @@ def build(
common_attn_metadata.block_table_tensor[:graph_pad_size])
else:
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
if self.pcp_size > 1:
num_decodes_flatten = num_decodes * self.decode_threshold
block_table = common_attn_metadata.block_table_tensor[:
num_decodes_flatten
+
num_prefills]

# NOTE: Currently, MTP-fullgraph is incompatibility pcp
slot_mapping = common_attn_metadata.slot_mapping[:
num_actual_tokens_pcp_padded]
input_positions = common_attn_metadata.positions[:
Expand All @@ -144,6 +140,13 @@ def build(
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
num_computed_tokens_cpu = (seq_lens - query_lens)

# For pcp + spec decode, we flatten seq_lens and block_table
# to avoid irregular spec_attn_mask shape
num_decodes_flatten = query_lens[:num_decodes].sum().item()
block_table = common_attn_metadata.block_table_tensor[:
num_decodes_flatten
+ num_prefills]

prefill_metadata = None
chunked_context_metadata = None
if num_prefills > 0:
Expand Down Expand Up @@ -201,7 +204,7 @@ def build(
dtype=torch.int32)

local_context_lens_allranks = torch.tensor(
num_computed_tokens_of_pcp_dcp[reqs_start:num_reqs]
num_computed_tokens_of_pcp_dcp[num_decodes_flatten:]
).reshape(-1, self.dcp_size * self.pcp_size)
# Note(qcs): The max local context lengths
# padded to `cp_local_block_size`.
Expand Down Expand Up @@ -280,9 +283,8 @@ def build(
cos=cos,
pcp_metadata=pcp_metadata,
)
if self.pcp_size > 1:
prefill_metadata.block_table = block_table[
num_decodes_flatten:, ...]
prefill_metadata.block_table = \
block_table[num_decodes_flatten:, ...]

decode_metadata = None
if num_decodes > 0:
Expand All @@ -293,13 +295,7 @@ def build(
max_seq_lens = seq_lens[:num_decodes].max().item()
seq_lens = seq_lens[:num_decodes]
input_positions = input_positions[:num_decode_tokens]
if self.pcp_size > 1:
# For pcp + spec decode, we flatten seq_lens and block_table
# to avoid irregular spec_attn_mask shape
block_table = block_table[:num_decodes_flatten, ...]
else:
block_table = block_table[:num_decodes, ...]
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
block_table = block_table[:num_decodes_flatten, ...]
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
if graph_pad_size > num_decodes and \
self.speculative_config.disable_padded_drafter_batch:
Expand All @@ -308,8 +304,7 @@ def build(

# [bs, pcp_size, dcp_size]
num_computed_tokens_of_cp_dcp_array = np.array(
num_computed_tokens_of_pcp_dcp)[:num_decodes *
self.decode_threshold]
num_computed_tokens_of_pcp_dcp)[:num_decodes_flatten]

cp_seq_len = num_computed_tokens_of_cp_dcp_array[:, self.pcp_rank,
self.dcp_rank]
Expand Down Expand Up @@ -1057,8 +1052,11 @@ def _forward_decode_pcp_dcp(
"return_lse": True,
"calc_type": "calc_type_ring",
}
graph_params = get_graph_params()
forward_context: ForwardContext = get_forward_context()
if forward_context.is_mtp_model:
graph_params = get_mtp_graph_params()
else:
graph_params = get_graph_params()
if forward_context.capturing:
stream = torch_npu.npu.current_stream()
event = torch.npu.ExternalEvent()
Expand Down
Loading
Loading