@@ -428,9 +428,7 @@ def _precompile_eagle3_helpers(self) -> None:
428428 draft_kv_cache_group_id = num_kv_cache_groups - 1
429429 block_tables = self .runner .input_batch .block_table [
430430 draft_kv_cache_group_id ].get_cpu_tensor ().reshape (- 1 )
431- block_tables_first_spec = jax .device_put (
432- block_tables , NamedSharding (self .runner .mesh , PartitionSpec ()))
433- block_tables_loop = jax .device_put (
431+ block_tables = jax .device_put (
434432 block_tables , NamedSharding (self .runner .mesh ,
435433 PartitionSpec (None , )))
436434
@@ -447,7 +445,7 @@ def _precompile_eagle3_helpers(self) -> None:
447445 self ._run_compilation (
448446 "_update_inputs_for_loop_speculation for the subsequent loops" ,
449447 self .runner .drafter ._update_inputs_for_loop_speculation ,
450- selected_positions , seq_lens , block_tables_loop )
448+ selected_positions , seq_lens , block_tables )
451449
452450 request_distribution = np .array ([0 , 0 , 0 ], dtype = np .int32 )
453451 request_distribution = device_array (self .runner .mesh ,
@@ -498,7 +496,7 @@ def _precompile_eagle3_helpers(self) -> None:
498496 positions = self ._create_dummy_tensor ((num_tokens , ), jnp .int32 )
499497 attention_metadata = AttentionMetadata (
500498 input_positions = positions ,
501- block_tables = block_tables_first_spec ,
499+ block_tables = block_tables ,
502500 seq_lens = seq_lens ,
503501 query_start_loc = query_start_loc ,
504502 request_distribution = request_distribution ,
@@ -520,11 +518,7 @@ def filter_token_and_prepare_initial_inputs_wrapper(
520518 num_reqs )
521519 return target_hidden_states , input_ids , last_token_indices
522520
523- token_indices = self ._create_dummy_tensor ((num_tokens , ),
524- jnp .int32 )
525- input_ids = self ._create_dummy_tensor (
526- (num_tokens , ), jnp .int32 ,
527- NamedSharding (self .runner .mesh , PartitionSpec ()))
521+ input_ids = self ._create_dummy_tensor ((num_tokens , ), jnp .int32 )
528522 aux_hidden_states = [
529523 self ._create_dummy_tensor (
530524 (num_tokens , hidden_size ), jnp .bfloat16 ,
@@ -539,22 +533,29 @@ def filter_token_and_prepare_initial_inputs_wrapper(
539533 NamedSharding (self .runner .mesh , PartitionSpec (None ,
540534 None ))),
541535 ]
542- self ._run_compilation (
543- "eagle3_filter_token_and_prepare_initial_inputs" ,
544- filter_token_and_prepare_initial_inputs_wrapper ,
545- token_indices ,
546- query_start_loc ,
547- seq_lens ,
548- input_ids ,
549- aux_hidden_states ,
550- attention_metadata ,
551- next_token_ids ,
552- device_array (
553- self .runner .mesh ,
554- np .asarray ([self .runner .input_batch .num_reqs ],
555- dtype = jnp .int32 )),
556- num_tokens = num_tokens ,
557- )
536+ # TODO(ranlihao): This will increase the precompilation latency. Find proper range for token_indices.
537+ for padded_total_num_tokens in [
538+ num_tokens ,
539+ min (num_tokens * 2 , self .runner .num_tokens_paddings [- 1 ])
540+ ]:
541+ token_indices = self ._create_dummy_tensor (
542+ (padded_total_num_tokens , ), jnp .int32 )
543+ self ._run_compilation (
544+ "eagle3_filter_token_and_prepare_initial_inputs" ,
545+ filter_token_and_prepare_initial_inputs_wrapper ,
546+ token_indices ,
547+ query_start_loc ,
548+ seq_lens ,
549+ input_ids ,
550+ aux_hidden_states ,
551+ attention_metadata ,
552+ next_token_ids ,
553+ device_array (
554+ self .runner .mesh ,
555+ np .asarray ([self .runner .input_batch .num_reqs ],
556+ dtype = jnp .int32 )),
557+ num_tokens = num_tokens ,
558+ )
558559
559560 def draft_model_fn_wrapper (
560561 state ,
@@ -572,6 +573,9 @@ def draft_model_fn_wrapper(
572573 target_hidden_states = self ._create_dummy_tensor (
573574 (num_tokens , hidden_size ), dtype ,
574575 NamedSharding (self .runner .mesh , PartitionSpec (None , "model" )))
576+ input_ids = self ._create_dummy_tensor (
577+ (num_tokens , ), jnp .int32 ,
578+ NamedSharding (self .runner .mesh , PartitionSpec ()))
575579 self ._run_compilation (
576580 "eagle3_draft_model_fn" ,
577581 draft_model_fn_wrapper ,
@@ -602,7 +606,6 @@ def draft_model_fn_wrapper(
602606 attention_metadata .query_start_loc = jax .device_put (
603607 attention_metadata .query_start_loc ,
604608 NamedSharding (self .runner .mesh , PartitionSpec ()))
605- attention_metadata .block_tables = block_tables_loop
606609 attention_metadata .input_positions = self ._create_dummy_tensor (
607610 (self .runner .max_num_reqs , ), jnp .int32 )
608611 self ._run_compilation (
0 commit comments