Skip to content

Commit ae06584

Browse files
authored
[Feature] Code implementation of Async Scheduler (#924)
Signed-off-by: cychiuak <andersonchiu@google.com>
1 parent 78131cf commit ae06584

File tree

3 files changed

+479
-5
lines changed

3 files changed

+479
-5
lines changed

tests/e2e/test_async_scheduler.py

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
from __future__ import annotations
2+
3+
import random
4+
import string
5+
import time
6+
7+
import pytest
8+
from vllm import LLM, SamplingParams
9+
10+
11+
@pytest.fixture
12+
def sampling_config():
13+
return SamplingParams(temperature=0,
14+
max_tokens=120,
15+
ignore_eos=True,
16+
repetition_penalty=1,
17+
frequency_penalty=0,
18+
presence_penalty=0,
19+
min_p=0,
20+
logprobs=None)
21+
22+
23+
@pytest.fixture
24+
def model_name():
25+
return "Qwen/Qwen2.5-1.5B-Instruct"
26+
27+
28+
def get_test_prompts():
29+
"""
30+
Generates a list of prompts with a specific word count,
31+
32+
Args:
33+
num_prompts: The number of prompts to generate.
34+
input_len_words: The total number of words for each prompt.
35+
36+
Returns:
37+
A list of strings with number of prompts = num_prompts and
38+
The total number of words for each prompt = input_len_words.
39+
"""
40+
num_prompts = 500
41+
input_len_words = 120
42+
prompts = []
43+
44+
# For example w = 's'
45+
# The generated prompt will be Keep repeating: s s s ...
46+
num_repetitions = input_len_words
47+
prefix = "Keep repeating: "
48+
49+
for _ in range(num_prompts):
50+
# 1. Pick a random lowercase letter
51+
w = random.choice(list(string.ascii_lowercase))
52+
53+
# 2. Create the string of repeated words
54+
# This will have (num_repetitions) words
55+
repeating_part = " ".join([w] * num_repetitions)
56+
57+
# 3. Combine with the prefix (if any)
58+
print(f"{prefix}{repeating_part}")
59+
prompts.append(f"{prefix}{repeating_part}")
60+
61+
return prompts
62+
63+
64+
def _test_performance_helper(monkeypatch: pytest.MonkeyPatch,
65+
sampling_config: SamplingParams, model_name: str,
66+
min_speedup: float):
67+
'''
68+
Helper function to test async scheduler decoding performance.
69+
Compares timing between reference LLM and async LLM using Qwen2.5-1.5B.
70+
'''
71+
72+
with monkeypatch.context():
73+
# Use a smaller set of prompts for performance testing
74+
test_prompts = get_test_prompts() # num_prompts=100, input_len=120
75+
76+
# Test reference LLM timing
77+
ref_llm = LLM(model=model_name,
78+
max_model_len=800,
79+
max_num_seqs=24,
80+
max_num_batched_tokens=512,
81+
enable_prefix_caching=False,
82+
async_scheduling=0)
83+
84+
start_time = time.time()
85+
_ = ref_llm.generate(test_prompts, sampling_config)
86+
ref_time = time.time() - start_time
87+
88+
del ref_llm
89+
# Waiting for TPUs to be released
90+
time.sleep(10)
91+
92+
# # Test async LLM timing with max_num_seqs=256
93+
async_llm = LLM(model=model_name,
94+
max_model_len=800,
95+
max_num_seqs=24,
96+
max_num_batched_tokens=512,
97+
enable_prefix_caching=False,
98+
async_scheduling=1)
99+
100+
start_time = time.time()
101+
_ = async_llm.generate(test_prompts, sampling_config)
102+
async_time = time.time() - start_time
103+
104+
del async_llm
105+
# # Waiting for TPUs to be released
106+
time.sleep(10)
107+
108+
speedup = ref_time / async_time
109+
print(f"Reference LLM time: {ref_time:.2f}s")
110+
print(f"Async LLM time: {async_time:.2f}s")
111+
print(f"Speedup: {speedup:.2f}x")
112+
113+
assert speedup >= min_speedup, f"Expected at least {min_speedup}x speedup for async scheduler, got {speedup:.2f}x"
114+
115+
116+
def test_performance(
117+
monkeypatch: pytest.MonkeyPatch,
118+
sampling_config: SamplingParams,
119+
model_name: str,
120+
):
121+
'''
122+
Test that async scheduler decoding provides significant performance improvement.
123+
Compares timing between reference LLM and async LLM using Qwen2.5-1.5B.
124+
Expects async_llm to be at least 1.3x faster than ref_llm.
125+
'''
126+
min_speed_up = 1.3
127+
_test_performance_helper(monkeypatch, sampling_config, model_name,
128+
min_speed_up)
129+
130+
131+
def _test_correctness_helper(
132+
monkeypatch: pytest.MonkeyPatch,
133+
sampling_config: SamplingParams,
134+
model_name: str,
135+
):
136+
'''
137+
Helper function to test async scheduler correctness.
138+
Compare the outputs of a original LLM and a async LLM
139+
should be the same when using async scheduler decoding.
140+
141+
Known Edge Case (KV Cache Swapping):
142+
Under this case, though the temperature is set to 0,
143+
the output is still slightly different everytime.
144+
This is an expected behaviour as the normal scheduler also
145+
behaves the same and hence, it is difficult to design a test
146+
for such scenario.
147+
'''
148+
with monkeypatch.context():
149+
test_prompts = get_test_prompts()
150+
151+
ref_llm = LLM(model=model_name,
152+
max_model_len=1024,
153+
max_num_seqs=100,
154+
async_scheduling=0)
155+
ref_outputs = ref_llm.generate(test_prompts, sampling_config)
156+
157+
del ref_llm
158+
159+
# Waiting for TPUs to be released.
160+
time.sleep(10)
161+
162+
async_llm = LLM(model=model_name,
163+
max_model_len=1024,
164+
max_num_seqs=100,
165+
async_scheduling=1)
166+
async_outputs = async_llm.generate(test_prompts, sampling_config)
167+
168+
matches = 0
169+
misses = 0
170+
for ref_output, async_output in zip(ref_outputs, async_outputs):
171+
if ref_output.outputs[0].text == async_output.outputs[0].text:
172+
print(f"ref_output: {ref_output.outputs[0].text}")
173+
print(f"async_output: {async_output.outputs[0].text}")
174+
matches += 1
175+
else:
176+
misses += 1
177+
print(f"ref_output: {ref_output.outputs[0].text}")
178+
print(f"async_output: {async_output.outputs[0].text}")
179+
180+
assert misses == 0
181+
del async_outputs
182+
183+
# Waiting for TPUs to be released.
184+
time.sleep(10)
185+
186+
187+
def test_async_correctness(
188+
monkeypatch: pytest.MonkeyPatch,
189+
sampling_config: SamplingParams,
190+
model_name: str,
191+
):
192+
'''
193+
Compare the outputs of a original LLM and a async LLM
194+
should be the same when using async scheduler.
195+
'''
196+
197+
_test_correctness_helper(monkeypatch, sampling_config, model_name)

tpu_inference/runner/compilation_manager.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ def capture_model(self) -> None:
7575
self._precompile_backbone_text_only()
7676
if self.runner.is_multimodal_model:
7777
self._precompile_backbone_with_inputs_embeds()
78+
if self.runner.scheduler_config.async_scheduling:
79+
self._precompile_substitute_placeholder_token()
7880
self._precompile_select_from_array()
7981
self._precompile_compute_logits()
8082
self._precompile_disagg_utils()
@@ -148,6 +150,41 @@ def model_fn_wrapper(
148150
num_tokens=num_tokens,
149151
)
150152

153+
def _precompile_substitute_placeholder_token(self) -> None:
154+
"""Precompiles the token substitution function for all expected input shapes.
155+
156+
It iterates through all potential padded token lengths
157+
(`num_tokens_paddings`) and request batch sizes (`num_reqs_paddings`)
158+
that the scheduler is expected to handle, ensuring a compiled version
159+
is ready for each combination.
160+
"""
161+
162+
for num_tokens in self.runner.num_tokens_paddings:
163+
padded_token_in_tpu_cur_input_indices = np.zeros((num_tokens, ),
164+
dtype=np.int32)
165+
padded_token_in_tpu_pre_next_tokens_indices = np.zeros(
166+
(num_tokens, ), dtype=jnp.int32)
167+
for num_reqs in self.runner.num_reqs_paddings:
168+
input_ids = self._create_dummy_tensor((num_tokens, ),
169+
jnp.int32)
170+
# Need align to the sampling output
171+
next_tokens = self._create_dummy_tensor(
172+
(num_reqs, ),
173+
jnp.int32,
174+
sharding=NamedSharding(self.runner.mesh, PartitionSpec()))
175+
placeholder_num = 1
176+
self._run_compilation(
177+
"_substitute_placeholder_token_fn",
178+
self.runner._substitute_placeholder_token_fn,
179+
input_ids,
180+
padded_token_in_tpu_cur_input_indices,
181+
padded_token_in_tpu_pre_next_tokens_indices,
182+
next_tokens,
183+
placeholder_num,
184+
num_tokens=num_tokens,
185+
num_reqs=num_reqs,
186+
)
187+
151188
def _precompile_backbone_text_only(self) -> None:
152189
for num_tokens in self.runner.num_tokens_paddings:
153190
input_ids = self._create_dummy_tensor((num_tokens, ), jnp.int32)

0 commit comments

Comments
 (0)