Skip to content

Commit 58c99a0

Browse files
[Bug fix] Fix log probabilities handling (#1114)
1 parent 4a5f9b8 commit 58c99a0

File tree

4 files changed

+81
-68
lines changed

4 files changed

+81
-68
lines changed

tests/e2e/test_data_parallel.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,6 @@
99
from vllm import LLM, EngineArgs, SamplingParams
1010

1111

12-
@pytest.fixture
13-
def model_name():
14-
"""Small model for faster testing."""
15-
return "Qwen/Qwen2.5-1.5B-Instruct"
16-
17-
1812
@pytest.fixture(autouse=True)
1913
def setup_new_model_design():
2014
"""Automatically set NEW_MODEL_DESIGN=True for all tests."""
@@ -56,21 +50,23 @@ def _run_inference_with_config(model_name: str,
5650
data_parallel_size: int = 1,
5751
additional_config: dict = {},
5852
kv_cache_dtype: str = "auto",
59-
enable_prefix_caching: bool = False) -> list:
53+
enable_prefix_caching: bool = False,
54+
async_scheduling: bool = False) -> list:
6055
"""Helper function to run inference with specified configuration."""
6156

6257
# Create LLM args using parser-based approach similar to offline_inference.py
6358
engine_args = EngineArgs(
6459
model=model_name,
65-
max_model_len=128,
60+
max_model_len=32,
6661
tensor_parallel_size=tensor_parallel_size,
6762
data_parallel_size=data_parallel_size,
68-
gpu_memory_utilization=0.95,
63+
gpu_memory_utilization=0.98,
6964
max_num_batched_tokens=128,
7065
max_num_seqs=16,
7166
enable_prefix_caching=enable_prefix_caching,
7267
additional_config=additional_config,
7368
kv_cache_dtype=kv_cache_dtype,
69+
async_scheduling=async_scheduling,
7470
)
7571

