1515from flax import nnx
1616from jax .experimental import mesh_utils
1717from jax .sharding import NamedSharding , PartitionSpec
18- from torchax .ops .mappings import j2t , j2t_dtype
18+ from torchax .ops .mappings import j2t_dtype
1919from vllm .config import VllmConfig
2020from vllm .distributed .kv_transfer import (get_kv_transfer_group ,
2121 has_kv_transfer_group )
@@ -154,6 +154,7 @@ class ExecuteModelState:
154154 spec_decode_metadata : Optional [SpecDecodeMetadata ]
155155 kv_connector_output : Optional [KVConnectorOutput ]
156156 logits_indices_selector : Optional [List [int ]] = None
157+ padded_num_reqs : Optional [int ] = None
157158
158159
159160@functools .partial (jax .jit , donate_argnums = (0 , 1 , 2 ))
@@ -191,19 +192,28 @@ def _substitute_placeholder_token(
191192 return input_ids .at [token_in_tpu_cur_input_indices ].set (update_values )
192193
193194
194- def _reorder_logits_indices (logprobs_lists : LogprobsLists ,
195- logits_indices_selector : List [int ]):
195+ def _jax_logprobs_to_lists (logprobs_tensors ,
196+ logits_indices_selector = None ,
197+ cu_num_generated_tokens = None ):
198+ """Convert JAX LogprobsTensors to LogprobsLists by converting JAX arrays to numpy."""
199+ log_token_ids_list = logprobs_tensors .logprob_token_ids .tolist ()
200+ logprobs_list = logprobs_tensors .logprobs .tolist ()
201+ selected_token_ranks_list = logprobs_tensors .selected_token_ranks .tolist ()
202+
203+ if logits_indices_selector is not None :
204+ log_token_ids_list = [
205+ log_token_ids_list [i ] for i in logits_indices_selector
206+ ]
207+ logprobs_list = [logprobs_list [i ] for i in logits_indices_selector ]
208+ selected_token_ranks_list = [
209+ selected_token_ranks_list [i ] for i in logits_indices_selector
210+ ]
211+
196212 return LogprobsLists (
197- logprob_token_ids = [
198- logprobs_lists .logprob_token_ids [i ]
199- for i in logits_indices_selector
200- ],
201- logprobs = [logprobs_lists .logprobs [i ] for i in logits_indices_selector ],
202- sampled_token_ranks = [
203- logprobs_lists .sampled_token_ranks [i ]
204- for i in logits_indices_selector
205- ],
206- cu_num_generated_tokens = logprobs_lists .cu_num_generated_tokens ,
213+ logprob_token_ids = np .asarray (log_token_ids_list ),
214+ logprobs = np .asarray (logprobs_list ),
215+ sampled_token_ranks = np .asarray (selected_token_ranks_list ),
216+ cu_num_generated_tokens = cu_num_generated_tokens ,
207217 )
208218
209219
@@ -552,16 +562,17 @@ def sample_tokens(
552562
553563 (scheduler_output , attn_metadata , input_ids , hidden_states , logits ,
554564 aux_hidden_states , spec_decode_metadata , kv_connector_output ,
555- logits_indices_selector ) = (
556- self .execute_model_state .scheduler_output ,
557- self .execute_model_state .attn_metadata ,
558- self .execute_model_state .input_ids ,
559- self .execute_model_state .hidden_states ,
560- self .execute_model_state .logits ,
561- self .execute_model_state .aux_hidden_states ,
562- self .execute_model_state .spec_decode_metadata ,
563- self .execute_model_state .kv_connector_output ,
564- self .execute_model_state .logits_indices_selector )
565+ logits_indices_selector ,
566+ padded_num_reqs ) = (self .execute_model_state .scheduler_output ,
567+ self .execute_model_state .attn_metadata ,
568+ self .execute_model_state .input_ids ,
569+ self .execute_model_state .hidden_states ,
570+ self .execute_model_state .logits ,
571+ self .execute_model_state .aux_hidden_states ,
572+ self .execute_model_state .spec_decode_metadata ,
573+ self .execute_model_state .kv_connector_output ,
574+ self .execute_model_state .logits_indices_selector ,
575+ self .execute_model_state .padded_num_reqs )
565576 self .execute_model_state = None
566577
567578 if grammar_output is not None :
@@ -575,12 +586,10 @@ def sample_tokens(
575586 logits ,
576587 arange ,
577588 )
578- return self ._sample_from_logits (scheduler_output , attn_metadata ,
579- input_ids , hidden_states , logits ,
580- aux_hidden_states ,
581- spec_decode_metadata ,
582- kv_connector_output ,
583- logits_indices_selector )
589+ return self ._sample_from_logits (
590+ scheduler_output , attn_metadata , input_ids , hidden_states , logits ,
591+ aux_hidden_states , spec_decode_metadata , kv_connector_output ,
592+ logits_indices_selector , padded_num_reqs )
584593
585594 def _modify_prev_results (self ):
586595 # If copy to host has not been done, we just wait.
@@ -694,6 +703,7 @@ def _execute_model(
694703 logits_indices ,
695704 spec_decode_metadata ,
696705 logits_indices_selector ,
706+ padded_num_reqs ,
697707 ) = self ._prepare_inputs (scheduler_output )
698708
699709 # multi-modal support
@@ -756,7 +766,8 @@ def _execute_model(
756766 aux_hidden_states = aux_hidden_states ,
757767 spec_decode_metadata = spec_decode_metadata ,
758768 kv_connector_output = kv_connector_output ,
759- logits_indices_selector = logits_indices_selector )
769+ logits_indices_selector = logits_indices_selector ,
770+ padded_num_reqs = padded_num_reqs )
760771 return attn_metadata , None
761772
762773 def _sample_from_logits (
@@ -770,11 +781,19 @@ def _sample_from_logits(
770781 spec_decode_metadata : Optional [SpecDecodeMetadata ],
771782 kv_connector_output : Optional [KVConnectorOutput ],
772783 logits_indices_selector : Optional [List [int ]] = None ,
784+ padded_num_reqs : Optional [int ] = None ,
773785 ) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput :
774- padded_num_reqs = runner_utils .get_padded_num_reqs_with_upper_limit (
775- self .input_batch .num_reqs , self .max_num_reqs )
786+ if padded_num_reqs is None :
787+ padded_num_reqs = runner_utils .get_padded_num_reqs_with_upper_limit (
788+ self .input_batch .num_reqs , self .max_num_reqs )
789+
790+ sharding = None
791+ if self .dp_size > 1 :
792+ sharding = NamedSharding (self .mesh ,
793+ PartitionSpec (ShardingAxisName .ATTN_DATA ))
794+
776795 tpu_sampling_metadata = TPUSupportedSamplingMetadata .from_input_batch (
777- self .mesh , self .input_batch , padded_num_reqs )
796+ self .mesh , self .input_batch , padded_num_reqs , sharding = sharding )
778797 if spec_decode_metadata is None :
779798 next_tokens = sample (
780799 self .rng_params_for_sampling ,
@@ -806,8 +825,6 @@ def _sample_from_logits(
806825 if tpu_sampling_metadata .logprobs :
807826 logprobs = self ._compute_and_gather_logprobs (
808827 logits , next_tokens , self .model_config .max_logprobs )
809- logprobs_lists = jax .tree .map (lambda x : j2t (x .astype (jnp .float32 )),
810- logprobs ).tolists ()
811828 else :
812829 logprobs = None
813830
@@ -860,9 +877,8 @@ def _sample_from_logits(
860877
861878 if logprobs is not None :
862879 # Map logprobs back to the pre-dp shuffling order
863- if logits_indices_selector is not None :
864- logprobs_lists = _reorder_logits_indices (
865- logprobs_lists , logits_indices_selector )
880+ logprobs_lists = _jax_logprobs_to_lists (
881+ logprobs , logits_indices_selector )
866882
867883 else :
868884 logprobs_lists = None
@@ -934,9 +950,8 @@ def _sample_from_logits(
934950
935951 if logprobs is not None :
936952 # Map logprobs back to the pre-dp shuffling order
937- if logits_indices_selector is not None :
938- logprobs_lists = _reorder_logits_indices (
939- logprobs_lists , logits_indices_selector )
953+ logprobs_lists = _jax_logprobs_to_lists (logprobs ,
954+ logits_indices_selector )
940955 else :
941956 logprobs_lists = None
942957
@@ -1397,6 +1412,7 @@ def _prepare_inputs_dp(self, scheduler_output: "VllmSchedulerOutput"):
13971412 logits_indices ,
13981413 spec_decode_metadata ,
13991414 logits_indices_selector ,
1415+ padded_num_reqs ,
14001416 )
14011417
14021418 def _prepare_inputs_non_dp (self , scheduler_output : "VllmSchedulerOutput" ):
@@ -1563,7 +1579,8 @@ def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"):
15631579 attention_metadata .seq_lens_cpu = seq_lens_cpu
15641580 logits_indices_selector = None
15651581 return (input_ids , attention_metadata , sampling_metadata ,
1566- logits_indices , spec_decode_metadata , logits_indices_selector )
1582+ logits_indices , spec_decode_metadata , logits_indices_selector ,
1583+ padded_num_reqs )
15671584
15681585 def _get_input_ids_embeds (self , input_ids : jax .Array ,
15691586 mm_embeds : list [jax .Array ]):
0 commit comments