-
Notifications
You must be signed in to change notification settings - Fork 335
Qw35 mtp #1371
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
+1,471
−297
Closed
Qw35 mtp #1371
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
89a27fd
feat(mtp): MTP verify-decode infrastructure
sufubao 61e4dcb
feat(qwen3_5_mtp): Qwen3.5 / Qwen3.5-MoE MTP draft models
sufubao 89a163a
feat(qwen3next): GDN spec-decode verify path + linear-att cache split
sufubao 47ddb6c
feat(scheduler): MTP verify backend + accept-len transport
sufubao 085a185
test(mtp): MTP unit tests + static benchmark
sufubao db50f25
Fix Qwen3Next MTP linear-att page moves
sufubao 45ec253
revert formatting churn on pre-existing code
sufubao 5883b41
revert(mtp): drop eagle reduced-batch draft optimization
sufubao 82522e6
revert(mtp): run the MTP draft on upstream's grouped verify layout
sufubao cd6b918
clean code
sufubao fe9ac22
clean code
sufubao 10473dd
refactor(mtp): GPU-resident req_to_accept_len + simplify verify-decod…
sufubao 45831a2
revert: drop all test/ and unit_tests/ changes from this branch
sufubao 31fa641
style: black-format fp8.py k/v_descale lines (pre-commit)
sufubao c4c3c2f
clean code
sufubao 6f78b54
Merge upstream/main into qw35_mtp_feature
sufubao 10a6f38
Merge qw35_mtp_feature (Qwen3.5 MTP support) into rl_verl_rebase_main
sufubao 9f3ecaa
fix(mtp): always allocate req_to_accept_len so non-MTP hybrid models …
sufubao File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,4 +7,7 @@ dist | |
| .vscode | ||
| tmp/ | ||
| requirements-musa.txt | ||
| logs/ | ||
| logs/ | ||
|
|
||
| /benchmark/ | ||
| artifacts/ | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -45,9 +45,12 @@ def init_state(self): | |||||
| torch.arange(batch_size, device=device), self.infer_state.b_q_seq_len | ||||||
| ) | ||||||
| # 为了减少推理计算量,在推理外部初始化k_descale和v_descale | ||||||
| self.k_descale = offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) | ||||||
| self.v_descale = offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) | ||||||
|
|
||||||
| self.k_descale = ( | ||||||
| offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) | ||||||
| ) | ||||||
| self.v_descale = ( | ||||||
| offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) | ||||||
| ) | ||||||
|
|
||||||
| def prefill_att( | ||||||
| self, | ||||||
|
|
@@ -116,20 +119,19 @@ def init_state(self): | |||||
| super().init_state() | ||||||
| self.backend: Fp8Fa3AttBackend = self.backend | ||||||
|
|
||||||
| args_mtp_step = get_env_start_args().mtp_step | ||||||
| att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1) | ||||||
| assert self.infer_state.batch_size % (args_mtp_step + 1) == 0 | ||||||
|
|
||||||
| device = self.infer_state.input_ids.device | ||||||
| batch_size = att_batch_size | ||||||
| batch_size = self.b_att_seq_len.shape[0] | ||||||
| mem_manager = self.backend.model.mem_manager | ||||||
|
|
||||||
| offline_scales: torch.Tensor = mem_manager.scales | ||||||
| head_num = mem_manager.head_num | ||||||
|
|
||||||
| # 为了减少推理计算量,在推理外部初始化k_descale和v_descale | ||||||
| self.k_descale = offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) | ||||||
| self.v_descale = offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) | ||||||
| self.k_descale = ( | ||||||
| offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) | ||||||
| ) | ||||||
| self.v_descale = ( | ||||||
| offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) | ||||||
| ) | ||||||
|
|
||||||
| return | ||||||
|
|
||||||
|
|
@@ -180,11 +182,11 @@ def _fp8_decode_att( | |||||
| k_cache=cache_k, | ||||||
| v_cache=cache_v, | ||||||
| page_table=self.page_table, | ||||||
| cache_seqlens=self.infer_state.b_seq_len, | ||||||
| cache_seqlens=self.b_att_seq_len, | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Accessing
Suggested change
|
||||||
| cu_seqlens_q=self.cu_seqlens_q, | ||||||
| cu_seqlens_k_new=self.cu_seqlens_k, | ||||||
| max_seqlen_q=self.decode_max_q_seq_len, | ||||||
| causal=False, | ||||||
| causal=True, | ||||||
| window_size=(-1, -1), | ||||||
| softcap=0.0, | ||||||
| q_descale=q_scale.view(self.infer_state.batch_size, k_head_num), | ||||||
|
|
||||||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accessing
self.b_att_seq_lendirectly onselfwill raise anAttributeErrorbecauseb_att_seq_lenis initialized onself.infer_state(viainit_mtp_verify_extra_state), not on the attention layer object itself. It should be accessed viaself.infer_state.b_att_seq_len.