From e7a59feb1f5d5897e16d50891326d7082ca29dbc Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 2 Oct 2025 15:02:35 -0700 Subject: [PATCH] internal change PiperOrigin-RevId: 814399303 --- init2winit/checkpoint.py | 4 ++-- init2winit/trainer_lib/trainer_utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/init2winit/checkpoint.py b/init2winit/checkpoint.py index d1758eef..e4d0f1ec 100644 --- a/init2winit/checkpoint.py +++ b/init2winit/checkpoint.py @@ -173,8 +173,8 @@ def save_unreplicated_checkpoint( # So we first all_gather it to the host and then call jax.device_get if jax.process_count() > 1: unreplicated_optimizer_state = jax.device_get( - process_allgather(optimizer_state)) - unreplicated_params = jax.device_get(process_allgather(params)) + process_allgather(optimizer_state, tiled=True)) + unreplicated_params = jax.device_get(process_allgather(params, tiled=True)) else: unreplicated_optimizer_state = jax.device_get(optimizer_state) unreplicated_params = jax.device_get(params) diff --git a/init2winit/trainer_lib/trainer_utils.py b/init2winit/trainer_lib/trainer_utils.py index c60373cf..be3b6cba 100644 --- a/init2winit/trainer_lib/trainer_utils.py +++ b/init2winit/trainer_lib/trainer_utils.py @@ -191,8 +191,8 @@ def evaluate( # `merge` aggregates the metrics across batches. metrics = metrics.merge(computed_metrics) - metrics = jax.device_get(process_allgather(metrics)) - metrics = jax.tree_util.tree_map(lambda x: x[0] if x.ndim > 0 else x, metrics) + metrics = jax.device_get(process_allgather(metrics, tiled=True)) + metrics = jax.tree_util.tree_map(lambda x: x[0] if x.ndim > 1 else x, metrics) # For data splits with no data (e.g. Imagenet no test set) no values # will appear for that split. if metrics is not None: