Skip to content

Qw35 mtp#1371

Open
sufubao wants to merge 18 commits into
rl_verl_rebase_mainfrom
qw35_mtp_rl_verl_rebase
Open

Qw35 mtp#1371
sufubao wants to merge 18 commits into
rl_verl_rebase_mainfrom
qw35_mtp_rl_verl_rebase

Conversation

@sufubao

@sufubao sufubao commented Jun 29, 2026

Copy link
Copy Markdown
Collaborator

No description provided.

sufubao added 18 commits June 16, 2026 15:54
Model-agnostic verify-decode machinery: MTP-verify dispatch in TpPartBaseModel, dedicated decode CUDA-graph capture/replay for the (mtp_step+1)-expanded verify layout, a shared mtp_verify_extra_state block on infer_struct/batch_objs, fa3 decode attention narrowed to the verify layout (b_att_seq_len + causal) for fp/fp8/mla, and env/kv-cache helpers for MTP added-layer accounting.
Self-contained dense (qwen3_5_mtp) and MoE (qwen3_5_moe_mtp) MTP draft packages: each carries its own draft wiring (reuse the main model's req/mem managers + rope caches, is_mtp_draft_model marker) and shares a weight-retarget mixin (mtp.* head, embeddings shared with the main model) plus the MTP pre-layer fuse. No shared model base class.
Gated-delta-net (linear attention) speculative-decode verify path for qwen3next: a per-sequence spec causal_conv1d kernel; a widened conv working slot split from the committed (narrow) persisted slot; MTP draft full-attn KV-slot accounting across the linear-att cache config, mem operator and req manager; and removal of the dead gen_b_req_mtp_start_loc kernel.
Wire the verify path through the inference backends: a single draft-model factory keyed on (model_type, mtp_mode); build the (mtp_step+1)-expanded verify decode batch; run the eagle + vanilla draft decode; verify accepted tokens; and thread per-request accept-lengths (b_num_accepted_tokens) from the chunked-prefill and dp backends into the model verify forward.
Behavioural/CUDA coverage for the subtle MTP paths: verify-extra-state metadata, decode CUDA-graph verify layouts, fa3 fp8 verify narrowing, GDN verify equivalence, the spec causal_conv1d kernel and its prefill->decode roundtrip, and the linear-att conv/SSM widened-slot split + snapshot + CPU-cache persistence. Also extends the static-inference MTP benchmark and anchors the .gitignore benchmark-output rule to /benchmark.
Restore blank lines that were stripped from pre-existing definitions
(black-induced reformatting of upstream code that this PR didn't
functionally change). Keeps the diff focused on the MTP feature;
fixing historical formatting is out of scope for this PR.
Scope this branch to Qwen3.5 MTP support only by rolling back the
EAGLE-mode draft optimization. The draft model again runs the full
(mtp_step+1)-expanded verify layout instead of being narrowed to the
single accepted row per request.

- dp/chunked _draft_decode_eagle: restore full-layout draft (copy.copy +
  b_num_accepted_tokens=None so it routes to the (bs, False) graph); drop
  the per-rank padding helpers and accepted-row narrowing.
- base_backend: remove _build_eagle_accepted_draft_input /
  _scatter_accepted_next_token_ids.
- cuda_graph: the draft runs at multiples of (mtp_step+1) again, so
  collapse the dual batch-size sets to one and delete the now-redundant
  _get_graph_batch_sizes routing. Keep the (bs, is_mtp_verify_decode)
  graph key + verify-layout warmup (core GDN verify support, not the
  optimization).
- static benchmark: eagle path now measures the full-layout draft cost.
- tests: drop the two narrowed-draft tests; rewrite the dual-set tests to
  the single-set model (still cover the verify/normal key distinction).
Drop the remaining draft-side divergence from upstream so this branch is
scoped to Qwen3.5 MTP support only. The draft decode no longer clears
b_num_accepted_tokens to force a flat/normal layout; it reuses the main
model_input (still copy.copy'd to isolate per-step input_ids/b_seq_len/
mem_indexes mutations) and runs the same (mtp_step+1)-grouped verify
decode layout as the main model — exactly as upstream does.

For the pure-full-attention draft (qwen3_5_mtp: full_attention_interval=1,
no GDN) grouped and flat are numerically identical: each position k sees
KV [0, s+k) either way, same page-table entries, same RoPE positions; the
main verify forward already uses this geometry and is the validated path.
The earlier flat-draft only added an unnecessary (bs, False) cudagraph
layout + b_num_accepted_tokens gating; nothing the draft computes needs it.

- chunked_prefill/dp_backend: 6 draft fns (vanilla/eagle + dp overlap
  variants) stop clearing b_num_accepted_tokens.
- cuda_graph: draft warms up the verify graph key too (mtp_step>0 -> verify
  for both main and draft); delete the now-dead _is_mtp_draft_model.
- tests: rewrite the warmup-layout test (main+draft both verify; mtp_step==0
  -> normal) and drop the stale "draft uses normal layout" framing.

Keep is_mtp_verify_decode (main-model GDN verify still needs it) and the
committed fp8.py causal=True fix.

Verified live (QW35-122B-A10B, eagle_with_att, mtp_step=1, tp4): GSM8K
acc 0.964 / Invalid 0.000, accept 1.956/2.0 — matches pre-revert baseline
(no regression). Codex independent pass concurred (high confidence).
…e plumbing

- is_mtp_verify: drop the redundant `b_num_accepted_tokens is not None` clause
  (post grouped-revert it's implied by mtp_step>0 ∧ ¬prefill).
- Replace the per-step host round-trip for b_num_accepted_tokens with a
  GPU-resident ReqManager.req_to_accept_len: a triton scatter_mtp_accept_len
  after verify + a GDN-only gather in init_mtp_verify_extra_state. Removes the
  gen_from_list H2D rebuild, the phase-2 req.mtp_accept_len writeback, and the
  host attr (linear-att offload + resets now read/write the buffer).
- Drop the redundant `if mtp_step>0` guard inside decode_mtp/decode_overlap_mtp.
- config_objs: inline the mtp draft-layer count, dropping the _mtp_added_layer_num
  helper (kept get_added_mtp_kv_layer_num inlined in envs_utils).
- cpu_cache_meta: don't bump layer_num for linear-att models (the draft full-att
  slots are already in LinearAttCacheConfig.get_cpu_cache_big_page_bytes()).

Static checks pass (ast, flake8). The req_to_accept_len refactor is not yet
runtime-verified; pending a hybrid GSM8K + cudagraph-ON parity run.
Brings in PR #1349 (perf(qwen3next): drop q/k/v/a/b contiguous copies in GDN
fused_recurrent decode).

