[Feature] Add guided decoding support for speculative decoding#4559
[Feature] Add guided decoding support for speculative decoding#4559windreamer wants to merge 13 commits into
Conversation
There was a problem hiding this comment.
Pull request overview
Adds guided decoding (JSON schema / regex / grammar via xgrammar) support to the PyTorch speculative decoding (MTP) path by propagating GuidedDecodingManager into spec decoding and applying grammar bitmasks during both draft proposal and target verification/rejection sampling.
Changes:
- Propagate
GuidedDecodingManagerintoSpecModelAgentand spec proposers, and apply position-serial grammar masking in spec decode verification. - Add draft-side grammar masking support for proposers that share the target vocab (e.g.,
DeepseekMTP), and asupports_grammar_maskcapability flag. - Add unit/integration/E2E tests and update EN/ZH docs for guided decoding with speculative decoding.
Reviewed changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
lmdeploy/pytorch/engine/model_agent/agent.py |
Propagates guided_decoding_manager into the speculative decoding agent and proposer. |
lmdeploy/pytorch/spec_decode/spec_agent.py |
Implements guided masking in spec decode verification, and expands/slices additional SamplingInputs fields. |
lmdeploy/pytorch/spec_decode/proposers/base.py |
Adds supports_grammar_mask and guided_decoding_manager plumb-through; extends get_outputs signature. |
lmdeploy/pytorch/spec_decode/proposers/deepseek_mtp.py |
Applies grammar bitmask to draft logits (when provided) and advances forked matchers. |
lmdeploy/pytorch/spec_decode/proposers/eagle3.py |
Disables draft-side grammar masking via supports_grammar_mask = False. |
tests/pytorch/spec_decode/test_guided_spec_decode.py |
Unit tests for expand/slice behavior and guided-spec decode grammar mechanics. |
tests/pytorch/spec_decode/test_guided_spec_integration.py |
Higher-level integration tests for guided masking + rejection sampling state consistency. |
tests/test_lmdeploy/test_mtp_guided_decoding.py |
GPU integration tests for pipeline + MTP + guided decoding (schema/regex/json_object + streaming). |
docs/en/advance/spec_decoding.md |
Documents guided decoding usage with speculative decoding (EN). |
docs/zh_cn/advance/spec_decoding.md |
Documents guided decoding usage with speculative decoding (ZH). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 15 out of 15 changed files in this pull request and generated 5 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 15 out of 15 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 15 out of 15 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
84eac20 to
bb48caf
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 23 out of 23 changed files in this pull request and generated no new comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
- Implement position-serial grammar mask via forked GrammarMatchers - Propagate guided_decoding_manager from ModelAgent to SpecModelAgent - Apply grammar mask in DeepseekMTP proposer before token selection - Advance forked matcher state in Eagle3 proposer (no mask due to vocab mismatch) - Handle grammar state management after rejection sampling - Expand/slice sampling_inputs for non-tensor fields (response_formats, session_ctx, etc.) - Consolidate tests: 7 unit + 2 integration tests, 6 GPU e2e tests - Add MTP + guided decoding usage docs (en/zh_cn)
Eagle3's draft vocabulary differs from the target vocabulary, so a target-vocab grammar mask is inapplicable to draft logits. Add supports_grammar_mask class attribute to BaseSpecProposer (default True); Eagle3 overrides to False. spec_agent now gates the fork on this flag, and Eagle3.get_outputs() no longer accepts or processes guided_processors. Co-authored-by: openhands <openhands@all-hands.dev>
…ation - Eagle3.get_outputs() now applies grammar mask before argmax and accept_token after d2t mapping, matching DeepseekMTP pattern - Add _translate_bitmask() to convert target-vocab bitmask to draft-vocab bitmask via scatter_add_ (vectorized, no loops) - Remove supports_grammar_mask flag; all proposers now support it - Fork guided processors unconditionally in spec_agent._async_model_forward - Move session_to_cleanup handling before get_processors in forward_decode - Bump xgrammar>=0.1.33 (fork() requirement) in all 5 runtime requirements - Add comprehensive tests: bitmask translation, Eagle3 get_outputs, fork independence, multi-step draft loop
…ng on CUDA - Update use_fa3 capability check from == 9 (SM90 only) to >= 8 (SM80+) in attention/__init__.py and configurations/utils.py - Add FA3 requirement check in graph_runner.py: speculative decoding on CUDA now raises a clear error if FA3 is unavailable, instead of crashing deep in the Triton paged attention kernel - Update docstrings/error messages to reflect SM80+ (Ampere) support
…tion FA3 mha_fwd derives seqlen_k from page_table.shape[1] * page_size for paged KV without cu_seqlens_k. get_scheduler_metadata must receive the same value to produce a consistent scheduler layout. Previously max_seqlen_k was incorrectly set to step_context. max_kv_seqlen (runtime KV length) in op_backend.py, and decode_query_len or attn_metadata.max_kv_seqlen in cudagraph.py. These values differ from what FA3 computes internally, causing scheduler_metadata to be misaligned with the actual kernel behavior. - op_backend.py: use block_offsets.size(1) * block_size - cudagraph.py: use graph_meta.num_blocks * graph_meta.block_size Both now match FA3 internal: page_table.size(1) * page_size. Co-authored-by: openhands <openhands@all-hands.dev>
… device-to-host sync
In _guided_spec_logits_process, forked matchers were advanced using argmax of the masked target logits. In the non-greedy rejection sampling path, the actually accepted token can differ from argmax, causing subsequent grammar masks (especially the bonus-position mask) to be computed from an incorrect grammar state. Fix: advance forks using the known draft tokens for positions 0..num_spec_tokens-1. Target logits are conditioned on draft tokens, and rejection sampling discards positions after the first rejection, so the draft-token path is the only reachable one. The bonus position needs no advancement — the fork is discarded after the loop.
… loop blocking - Extract _fill_guided_bitmask and _accept_guided_tokens as sync methods - Wrap both with asyncio.to_thread to prevent CPU-bound xgrammar ops (fill_bitmap, accept_token) from blocking the asyncio event loop - Move result.cpu() to caller side (agent.py) instead of storing as member - Keep _wait_stream_once intact (confirmed not the root cause)
Extract fill_bitmap/accept_token loops in spec_agent.py into standalone helper functions and wrap them with asyncio.to_thread() to prevent CPU-bound xgrammar operations from blocking the asyncio event loop, which caused streaming token stuttering. - _accept_spec_rejection_tokens: accept tokens on original matchers - _fill_spec_bitmask: fill grammar bitmask for forked matchers - _accept_spec_forked_tokens: advance forked matchers with draft tokens
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 25 out of 25 changed files in this pull request and generated 9 comments.
Comments suppressed due to low confidence (3)
tests/pytorch/spec_decode/test_guided_spec_integration.py:320
- This simulates the bonus token by reusing
draft_tokens[0], but after accepting all draft tokens that token is not guaranteed to be valid in the current grammar state. The test can therefore advance the matcher along an impossible path or fail for reasons unrelated to rejection-sampling state; generate the bonus from the post-draft grammar mask instead.
# Accept bonus token (simulate)
bonus_token = draft_tokens[0] # placeholder
guided_manager.accept_token(original, bonus_token)
tests/pytorch/spec_decode/test_guided_spec_integration.py:365
- The replacement token for a rejection at position 1 is chosen from the initial grammar state, not from the state after accepting
draft_tokens[0]. That can select a token that is invalid at the actual rejection point, making this test validate an impossible rejection path.
bm = guided_manager.allocate_batched_bitmap(1)
guided_manager.fill_bitmap(original, bm, 0)
allowed = _allowed_ids(bm)
# Find a valid token that differs from draft
replacement_token = None
tests/pytorch/spec_decode/test_guided_spec_integration.py:451
- This end-to-end simulation advances the target-verification grammar fork with target argmax tokens. The production implementation advances that fork with draft tokens because the target logits and bonus position are conditioned on the draft-token path, so this simulation can validate a different grammar path than production uses.
target_fork = original_matcher.fork()
target_tokens_per_pos = []
for pos in range(num_expand):
bm = guided_manager.allocate_batched_bitmap(1)
guided_manager.fill_bitmap(target_fork, bm, 0)
After moving accept_token out of FusedLogitsProcessor.sampling(), the prefill path in _rejection_sampling() was missing the call to advance the grammar matcher state. This caused guided decoding constraints to be silently ignored after the first prefill step, producing malformed JSON and non-matching regex output.
…date docs - Fix session_ctx incorrectly treated as global in _expand_sampling_inputs and _slice_sampling_inputs. Only session_to_cleanup is global; session_ctx is per-batch and must be expanded/sliced alongside response_formats. - Cache device-specific bitmask translation constants in Eagle3 via _get_bitmask_constants(), eliminating repeated .to(device) calls in _translate_bitmask. Pre-compute _n_draft_words at init time. - Rewrite test_rollback_then_accept_rejection_output as test_fork_strategy_rejection_output: replace rollback+double-accept logic with the production fork strategy (accept exactly the rejection-sampled tokens on the original matcher, no rollback needed). - Add clarifying comments in test_guided_spec_integration.py noting that production code advances forks with draft tokens while the simulation uses argmax as a stand-in. - Update EN/ZH docs with note about vocab translation for Eagle3 (target-vocab bitmask translated to draft-vocab via scatter-add).
Motivation
Fixes #4551
When speculative decoding and guided decoding (JSON schema / regex / grammar) are both enabled, guided constraints are silently ignored — the
GuidedDecodingManageris never propagated into the speculative decoding path. This is a silent correctness issue: no error, no warning, just unconstrained output.Modification
Core change: propagate & apply grammar mask in spec decode
agent.py— Afterbuild_spec_agent(), propagateGuidedDecodingManagerto bothSpecModelAgentand itsproposer.spec_agent.py— Main integration:_async_model_forward: ForkGrammarMatchers for the draft model from the original guided processors; forked matchers are advanced in-place byget_outputs()at each draft step; originals remain untouched._rejection_sampling:_guided_spec_logits_process()— forked matchers provide per-position bitmasks for allnum_spec_tokens + 1target logits. After rejection sampling, accept the final output tokens on original matchers to advance their state correctly.guided_decoding_managertoFusedLogitsProcessor(standard path already handles it)._guided_spec_logits_process(): New method that (1) runs non-grammar logits processing (temperature, penalties), (2) applies per-position grammar bitmasks using forked matchers, advancing each fork with draft tokens (not argmax — target logits are conditioned on draft tokens, so the grammar state must follow the draft-token path), (3) returns processed logits for rejection sampling.deepseek_mtp.py— Acceptguided_processorsinget_outputs(). Apply grammar bitmask to draft logits beforeargmax, thenaccept_tokenon each forked matcher to advance its state for the next draft position.base.py— Addguided_decoding_managerattribute toBaseSpecProposer(set bySpecModelAgentafter construction). Addguided_processorsparameter toget_outputs()signature.eagle3.py— Support guided decoding via draft-to-target bitmask translation. Since Eagle3 draft vocabulary differs from the target vocabulary, a target-vocab grammar mask cannot be applied directly. Instead,_translate_bitmask()converts the target-vocab bitmask into a draft-vocab bitmask using thedraft_id_to_target_idmapping, then applies it to draft logits. Afterargmax+ token mapping,accept_tokenadvances each forked matcher's state.attention/__init__.py,configurations/utils.py,graph_runner.py,attention/fa3.py— Fix speculative decoding on non-SM90 CUDA GPUs: extend FA3 capability check from== 9(SM90 only) to>= 8(SM80+, Ampere and above) so that speculative decoding can use FA3's multi-token decode path. Add an early check inCUDAGraphRunner.__init__that raises a clearRuntimeErrorwhen speculative decoding is requested but FA3 is unavailable, instead of crashing in the Triton paged attention kernel.Streaming performance fix:
asyncio.to_threadfor CPU-bound xgrammar opsCPU-heavy xgrammar operations (
fill_bitmap,accept_token) were blocking the asyncio event loop during guided decoding, causing tokens to be returned in stuttered batches rather than smoothly streamed. Fix: extract these loops into standalone sync helpers and wrap calls withasyncio.to_thread()so they run off the event loop.logits_process.py— Extract_fill_guided_bitmask()and_accept_guided_tokens()as sync helpers; wrap calls withasyncio.to_thread().agent.py— Wrap_accept_guided_tokenscall withasyncio.to_thread().spec_agent.py— Extract_fill_spec_bitmask(),_accept_spec_forked_tokens(),_accept_spec_rejection_tokens()as sync helpers; wrap calls withasyncio.to_thread().spec_agent.py— Batch GPU→CPU tensor syncs before loops to avoid per-iteration device-to-host synchronization stalls.Helper changes
_expand_sampling_inputs/_slice_sampling_inputs: Handle additionalSamplingInputsfields (response_formats,session_ctx, etc.) so that guided-decoding–related inputs survive the expand/slice round-trip during rejection sampling.Tests
test_guided_spec_decode.py— Unit tests for_expand_sampling_inputs/_slice_sampling_inputswith guided fields,_guided_spec_logits_processbitmask application, andaccept_tokenstate advancement.test_guided_spec_integration.py— Integration tests (require xgrammar + GPU).test_mtp_guided_decoding.py— End-to-end pipeline tests (require xgrammar + GPU).Docs
spec_decoding.md(EN & ZH) with guided decoding usage notes.BC-breaking (Optional)
None. The
guided_processorsparameter inget_outputs()defaults toNone, so existing proposers that don't override it are unaffected.Checklist
_guided_spec_logits_process, expand/slice with guided fields).