Skip to content

Commit db36644

Browse files
committed
experimental multi-host device_put
1 parent 5ed4cba commit db36644

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tpu_inference/runner/tpu_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1557,7 +1557,7 @@ def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"):
15571557
input_tuple_single_device = jax.device_put(
15581558
(input_ids, positions, block_tables, query_start_loc, seq_lens,
15591559
logits_indices, request_distribution),
1560-
device=self.devices[0],
1560+
device=jax.local_devices()[0],
15611561
)
15621562
(input_ids, positions, block_tables, query_start_loc, seq_lens,
15631563
logits_indices, request_distribution) = device_array(

0 commit comments

Comments
 (0)