Semantic reconciliation in fused_recurrent.py (git auto-merged it silently
since this branch never textually touched the file, but the auto-resolution
would have broken MTP verify):

  #1349 makes the fused recurrent kernel decode-only by asserting
  `cu_seqlens is None`. This branch's MTP `_gdn_verify_kernel` drives the same
  kernel with `cu_seqlens` (variable-length verify chunks), so the bare
  auto-merge would crash verify-decode on that assert.

Resolution keeps BOTH: #1349's per-token strided no-copy decode path AND the
MTP verify varlen path. The strided kernel arithmetic is already general
(bos * stride_tok), so only the host wrapper needed fixing:
  - drop the two `assert cu_seqlens is None` guards
  - restore `N = B if cu_seqlens is None else len(cu_seqlens) - 1`
  - generalize `_ensure_qkv_token_strided` to accept the verify layout
    [1, tokens, head, dim] (token dim = dim 1) in addition to the decode
    layout [tokens, 1, head, dim] (token dim = dim 0); both are contiguous-tail
    column views, so no copy is needed in either case.

unit_tests/.../test_fused_recurrent_strided.py: dropped the
`test_cu_seqlens_is_not_supported` negative test (it asserted the lifted
decode-only contract); kept the decode strided-views equivalence test. Varlen
verify correctness is covered E2E by the MTP GSM8K accuracy check.

