Skip to content

Commit 0a3b4f0

Browse files
committed
experimental multi-host device_put
1 parent b45ad8c commit 0a3b4f0

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tpu_inference/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def device_array(mesh: Mesh, *args, sharding=None, **kwargs) -> jax.Array:
243243
"""
244244
if sharding is None:
245245
sharding = NamedSharding(mesh, PartitionSpec(None))
246-
return jax.make_array_from_process_local_data(sharding=sharding, *args, **kwargs)
246+
return jax.make_array_from_process_local_data(sharding, *args, **kwargs)
247247

248248

249249
def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:

0 commit comments

Comments
 (0)