-
Notifications
You must be signed in to change notification settings - Fork 659
[feature] support pcp + mtp in full graph #4572
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for PCP (Prefill Context Parallelism) and MTP (Multi-Token Prediction) in full graph mode, along with several related bug fixes. The changes correctly generalize PCP-only logic to accommodate DCP (Decode Context Parallelism) as well. A notable improvement is the handling of variable query lengths in speculative decoding batches, which replaces assumptions of fixed lengths with more robust logic. However, I've identified one critical issue in the implementation that needs to be addressed.
b08104c to
b330b75
Compare
| # prefill target_hidden_states: pcp split | ||
| num_tokens_d = num_decode_reqs * self.decode_threshold | ||
| query_lens_d = self.runner.query_lens[:num_decode_reqs] | ||
| num_tokens_d = query_lens_d.sum().item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is query_lens_d a device tensor? if so, you call query_lens_d.sum().item() will incur cpu blocking, please fix it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a host tensor, refer to model_runner_v1.py: self.query_lens = torch.from_numpy(num_scheduled_tokens), so it will not influence host device sync. The _d is abbreviation of _decode.
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
bf5139c to
7cc6d7b
Compare
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
64c24c1 to
df5a7d2
Compare
Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
Uh oh!
There was an error while loading. Please reload this page.