Claude-Session: https://claude.ai/code/session_01J3SiYM55DJg8ht4dmEPmUx
Brings Qwen3.5 / Qwen3.5-MoE MTP draft models and the GDN spec-decode
verify path onto the RL/verl rebase branch. fused_recurrent.py resolves to
the reconciled keep-both version (decode + variable-length verify); the 6
auto-merged core files were verified to layer MTP changes cleanly on top of
the RL branch with no overlap loss.
…don't crash

req_to_accept_len was allocated only when mtp_step > 0 (else None), but the
linear-att cache-copy paths in infer_batch.py (lines 409/436) index it
unconditionally — gated on linear-att + radix cache, not on mtp_step. A
non-MTP (mtp_step=0) qwen3next/Qwen3.5 run with prompt caching therefore hit
None[...] -> TypeError on the first prefill crossing a page boundary.

Always allocate the tensor (init 1 = no draft tokens accepted); the value and
behavior are unchanged for the MTP path (it already init'd with torch.ones and
overwrites per-step via scatter). Removes the three now-redundant
is-not-None guards.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements Multi-Token Prediction (MTP) speculative decoding support for Qwen3.5 and Qwen3.5-MoE models, refactors MTP state management to unify verify extra states, and adapts CUDA graph capture and Triton kernels for MTP-expanded decode layouts. The review identified several issues: accessing self.b_att_seq_len directly in fp8.py will cause an AttributeError and should go through self.infer_state; calling .min() and .max() on CUDA tensors in linear_att_copy.py introduces a synchronous D2H bottleneck on the hot path; and a potential ZeroDivisionError in causal_conv1d_spec.py should be prevented with an early return when batch is zero.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.


device = self.infer_state.input_ids.device
batch_size = att_batch_size
batch_size = self.b_att_seq_len.shape[0]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Accessing self.b_att_seq_len directly on self will raise an AttributeError because b_att_seq_len is initialized on self.infer_state (via init_mtp_verify_extra_state), not on the attention layer object itself. It should be accessed via self.infer_state.b_att_seq_len.

Suggested change
batch_size = self.b_att_seq_len.shape[0]
batch_size = self.infer_state.b_att_seq_len.shape[0]

v_cache=cache_v,
page_table=self.page_table,
cache_seqlens=self.infer_state.b_seq_len,
cache_seqlens=self.b_att_seq_len,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Accessing self.b_att_seq_len directly on self will raise an AttributeError because b_att_seq_len is initialized on self.infer_state, not on the attention layer object itself. It should be accessed via self.infer_state.b_att_seq_len.

Suggested change
cache_seqlens=self.b_att_seq_len,
cache_seqlens=self.infer_state.b_att_seq_len,

Comment on lines +100 to 103
assert int(b_num_accepted_tokens.min()) >= 1 and int(b_num_accepted_tokens.max()) <= mtp_step + 1, (
f"b_num_accepted_tokens out of range [1, {mtp_step + 1}]: "
f"min={int(b_num_accepted_tokens.min())} max={int(b_num_accepted_tokens.max())}"
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling .min(), .max(), and converting them to int on a CUDA tensor (b_num_accepted_tokens) causes a synchronous host-device transfer (D2H copy). Since copy_linear_att_state_to_kv_buffer is called on the hot path of every decode step (via copy_linear_att_state_to_cache_buffer), this assertion will introduce a CPU-GPU synchronization bottleneck. Consider removing this assertion or wrapping it under a debug/warmup-only check to maintain high inference throughput.

Comment on lines +384 to +385
batch = conv_state_indices.size(0)
dim = x.size(1)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If batch is 0 (e.g., during an empty batch or edge-case execution), x.size(0) % batch will raise a ZeroDivisionError. Adding an early return guard when batch == 0 is a safer defensive programming practice.

Suggested change
batch = conv_state_indices.size(0)
dim = x.size(1)
batch = conv_state_indices.size(0)
if batch == 0:
return x
dim = x.size(1)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant