We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 4cac52f commit 103a581Copy full SHA for 103a581
tpu_inference/worker/tpu_worker_jax.py
@@ -165,8 +165,8 @@ def init_device(self):
165
if device_indexes is not None and len(device_indexes) > 0:
166
# Enforcing the devices sequence to be consistent with the specified device indexes
167
self.devices = [jax.local_devices()[i] for i in device_indexes]
168
- all_devices = jax.local_devices()
169
- device_dict = {device.id: device for device in all_devices}
+ all_local_devices = jax.local_devices()
+ device_dict = {device.id: device for device in all_local_devices}
170
self.devices = []
171
for device_index in device_indexes:
172
device = device_dict[device_index]
0 commit comments