Skip to content

Commit 41b8cd4

Browse files
committed
resolve comments
Signed-off-by: Chenyaaang <chenyangli@google.com>
1 parent 3a40604 commit 41b8cd4

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tpu_inference/worker/tpu_worker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,8 @@ def init_device(self):
164164
device_indexes = sharding_config.device_indexes
165165
if device_indexes is not None and len(device_indexes) > 0:
166166
# Enforcing the devices sequence to be consistent with the specified device indexes
167-
all_devices = jax.local_devices()
168-
device_dict = {device.id: device for device in all_devices}
167+
all_local_devices = jax.local_devices()
168+
device_dict = {device.id: device for device in all_local_devices}
169169
self.devices = []
170170
for device_index in device_indexes:
171171
device = device_dict[device_index]

0 commit comments

Comments
 (0)