We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 3a40604 commit 41b8cd4Copy full SHA for 41b8cd4
tpu_inference/worker/tpu_worker.py
@@ -164,8 +164,8 @@ def init_device(self):
164
device_indexes = sharding_config.device_indexes
165
if device_indexes is not None and len(device_indexes) > 0:
166
# Enforcing the devices sequence to be consistent with the specified device indexes
167
- all_devices = jax.local_devices()
168
- device_dict = {device.id: device for device in all_devices}
+ all_local_devices = jax.local_devices()
+ device_dict = {device.id: device for device in all_local_devices}
169
self.devices = []
170
for device_index in device_indexes:
171
device = device_dict[device_index]
0 commit comments