Skip to content

Commit 103a581

Browse files
committed
resolve comments
Signed-off-by: Chenyaaang <chenyangli@google.com>
1 parent 4cac52f commit 103a581

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tpu_inference/worker/tpu_worker_jax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,8 @@ def init_device(self):
165165
if device_indexes is not None and len(device_indexes) > 0:
166166
# Enforcing the devices sequence to be consistent with the specified device indexes
167167
self.devices = [jax.local_devices()[i] for i in device_indexes]
168-
all_devices = jax.local_devices()
169-
device_dict = {device.id: device for device in all_devices}
168+
all_local_devices = jax.local_devices()
169+
device_dict = {device.id: device for device in all_local_devices}
170170
self.devices = []
171171
for device_index in device_indexes:
172172
device = device_dict[device_index]

0 commit comments

Comments
 (0)