Skip to content

Commit a289cc1

Browse files
authored
[Test] Batch Invariant: Rename and organize tests (#27421)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
1 parent 95ae50b commit a289cc1

File tree

5 files changed

+248
-80
lines changed

5 files changed

+248
-80
lines changed

tests/v1/determinism/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import pytest
5+
6+
7+
@pytest.fixture(autouse=True)
8+
def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch):
9+
"""Automatically enable batch invariant kernel overrides for all tests."""
10+
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1")
11+
yield

tests/v1/generation/test_batch_invariance.py renamed to tests/v1/determinism/test_batch_invariance.py

Lines changed: 1 addition & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -6,66 +6,9 @@
66

77
import pytest
88
import torch
9+
from utils import _extract_step_logprobs, _random_prompt, skip_unsupported
910

1011
from vllm import LLM, SamplingParams
11-
from vllm.platforms import current_platform
12-
13-
skip_unsupported = pytest.mark.skipif(
14-
not (current_platform.is_cuda() and current_platform.has_device_capability(90)),
15-
reason="Requires CUDA and >= Hopper (SM90)",
16-
)
17-
18-
19-
@pytest.fixture(autouse=True)
20-
def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch):
21-
"""Automatically enable batch invariant kernel overrides for all tests."""
22-
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1")
23-
yield
24-
25-
26-
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
27-
# Generate more realistic prompts that will actually produce varied tokens
28-
# Use a mix of common English text patterns
29-
30-
prompt_templates = [
31-
# Question-answer style
32-
"Question: What is the capital of France?\nAnswer: The capital of France is",
33-
"Q: How does photosynthesis work?\nA: Photosynthesis is the process by which",
34-
"User: Can you explain quantum mechanics?\nAssistant: Quantum mechanics is",
35-
# Story/narrative style
36-
"Once upon a time in a distant galaxy, there lived",
37-
"The old man walked slowly down the street, remembering",
38-
"In the year 2157, humanity finally discovered",
39-
# Technical/code style
40-
"To implement a binary search tree in Python, first we need to",
41-
"The algorithm works by iterating through the array and",
42-
"Here's how to optimize database queries using indexing:",
43-
# Factual/informative style
44-
"The Renaissance was a period in European history that",
45-
"Climate change is caused by several factors including",
46-
"The human brain contains approximately 86 billion neurons which",
47-
# Conversational style
48-
"I've been thinking about getting a new laptop because",
49-
"Yesterday I went to the store and bought",
50-
"My favorite thing about summer is definitely",
51-
]
52-
53-
# Pick a random template
54-
base_prompt = random.choice(prompt_templates)
55-
56-
if max_words < min_words:
57-
max_words = min_words
58-
target_words = random.randint(min_words, max_words)
59-
60-
if target_words > 50:
61-
# For longer prompts, repeat context
62-
padding_text = (
63-
" This is an interesting topic that deserves more explanation. "
64-
* (target_words // 50)
65-
)
66-
base_prompt = base_prompt + padding_text
67-
68-
return base_prompt
6912

7013

7114
@skip_unsupported
@@ -204,22 +147,6 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
204147
llm_bsN.shutdown()
205148

206149

207-
def _extract_step_logprobs(request_output):
208-
if getattr(request_output, "outputs", None):
209-
inner = request_output.outputs[0]
210-
if hasattr(inner, "logprobs") and inner.logprobs is not None:
211-
t = torch.tensor(
212-
[
213-
inner.logprobs[i][tid].logprob
214-
for i, tid in enumerate(inner.token_ids)
215-
],
216-
dtype=torch.float32,
217-
)
218-
return t, inner.token_ids
219-
220-
return None, None
221-
222-
223150
@skip_unsupported
224151
@pytest.mark.parametrize(
225152
"backend",
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
HTTP-based batch invariance test: send requests to a running
5+
vLLM server and compare BS=1 vs BS=N results (tokens and per-step logprobs).
6+
7+
Environment variables:
8+
- VLLM_TEST_MODEL: served model name (e.g., Qwen/Qwen3-1.7B / DeepSeek-R1)
9+
- VLLM_TP_SIZE: tensor parallelism size (e.g., 4)
10+
11+
"""
12+
13+
import os
14+
import random
15+
import sys
16+
from typing import Any
17+
18+
import openai
19+
from utils import _random_prompt, skip_unsupported
20+
21+
from tests.utils import RemoteOpenAIServer
22+
23+
24+
def _request_completion(
25+
client: openai.OpenAI,
26+
model: str,
27+
prompt: Any,
28+
sp: dict[str, Any],
29+
max_retries: int = 3,
30+
retry_backoff: float = 0.5,
31+
) -> dict[str, Any] | None:
32+
payload: dict[str, Any] = {"model": model, "prompt": prompt}
33+
payload.update(sp)
34+
35+
for attempt in range(max_retries + 1):
36+
try:
37+
completion = client.completions.create(**payload)
38+
# Convert to plain dict so downstream logic can keep using
39+
# dict-style access just like with raw HTTP JSON.
40+
return completion.model_dump()
41+
except Exception as e: # pragma: no cover
42+
if attempt < max_retries:
43+
import time as _t
44+
45+
_t.sleep(retry_backoff * (2**attempt))
46+
continue
47+
sys.stderr.write(f"Error: {e}\n")
48+
return None
49+
return None
50+
51+
52+
def _extract_tokens_and_logprobs(
53+
choice: dict[str, Any],
54+
) -> tuple[list[Any], list[float] | None]:
55+
tokens: list[Any] = []
56+
token_logprobs: list[float] | None = None
57+
lp = choice.get("logprobs")
58+
if lp and isinstance(lp, dict):
59+
tokens = lp.get("token_ids") or lp.get("tokens") or []
60+
token_logprobs = lp.get("token_logprobs", None)
61+
return tokens, token_logprobs
62+
63+
64+
def _compare_bs1_vs_bsn_single_process(
65+
prompts: list[str],
66+
sp_kwargs: dict[str, Any],
67+
client: openai.OpenAI,
68+
model_name: str,
69+
) -> None:
70+
# BS=1
71+
bs1_tokens_per_prompt: list[list[Any]] = []
72+
bs1_logprobs_per_prompt: list[list[float] | None] = []
73+
for p in prompts:
74+
resp = _request_completion(client, model_name, p, sp_kwargs)
75+
if resp is None or not resp.get("choices"):
76+
raise AssertionError("BS=1 empty/failed response")
77+
choice = resp["choices"][0]
78+
toks, lps = _extract_tokens_and_logprobs(choice)
79+
if lps is None:
80+
raise AssertionError(
81+
"logprobs not returned; ensure server supports 'logprobs'"
82+
)
83+
bs1_tokens_per_prompt.append(list(toks))
84+
bs1_logprobs_per_prompt.append(list(lps))
85+
86+
# BS=N
87+
bsN_tokens_per_prompt: list[list[Any]] = [None] * len(prompts) # type: ignore[list-item]
88+
bsN_logprobs_per_prompt: list[list[float] | None] = [None] * len(prompts)
89+
resp = _request_completion(client, model_name, prompts, sp_kwargs)
90+
if resp is None or not resp.get("choices"):
91+
raise AssertionError("BS=N empty/failed batched response")
92+
choices = resp.get("choices", [])
93+
if len(choices) != len(prompts):
94+
raise AssertionError(
95+
f"BS=N choices length {len(choices)} != num prompts {len(prompts)}"
96+
)
97+
for idx, choice in enumerate(choices):
98+
toks, lps = _extract_tokens_and_logprobs(choice)
99+
if lps is None:
100+
raise AssertionError(f"BS=N missing logprobs for prompt {idx}")
101+
bsN_tokens_per_prompt[idx] = list(toks)
102+
bsN_logprobs_per_prompt[idx] = list(lps)
103+
104+
# compare
105+
for i, (tokens_bs1, tokens_bsN, logprobs_bs1, logprobs_bsN) in enumerate(
106+
zip(
107+
bs1_tokens_per_prompt,
108+
bsN_tokens_per_prompt,
109+
bs1_logprobs_per_prompt,
110+
bsN_logprobs_per_prompt,
111+
)
112+
):
113+
if tokens_bs1 != tokens_bsN:
114+
raise AssertionError(
115+
f"Prompt {i} (sampling): Different tokens sampled. "
116+
f"BS=1 tokens: {tokens_bs1} BS=N tokens: {tokens_bsN}"
117+
)
118+
if logprobs_bs1 is None or logprobs_bsN is None:
119+
raise AssertionError(f"Prompt {i}: Missing logprobs in one of the runs")
120+
if len(logprobs_bs1) != len(logprobs_bsN):
121+
raise AssertionError(
122+
f"Prompt {i}: Different number of steps: "
123+
f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bsN)} (BS=N)."
124+
)
125+
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
126+
if a != b:
127+
diff = abs(a - b)
128+
raise AssertionError(
129+
f"Prompt {i} Step {t}: Bitwise mismatch "
130+
f"(abs diff={diff:.6e}). "
131+
f"BS=1 tokens: {tokens_bs1} BS=N tokens: {tokens_bsN}"
132+
)
133+
134+
135+
@skip_unsupported
136+
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN():
137+
random.seed(int(os.getenv("VLLM_TEST_SEED", "12345")))
138+
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
139+
prompts_all = [_random_prompt(10, 50) for _ in range(32)]
140+
141+
sp_kwargs: dict[str, Any] = {
142+
"temperature": 0.6,
143+
"top_p": 1.0,
144+
"max_tokens": 8,
145+
"seed": 42,
146+
"logprobs": 5,
147+
}
148+
149+
tp_size = os.getenv("VLLM_TP_SIZE", "1")
150+
server_args: list[str] = []
151+
if tp_size:
152+
server_args += ["-tp", tp_size]
153+
154+
with RemoteOpenAIServer(model_name, server_args) as server:
155+
client = server.get_client()
156+
_compare_bs1_vs_bsn_single_process(
157+
prompts=prompts_all,
158+
sp_kwargs=sp_kwargs,
159+
client=client,
160+
model_name=model_name,
161+
)

tests/v1/generation/test_rms_norm_batch_invariant.py renamed to tests/v1/determinism/test_rms_norm_batch_invariant.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,10 @@
99

1010
import pytest
1111
import torch
12+
from utils import skip_unsupported
1213

1314
from vllm.model_executor.layers.batch_invariant import rms_norm as triton_rms_norm
1415
from vllm.model_executor.layers.layernorm import RMSNorm
15-
from vllm.platforms import current_platform
16-
17-
skip_unsupported = pytest.mark.skipif(
18-
not (current_platform.is_cuda() and current_platform.has_device_capability(90)),
19-
reason="Requires CUDA and >= Hopper (SM90)",
20-
)
2116

2217

2318
@skip_unsupported

tests/v1/determinism/utils.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import random
4+
5+
import pytest
6+
import torch
7+
8+
from vllm.platforms import current_platform
9+
10+
skip_unsupported = pytest.mark.skipif(
11+
not (current_platform.is_cuda() and current_platform.has_device_capability(90)),
12+
reason="Requires CUDA and >= Hopper (SM90)",
13+
)
14+
15+
16+
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
17+
# Generate more realistic prompts that will actually produce varied tokens
18+
# Use a mix of common English text patterns
19+
20+
prompt_templates = [
21+
# Question-answer style
22+
"Question: What is the capital of France?\nAnswer: The capital of France is",
23+
"Q: How does photosynthesis work?\nA: Photosynthesis is the process by which",
24+
"User: Can you explain quantum mechanics?\nAssistant: Quantum mechanics is",
25+
# Story/narrative style
26+
"Once upon a time in a distant galaxy, there lived",
27+
"The old man walked slowly down the street, remembering",
28+
"In the year 2157, humanity finally discovered",
29+
# Technical/code style
30+
"To implement a binary search tree in Python, first we need to",
31+
"The algorithm works by iterating through the array and",
32+
"Here's how to optimize database queries using indexing:",
33+
# Factual/informative style
34+
"The Renaissance was a period in European history that",
35+
"Climate change is caused by several factors including",
36+
"The human brain contains approximately 86 billion neurons which",
37+
# Conversational style
38+
"I've been thinking about getting a new laptop because",
39+
"Yesterday I went to the store and bought",
40+
"My favorite thing about summer is definitely",
41+
]
42+
43+
# Pick a random template
44+
base_prompt = random.choice(prompt_templates)
45+
46+
if max_words < min_words:
47+
max_words = min_words
48+
target_words = random.randint(min_words, max_words)
49+
50+
if target_words > 50:
51+
# For longer prompts, repeat context
52+
padding_text = (
53+
" This is an interesting topic that deserves more explanation. "
54+
* (target_words // 50)
55+
)
56+
base_prompt = base_prompt + padding_text
57+
58+
return base_prompt
59+
60+
61+
def _extract_step_logprobs(request_output):
62+
if getattr(request_output, "outputs", None):
63+
inner = request_output.outputs[0]
64+
if hasattr(inner, "logprobs") and inner.logprobs is not None:
65+
t = torch.tensor(
66+
[
67+
inner.logprobs[i][tid].logprob
68+
for i, tid in enumerate(inner.token_ids)
69+
],
70+
dtype=torch.float32,
71+
)
72+
return t, inner.token_ids
73+
74+
return None, None

0 commit comments

Comments
 (0)