Skip to content

Commit 6d54336

Browse files
Flechmanrussellb
andauthored
[Bugfix] Fix llguidance backend, rollback when EOS was encountered (#25905)
Signed-off-by: Rémi Delacourt <remi@mistral.ai> Signed-off-by: remi <remi@mistral.ai> Co-authored-by: Russell Bryant <rbryant@redhat.com>
1 parent 34553b9 commit 6d54336

File tree

2 files changed

+126
-2
lines changed

2 files changed

+126
-2
lines changed
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from transformers import AutoTokenizer
4+
5+
from vllm.config import StructuredOutputsConfig, VllmConfig
6+
from vllm.config.model import ModelConfig
7+
from vllm.config.speculative import SpeculativeConfig
8+
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
9+
from vllm.v1.request import Request
10+
from vllm.v1.structured_output import StructuredOutputManager
11+
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
12+
from vllm.v1.structured_output.backend_types import StructuredOutputOptions
13+
14+
TOKENIZER = "gpt2"
15+
16+
17+
def test_backend_guidance_rollback_terminated():
18+
# Test that the backend guidance successfully rollbacks from a
19+
# terminated state. This can happen with speculative decoding,
20+
# where the draft model proposes EOS and it is verified by the
21+
# guidance backend. In that case we are in a stopped state, but
22+
# it should be reverted in case EOS is not accepted by the target
23+
# model.
24+
vllm_config = VllmConfig(
25+
decoding_config=StructuredOutputsConfig(
26+
backend="guidance",
27+
)
28+
)
29+
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
30+
31+
backend = GuidanceBackend(
32+
vllm_config,
33+
tokenizer=tokenizer,
34+
vocab_size=50257,
35+
)
36+
37+
grammar = backend.compile_grammar(
38+
StructuredOutputOptions.JSON, '{"type": "object"}'
39+
)
40+
41+
prompt = tokenizer.encode('{"a": "b"}')
42+
assert len(prompt) > 1
43+
dummy_wrong = tokenizer.encode('{"a"}')
44+
for token in prompt:
45+
assert grammar.accept_tokens("", [token])
46+
assert not grammar.is_terminated()
47+
assert grammar.accept_tokens("", [tokenizer.eos_token_id])
48+
assert grammar.is_terminated()
49+
# Giving any other token should also be accepted
50+
assert grammar.accept_tokens("", dummy_wrong)
51+
# Rollback is done from where state was terminated, so from '}' not EOS
52+
grammar.rollback(len(prompt) - 1)
53+
assert not grammar.is_terminated()
54+
assert grammar.validate_tokens([tokenizer.eos_token_id]) == []
55+
assert grammar.validate_tokens(dummy_wrong) != dummy_wrong
56+
assert grammar.accept_tokens("", prompt[1:])
57+
assert not grammar.is_terminated()
58+
assert grammar.accept_tokens("", [tokenizer.eos_token_id])
59+
assert grammar.is_terminated()
60+
# Rollback of <= 0 should not change the terminated state
61+
grammar.rollback(0)
62+
assert grammar.is_terminated()
63+
grammar.rollback(-1)
64+
assert grammar.is_terminated()
65+
66+
67+
def test_grammar_bitmask_with_specdec():
68+
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
69+
prompt = tokenizer.encode('{"a": "b"}')
70+
vllm_config = VllmConfig(
71+
model_config=ModelConfig(tokenizer=TOKENIZER),
72+
structured_outputs_config=StructuredOutputsConfig(backend="guidance"),
73+
speculative_config=SpeculativeConfig(model="[ngram]", num_speculative_tokens=3),
74+
)
75+
structured_output_manager = StructuredOutputManager(vllm_config)
76+
77+
for i in range(1, 2):
78+
sampling_params = SamplingParams(
79+
structured_outputs=StructuredOutputsParams(
80+
json='{"type": "object"}',
81+
),
82+
)
83+
sampling_params.structured_outputs._backend = "guidance"
84+
85+
my_req_id = f"my_req_id_{i}"
86+
request = Request(
87+
my_req_id,
88+
prompt_token_ids=prompt[:i],
89+
sampling_params=sampling_params,
90+
pooling_params=None,
91+
eos_token_id=tokenizer.eos_token_id,
92+
)
93+
94+
structured_output_manager.grammar_init(request)
95+
96+
def grammar_bitmask(req: Request, tokens: list[int]) -> None:
97+
structured_output_manager.grammar_bitmask(
98+
requests={req.request_id: req},
99+
structured_output_request_ids={req.request_id: 0},
100+
scheduled_spec_decode_tokens={req.request_id: tokens},
101+
)
102+
# At this point, we rolled-back, so should not be terminated
103+
assert not req.structured_output_request.grammar.is_terminated()
104+
105+
# The grammar might not yet be compiled, so we wait for it
106+
while not request.structured_output_request._check_grammar_completion():
107+
continue
108+
109+
assert request.structured_output_request.grammar.accept_tokens(
110+
request.request_id, prompt[:i]
111+
)
112+
113+
grammar_bitmask(request, prompt[i:] + [tokenizer.eos_token_id])
114+
grammar_bitmask(
115+
request, prompt[i:] + [tokenizer.eos_token_id] + prompt
116+
) # EOS not the final token
117+
grammar_bitmask(request, prompt[i:]) # EOS not present
118+
grammar_bitmask(request, prompt[i:] + [tokenizer.eos_token_id])

vllm/v1/structured_output/backend_guidance.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ class GuidanceGrammar(StructuredOutputGrammar):
111111
vocab_size: int
112112
printed_error: bool = False
113113
terminated: bool = False
114+
rollback_lag: int = 0
114115

115116
def check_error(self):
116117
if not self.printed_error:
@@ -127,6 +128,8 @@ def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
127128
"""
128129

129130
if self.ll_tokenizer.eos_token in tokens:
131+
if self.ll_matcher.is_stopped() and not self.terminated:
132+
self.rollback_lag = 1
130133
self.terminated = True
131134

132135
if self.ll_matcher.is_stopped():
@@ -163,8 +166,11 @@ def validate_tokens(self, tokens: list[int]) -> list[int]:
163166
return tokens[:num_tokens]
164167

165168
def rollback(self, num_tokens: int) -> None:
166-
self.ll_matcher.rollback(num_tokens)
167-
self.check_error()
169+
if num_tokens > 0:
170+
self.ll_matcher.rollback(num_tokens - self.rollback_lag)
171+
self.terminated = False
172+
self.rollback_lag = 0
173+
self.check_error()
168174

169175
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
170176
# this will automatically return [EOS] mask if the matcher is stopped

0 commit comments

Comments
 (0)