Skip to content

Commit 841c5df

Browse files
authored
[Spec Decoding] Fix precompilation (#960)
Signed-off-by: Lihao Ran <imlihao.ran@gmail.com>
1 parent e8620b3 commit 841c5df

File tree

3 files changed

+38
-32
lines changed

3 files changed

+38
-32
lines changed

tests/runner/test_speculative_decoding_manager.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -302,14 +302,15 @@ def test_propose_eagle3_draft_token_ids(self,
302302
# Mock drafter methods
303303
mock_attn_metadata = MagicMock()
304304
mock_target_token_ids = MagicMock()
305+
mock_last_token_indices = MagicMock()
305306
mock_target_hidden_states = MagicMock()
306307
self.runner.drafter.prepare_inputs.return_value = (
307-
mock_attn_metadata,
308-
mock_target_token_ids,
309308
mock_target_hidden_states,
309+
mock_target_token_ids,
310+
mock_last_token_indices,
311+
mock_attn_metadata,
310312
)
311-
mock_draft_token_ids = MagicMock()
312-
mock_draft_token_ids.tolist.return_value = [[10, 11], [20, 21]]
313+
mock_draft_token_ids = [[10, 11], [20, 21]]
313314
self.runner.drafter.propose.return_value = (
314315
self.runner.kv_caches,
315316
mock_draft_token_ids,

tests/spec_decode/test_eagle3.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def test_prepare_inputs():
8686
proposer.state = None # Mock state
8787
proposer.runner.input_batch.block_table = [mock.MagicMock()]
8888
# Mock the block table return value (2D array)
89-
(proposer.runner.input_batch.block_table[0].get_device_tensor.return_value
89+
(proposer.runner.input_batch.block_table[0].get_cpu_tensor.return_value
9090
) = jnp.zeros((num_reqs, max_num_blocks_per_req), dtype=jnp.int32)
9191

9292
# --- Setup sequence data ---
@@ -289,6 +289,8 @@ def mock_combine_hidden_states_fn(state, hidden_states):
289289
target_hidden_states,
290290
)
291291

292+
if draft_token_ids.ndim == 1:
293+
draft_token_ids = jnp.expand_dims(draft_token_ids, axis=-1)
292294
# Assertions
293295
assert draft_token_ids.shape == (batch_size, num_speculative_tokens)
294296

tpu_inference/runner/compilation_manager.py

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -428,9 +428,7 @@ def _precompile_eagle3_helpers(self) -> None:
428428
draft_kv_cache_group_id = num_kv_cache_groups - 1
429429
block_tables = self.runner.input_batch.block_table[
430430
draft_kv_cache_group_id].get_cpu_tensor().reshape(-1)
431-
block_tables_first_spec = jax.device_put(
432-
block_tables, NamedSharding(self.runner.mesh, PartitionSpec()))
433-
block_tables_loop = jax.device_put(
431+
block_tables = jax.device_put(
434432
block_tables, NamedSharding(self.runner.mesh,
435433
PartitionSpec(None, )))
436434

@@ -447,7 +445,7 @@ def _precompile_eagle3_helpers(self) -> None:
447445
self._run_compilation(
448446
"_update_inputs_for_loop_speculation for the subsequent loops",
449447
self.runner.drafter._update_inputs_for_loop_speculation,
450-
selected_positions, seq_lens, block_tables_loop)
448+
selected_positions, seq_lens, block_tables)
451449

452450
request_distribution = np.array([0, 0, 0], dtype=np.int32)
453451
request_distribution = device_array(self.runner.mesh,
@@ -498,7 +496,7 @@ def _precompile_eagle3_helpers(self) -> None:
498496
positions = self._create_dummy_tensor((num_tokens, ), jnp.int32)
499497
attention_metadata = AttentionMetadata(
500498
input_positions=positions,
501-
block_tables=block_tables_first_spec,
499+
block_tables=block_tables,
502500
seq_lens=seq_lens,
503501
query_start_loc=query_start_loc,
504502
request_distribution=request_distribution,
@@ -520,11 +518,7 @@ def filter_token_and_prepare_initial_inputs_wrapper(
520518
num_reqs)
521519
return target_hidden_states, input_ids, last_token_indices
522520

523-
token_indices = self._create_dummy_tensor((num_tokens, ),
524-
jnp.int32)
525-
input_ids = self._create_dummy_tensor(
526-
(num_tokens, ), jnp.int32,
527-
NamedSharding(self.runner.mesh, PartitionSpec()))
521+
input_ids = self._create_dummy_tensor((num_tokens, ), jnp.int32)
528522
aux_hidden_states = [
529523
self._create_dummy_tensor(
530524
(num_tokens, hidden_size), jnp.bfloat16,
@@ -539,22 +533,29 @@ def filter_token_and_prepare_initial_inputs_wrapper(
539533
NamedSharding(self.runner.mesh, PartitionSpec(None,
540534
None))),
541535
]
542-
self._run_compilation(
543-
"eagle3_filter_token_and_prepare_initial_inputs",
544-
filter_token_and_prepare_initial_inputs_wrapper,
545-
token_indices,
546-
query_start_loc,
547-
seq_lens,
548-
input_ids,
549-
aux_hidden_states,
550-
attention_metadata,
551-
next_token_ids,
552-
device_array(
553-
self.runner.mesh,
554-
np.asarray([self.runner.input_batch.num_reqs],
555-
dtype=jnp.int32)),
556-
num_tokens=num_tokens,
557-
)
536+
# TODO(ranlihao): This will increase the precompilation latency. Find proper range for token_indices.
537+
for padded_total_num_tokens in [
538+
num_tokens,
539+
min(num_tokens * 2, self.runner.num_tokens_paddings[-1])
540+
]:
541+
token_indices = self._create_dummy_tensor(
542+
(padded_total_num_tokens, ), jnp.int32)
543+
self._run_compilation(
544+
"eagle3_filter_token_and_prepare_initial_inputs",
545+
filter_token_and_prepare_initial_inputs_wrapper,
546+
token_indices,
547+
query_start_loc,
548+
seq_lens,
549+
input_ids,
550+
aux_hidden_states,
551+
attention_metadata,
552+
next_token_ids,
553+
device_array(
554+
self.runner.mesh,
555+
np.asarray([self.runner.input_batch.num_reqs],
556+
dtype=jnp.int32)),
557+
num_tokens=num_tokens,
558+
)
558559

559560
def draft_model_fn_wrapper(
560561
state,
@@ -572,6 +573,9 @@ def draft_model_fn_wrapper(
572573
target_hidden_states = self._create_dummy_tensor(
573574
(num_tokens, hidden_size), dtype,
574575
NamedSharding(self.runner.mesh, PartitionSpec(None, "model")))
576+
input_ids = self._create_dummy_tensor(
577+
(num_tokens, ), jnp.int32,
578+
NamedSharding(self.runner.mesh, PartitionSpec()))
575579
self._run_compilation(
576580
"eagle3_draft_model_fn",
577581
draft_model_fn_wrapper,
@@ -602,7 +606,6 @@ def draft_model_fn_wrapper(
602606
attention_metadata.query_start_loc = jax.device_put(
603607
attention_metadata.query_start_loc,
604608
NamedSharding(self.runner.mesh, PartitionSpec()))
605-
attention_metadata.block_tables = block_tables_loop
606609
attention_metadata.input_positions = self._create_dummy_tensor(
607610
(self.runner.max_num_reqs, ), jnp.int32)
608611
self._run_compilation(

0 commit comments

Comments
 (0)