7672
engine_args_dict = asdict(engine_args)
@@ -86,7 +82,6 @@ def _run_inference_with_config(model_name: str,
8682

8783

8884
def test_model_data_parallelism(
89-
model_name: str,
9085
test_prompts: list,
9186
sampling_params: SamplingParams,
9287
):
@@ -98,9 +93,12 @@ def test_model_data_parallelism(
9893
Equivalent to:
9994
python examples/offline_inference.py --tensor_parallel_size=4 --data_parallel_size=2
10095
"""
96+
# Use Llama 1B for this test
97+
test_model = "meta-llama/Llama-3.2-1B-Instruct"
98+
10199
# Test with data parallelism enabled
102100
outputs = _run_inference_with_config(
103-
model_name=model_name,
101+
model_name=test_model,
104102
test_prompts=test_prompts,
105103
sampling_params=sampling_params,
106104
tensor_parallel_size=1,
@@ -119,7 +117,6 @@ def test_model_data_parallelism(
119117

120118

121119
def test_attention_data_parallelism(
122-
model_name: str,
123120
test_prompts: list,
124121
sampling_params: SamplingParams,
125122
):
@@ -132,6 +129,9 @@ def test_attention_data_parallelism(
132129
python examples/offline_inference.py --tensor_parallel_size=8 --kv-cache-dtype=fp8 \
133130
--additional_config='{"sharding":{"sharding_strategy": {"enable_dp_attention":1}}}'
134131
"""
132+
# Use Llama 1B for this test
133+
test_model = "Qwen/Qwen3-0.6B"
134+
135135
additional_config = {
136136
"sharding": {
137137
"sharding_strategy": {
@@ -142,7 +142,7 @@ def test_attention_data_parallelism(
142142

143143
# Test with attention data parallelism enabled
144144
outputs = _run_inference_with_config(
145-
model_name=model_name,
145+
model_name=test_model,
146146
test_prompts=test_prompts,
147147
sampling_params=sampling_params,
148148
tensor_parallel_size=8,
@@ -165,7 +165,6 @@ def test_attention_data_parallelism(
165165

166166

167167
def test_data_parallelism_correctness(
168-
model_name: str,
169168
test_prompts: list,
170169
sampling_params: SamplingParams,
171170
):
@@ -176,7 +175,7 @@ def test_data_parallelism_correctness(
176175
"""
177176
os.environ['SKIP_JAX_PRECOMPILE'] = '1'
178177
os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '0'
179-
178+
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
180179
# Use a smaller subset of prompts for correctness testing
181180
small_prompts = test_prompts[:10]
182181

@@ -187,6 +186,7 @@ def test_data_parallelism_correctness(
187186
sampling_params=sampling_params,
188187
tensor_parallel_size=1,
189188
data_parallel_size=1,
189+
async_scheduling=True,
190190
)
191191

192192
# Run with model data parallelism and async scheduling
@@ -196,9 +196,7 @@ def test_data_parallelism_correctness(
196196
sampling_params=sampling_params,
197197
tensor_parallel_size=1,
198198
data_parallel_size=2,
199-
additional_config={"scheduler_config": {
200-
"async_scheduling": True
201-
}},
199+
async_scheduling=True,
202200
)
203201

204202
# Compare outputs - they should be identical for greedy sampling

tests/runner/test_tpu_runner_dp.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,8 @@ def test_prepare_inputs_dp_basic_functionality(self,
102102
result = self.runner._prepare_inputs_dp(scheduler_output)
103103

104104
# Basic assertions
105-
assert len(result) == 6
106-
input_ids, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector = result
105+
assert len(result) == 7
106+
input_ids, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result
107107

108108
# Verify utility functions were called
109109
mock_runner_utils.get_padded_token_len.assert_called()
@@ -380,8 +380,7 @@ def mock_get_padded_token_len(paddings_list, val):
380380

381381
# Execute the method
382382
result = self.runner._prepare_inputs_dp(scheduler_output)
383-
input_ids, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector = result
384-
383+
input_ids, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result
385384
# 1. Verify input_ids content
386385
expected_input_ids = np.zeros(16, dtype=np.int32)
387386
expected_input_ids[:2] = [1006, 1007]
@@ -495,7 +494,7 @@ def mock_get_padded_token_len(paddings_list, val):
495494

496495
# Execute the method
497496
result = self.runner._prepare_inputs_dp(scheduler_output)
498-
input_ids, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector = result
497+
input_ids, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result
499498

500499
# 1. Verify input_ids
501500
expected_input_ids = np.zeros(16, dtype=np.int32)
@@ -724,7 +723,7 @@ def test_prepare_inputs_routing_to_non_dp(self):
724723

725724
self.runner.dp_size = 1
726725
self.runner._prepare_inputs_non_dp = MagicMock(
727-
return_value=(None, None, None, None, None, None))
726+
return_value=(None, None, None, None, None, None, None))
728727

729728
scheduler_output = MagicMock()
730729
self.runner._prepare_inputs(scheduler_output)

tpu_inference/runner/kv_cache_manager.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import functools
2-
import math
32
from typing import TYPE_CHECKING, Dict, List
43

54
import jax
@@ -190,7 +189,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
190189
num_blocks = kv_cache_tensor.size // page_size_bytes
191190
dp_size = self.runner.vllm_config.sharding_config.total_dp_size
192191
# num_blocks must be a multiple of dp_size
193-
num_blocks = math.ceil(num_blocks / dp_size) * dp_size
192+
num_blocks = (num_blocks // dp_size) * dp_size
194193
# NOTE: we'll multiply the num_kv_heads by 2 in the function
195194
kv_cache = create_kv_caches(
196195
num_blocks=num_blocks,

tpu_inference/runner/tpu_runner.py

Lines changed: 59 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from flax import nnx
1616
from jax.experimental import mesh_utils
1717
from jax.sharding import NamedSharding, PartitionSpec
18-
from torchax.ops.mappings import j2t, j2t_dtype
18+
from torchax.ops.mappings import j2t_dtype
1919
from vllm.config import VllmConfig
2020
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
2121
has_kv_transfer_group)
@@ -154,6 +154,7 @@ class ExecuteModelState:
154154
spec_decode_metadata: Optional[SpecDecodeMetadata]
155155
kv_connector_output: Optional[KVConnectorOutput]
156156
logits_indices_selector: Optional[List[int]] = None
157+
padded_num_reqs: Optional[int] = None
157158

158159

159160
@functools.partial(jax.jit, donate_argnums=(0, 1, 2))
@@ -191,19 +192,28 @@ def _substitute_placeholder_token(
191192
return input_ids.at[token_in_tpu_cur_input_indices].set(update_values)
192193

193194

194-
def _reorder_logits_indices(logprobs_lists: LogprobsLists,
195-
logits_indices_selector: List[int]):
195+
def _jax_logprobs_to_lists(logprobs_tensors,
196+
logits_indices_selector=None,
197+
cu_num_generated_tokens=None):
198+
"""Convert JAX LogprobsTensors to LogprobsLists by converting JAX arrays to numpy."""
199+
log_token_ids_list = logprobs_tensors.logprob_token_ids.tolist()
200+
logprobs_list = logprobs_tensors.logprobs.tolist()
201+
selected_token_ranks_list = logprobs_tensors.selected_token_ranks.tolist()
202+
203+
if logits_indices_selector is not None:
204+
log_token_ids_list = [
205+
log_token_ids_list[i] for i in logits_indices_selector
206+
]
207+
logprobs_list = [logprobs_list[i] for i in logits_indices_selector]
208+
selected_token_ranks_list = [
209+
selected_token_ranks_list[i] for i in logits_indices_selector
210+
]
211+
196212
return LogprobsLists(
197-
logprob_token_ids=[
198-
logprobs_lists.logprob_token_ids[i]
199-
for i in logits_indices_selector
200-
],
201-
logprobs=[logprobs_lists.logprobs[i] for i in logits_indices_selector],
202-
sampled_token_ranks=[
203-
logprobs_lists.sampled_token_ranks[i]
204-
for i in logits_indices_selector
205-
],
206-
cu_num_generated_tokens=logprobs_lists.cu_num_generated_tokens,
213+
logprob_token_ids=np.asarray(log_token_ids_list),
214+
logprobs=np.asarray(logprobs_list),
215+
sampled_token_ranks=np.asarray(selected_token_ranks_list),
216+
cu_num_generated_tokens=cu_num_generated_tokens,
207217
)
208218

209219

@@ -552,16 +562,17 @@ def sample_tokens(
552562

553563
(scheduler_output, attn_metadata, input_ids, hidden_states, logits,
554564
aux_hidden_states, spec_decode_metadata, kv_connector_output,
555-
logits_indices_selector) = (
556-
self.execute_model_state.scheduler_output,
557-
self.execute_model_state.attn_metadata,
558-
self.execute_model_state.input_ids,
559-
self.execute_model_state.hidden_states,
560-
self.execute_model_state.logits,
561-
self.execute_model_state.aux_hidden_states,
562-
self.execute_model_state.spec_decode_metadata,
563-
self.execute_model_state.kv_connector_output,
564-
self.execute_model_state.logits_indices_selector)
565+
logits_indices_selector,
566+
padded_num_reqs) = (self.execute_model_state.scheduler_output,
567+
self.execute_model_state.attn_metadata,
568+
self.execute_model_state.input_ids,
569+
self.execute_model_state.hidden_states,
570+
self.execute_model_state.logits,
571+
self.execute_model_state.aux_hidden_states,
572+
self.execute_model_state.spec_decode_metadata,
573+
self.execute_model_state.kv_connector_output,
574+
self.execute_model_state.logits_indices_selector,
575+
self.execute_model_state.padded_num_reqs)
565576
self.execute_model_state = None
566577

567578
if grammar_output is not None:
@@ -575,12 +586,10 @@ def sample_tokens(
575586
logits,
576587
arange,
577588
)
578-
return self._sample_from_logits(scheduler_output, attn_metadata,
579-
input_ids, hidden_states, logits,
580-
aux_hidden_states,
581-
spec_decode_metadata,
582-
kv_connector_output,
583-
logits_indices_selector)
589+
return self._sample_from_logits(
590+
scheduler_output, attn_metadata, input_ids, hidden_states, logits,
591+
aux_hidden_states, spec_decode_metadata, kv_connector_output,
592+
logits_indices_selector, padded_num_reqs)
584593

585594
def _modify_prev_results(self):
586595
# If copy to host has not been done, we just wait.
@@ -694,6 +703,7 @@ def _execute_model(
694703
logits_indices,
695704
spec_decode_metadata,
696705
logits_indices_selector,
706+
padded_num_reqs,
697707
) = self._prepare_inputs(scheduler_output)
698708

699709
# multi-modal support
@@ -756,7 +766,8 @@ def _execute_model(
756766
aux_hidden_states=aux_hidden_states,
757767
spec_decode_metadata=spec_decode_metadata,
758768
kv_connector_output=kv_connector_output,
759-
logits_indices_selector=logits_indices_selector)
769+
logits_indices_selector=logits_indices_selector,
770+
padded_num_reqs=padded_num_reqs)
760771
return attn_metadata, None
761772

762773
def _sample_from_logits(
@@ -770,11 +781,19 @@ def _sample_from_logits(
770781
spec_decode_metadata: Optional[SpecDecodeMetadata],
771782
kv_connector_output: Optional[KVConnectorOutput],
772783
logits_indices_selector: Optional[List[int]] = None,
784+
padded_num_reqs: Optional[int] = None,
773785
) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput:
774-
padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit(
775-
self.input_batch.num_reqs, self.max_num_reqs)
786+
if padded_num_reqs is None:
787+
padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit(
788+
self.input_batch.num_reqs, self.max_num_reqs)
789+
790+
sharding = None
791+
if self.dp_size > 1:
792+
sharding = NamedSharding(self.mesh,
793+
PartitionSpec(ShardingAxisName.ATTN_DATA))
794+
776795
tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
777-
self.mesh, self.input_batch, padded_num_reqs)
796+
self.mesh, self.input_batch, padded_num_reqs, sharding=sharding)
778797
if spec_decode_metadata is None:
779798
next_tokens = sample(
780799
self.rng_params_for_sampling,
@@ -806,8 +825,6 @@ def _sample_from_logits(
806825
if tpu_sampling_metadata.logprobs:
807826
logprobs = self._compute_and_gather_logprobs(
808827
logits, next_tokens, self.model_config.max_logprobs)
809-
logprobs_lists = jax.tree.map(lambda x: j2t(x.astype(jnp.float32)),
810-
logprobs).tolists()
811828
else:
812829
logprobs = None
813830

@@ -860,9 +877,8 @@ def _sample_from_logits(
860877

861878
if logprobs is not None:
862879
# Map logprobs back to the pre-dp shuffling order
863-
if logits_indices_selector is not None:
864-
logprobs_lists = _reorder_logits_indices(
865-
logprobs_lists, logits_indices_selector)
880+
logprobs_lists = _jax_logprobs_to_lists(
881+
logprobs, logits_indices_selector)
866882

867883
else:
868884
logprobs_lists = None
@@ -934,9 +950,8 @@ def _sample_from_logits(
934950

935951
if logprobs is not None:
936952
# Map logprobs back to the pre-dp shuffling order
937-
if logits_indices_selector is not None:
938-
logprobs_lists = _reorder_logits_indices(
939-
logprobs_lists, logits_indices_selector)
953+
logprobs_lists = _jax_logprobs_to_lists(logprobs,
954+
logits_indices_selector)
940955
else:
941956
logprobs_lists = None
942957

@@ -1397,6 +1412,7 @@ def _prepare_inputs_dp(self, scheduler_output: "VllmSchedulerOutput"):
13971412
logits_indices,
13981413
spec_decode_metadata,
13991414
logits_indices_selector,
1415+
padded_num_reqs,
14001416
)
14011417

14021418
def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"):
@@ -1563,7 +1579,8 @@ def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"):
15631579
attention_metadata.seq_lens_cpu = seq_lens_cpu
15641580
logits_indices_selector = None
15651581
return (input_ids, attention_metadata, sampling_metadata,
1566-
logits_indices, spec_decode_metadata, logits_indices_selector)
1582+
logits_indices, spec_decode_metadata, logits_indices_selector,
1583+
padded_num_reqs)
15671584

15681585
def _get_input_ids_embeds(self, input_ids: jax.Array,
15691586
mm_embeds: list[jax.Array]):

0 commit comments

Comments
 (0)