diff --git a/init2winit/checkpoint.py b/init2winit/checkpoint.py index d1758eef..e4d0f1ec 100644 --- a/init2winit/checkpoint.py +++ b/init2winit/checkpoint.py @@ -173,8 +173,8 @@ def save_unreplicated_checkpoint( # So we first all_gather it to the host and then call jax.device_get if jax.process_count() > 1: unreplicated_optimizer_state = jax.device_get( - process_allgather(optimizer_state)) - unreplicated_params = jax.device_get(process_allgather(params)) + process_allgather(optimizer_state, tiled=True)) + unreplicated_params = jax.device_get(process_allgather(params, tiled=True)) else: unreplicated_optimizer_state = jax.device_get(optimizer_state) unreplicated_params = jax.device_get(params) diff --git a/init2winit/trainer_lib/trainer_utils.py b/init2winit/trainer_lib/trainer_utils.py index c60373cf..be3b6cba 100644 --- a/init2winit/trainer_lib/trainer_utils.py +++ b/init2winit/trainer_lib/trainer_utils.py @@ -191,8 +191,8 @@ def evaluate( # `merge` aggregates the metrics across batches. metrics = metrics.merge(computed_metrics) - metrics = jax.device_get(process_allgather(metrics)) - metrics = jax.tree_util.tree_map(lambda x: x[0] if x.ndim > 0 else x, metrics) + metrics = jax.device_get(process_allgather(metrics, tiled=True)) + metrics = jax.tree_util.tree_map(lambda x: x[0] if x.ndim > 1 else x, metrics) # For data splits with no data (e.g. Imagenet no test set) no values # will appear for that split. if metrics is not None: