|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +from transformers import AutoTokenizer |
| 4 | + |
| 5 | +from vllm.config import StructuredOutputsConfig, VllmConfig |
| 6 | +from vllm.config.model import ModelConfig |
| 7 | +from vllm.config.speculative import SpeculativeConfig |
| 8 | +from vllm.sampling_params import SamplingParams, StructuredOutputsParams |
| 9 | +from vllm.v1.request import Request |
| 10 | +from vllm.v1.structured_output import StructuredOutputManager |
| 11 | +from vllm.v1.structured_output.backend_guidance import GuidanceBackend |
| 12 | +from vllm.v1.structured_output.backend_types import StructuredOutputOptions |
| 13 | + |
| 14 | +TOKENIZER = "gpt2" |
| 15 | + |
| 16 | + |
| 17 | +def test_backend_guidance_rollback_terminated(): |
| 18 | + # Test that the backend guidance successfully rollbacks from a |
| 19 | + # terminated state. This can happen with speculative decoding, |
| 20 | + # where the draft model proposes EOS and it is verified by the |
| 21 | + # guidance backend. In that case we are in a stopped state, but |
| 22 | + # it should be reverted in case EOS is not accepted by the target |
| 23 | + # model. |
| 24 | + vllm_config = VllmConfig( |
| 25 | + decoding_config=StructuredOutputsConfig( |
| 26 | + backend="guidance", |
| 27 | + ) |
| 28 | + ) |
| 29 | + tokenizer = AutoTokenizer.from_pretrained(TOKENIZER) |
| 30 | + |
| 31 | + backend = GuidanceBackend( |
| 32 | + vllm_config, |
| 33 | + tokenizer=tokenizer, |
| 34 | + vocab_size=50257, |
| 35 | + ) |
| 36 | + |
| 37 | + grammar = backend.compile_grammar( |
| 38 | + StructuredOutputOptions.JSON, '{"type": "object"}' |
| 39 | + ) |
| 40 | + |
| 41 | + prompt = tokenizer.encode('{"a": "b"}') |
| 42 | + assert len(prompt) > 1 |
| 43 | + dummy_wrong = tokenizer.encode('{"a"}') |
| 44 | + for token in prompt: |
| 45 | + assert grammar.accept_tokens("", [token]) |
| 46 | + assert not grammar.is_terminated() |
| 47 | + assert grammar.accept_tokens("", [tokenizer.eos_token_id]) |
| 48 | + assert grammar.is_terminated() |
| 49 | + # Giving any other token should also be accepted |
| 50 | + assert grammar.accept_tokens("", dummy_wrong) |
| 51 | + # Rollback is done from where state was terminated, so from '}' not EOS |
| 52 | + grammar.rollback(len(prompt) - 1) |
| 53 | + assert not grammar.is_terminated() |
| 54 | + assert grammar.validate_tokens([tokenizer.eos_token_id]) == [] |
| 55 | + assert grammar.validate_tokens(dummy_wrong) != dummy_wrong |
| 56 | + assert grammar.accept_tokens("", prompt[1:]) |
| 57 | + assert not grammar.is_terminated() |
| 58 | + assert grammar.accept_tokens("", [tokenizer.eos_token_id]) |
| 59 | + assert grammar.is_terminated() |
| 60 | + # Rollback of <= 0 should not change the terminated state |
| 61 | + grammar.rollback(0) |
| 62 | + assert grammar.is_terminated() |
| 63 | + grammar.rollback(-1) |
| 64 | + assert grammar.is_terminated() |
| 65 | + |
| 66 | + |
| 67 | +def test_grammar_bitmask_with_specdec(): |
| 68 | + tokenizer = AutoTokenizer.from_pretrained(TOKENIZER) |
| 69 | + prompt = tokenizer.encode('{"a": "b"}') |
| 70 | + vllm_config = VllmConfig( |
| 71 | + model_config=ModelConfig(tokenizer=TOKENIZER), |
| 72 | + structured_outputs_config=StructuredOutputsConfig(backend="guidance"), |
| 73 | + speculative_config=SpeculativeConfig(model="[ngram]", num_speculative_tokens=3), |
| 74 | + ) |
| 75 | + structured_output_manager = StructuredOutputManager(vllm_config) |
| 76 | + |
| 77 | + for i in range(1, 2): |
| 78 | + sampling_params = SamplingParams( |
| 79 | + structured_outputs=StructuredOutputsParams( |
| 80 | + json='{"type": "object"}', |
| 81 | + ), |
| 82 | + ) |
| 83 | + sampling_params.structured_outputs._backend = "guidance" |
| 84 | + |
| 85 | + my_req_id = f"my_req_id_{i}" |
| 86 | + request = Request( |
| 87 | + my_req_id, |
| 88 | + prompt_token_ids=prompt[:i], |
| 89 | + sampling_params=sampling_params, |
| 90 | + pooling_params=None, |
| 91 | + eos_token_id=tokenizer.eos_token_id, |
| 92 | + ) |
| 93 | + |
| 94 | + structured_output_manager.grammar_init(request) |
| 95 | + |
| 96 | + def grammar_bitmask(req: Request, tokens: list[int]) -> None: |
| 97 | + structured_output_manager.grammar_bitmask( |
| 98 | + requests={req.request_id: req}, |
| 99 | + structured_output_request_ids={req.request_id: 0}, |
| 100 | + scheduled_spec_decode_tokens={req.request_id: tokens}, |
| 101 | + ) |
| 102 | + # At this point, we rolled-back, so should not be terminated |
| 103 | + assert not req.structured_output_request.grammar.is_terminated() |
| 104 | + |
| 105 | + # The grammar might not yet be compiled, so we wait for it |
| 106 | + while not request.structured_output_request._check_grammar_completion(): |
| 107 | + continue |
| 108 | + |
| 109 | + assert request.structured_output_request.grammar.accept_tokens( |
| 110 | + request.request_id, prompt[:i] |
| 111 | + ) |
| 112 | + |
| 113 | + grammar_bitmask(request, prompt[i:] + [tokenizer.eos_token_id]) |
| 114 | + grammar_bitmask( |
| 115 | + request, prompt[i:] + [tokenizer.eos_token_id] + prompt |
| 116 | + ) # EOS not the final token |
| 117 | + grammar_bitmask(request, prompt[i:]) # EOS not present |
| 118 | + grammar_bitmask(request, prompt[i:] + [tokenizer.eos_token_id]) |
0 commit comments