From 97763a87541e95eac99d44d7de96dd56055b5456 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Mon, 2 Feb 2026 14:25:13 -0800 Subject: [PATCH] Update jax.host_id() and jax.host_count() to jax.process_index() and jax.process_count(), respectively. PiperOrigin-RevId: 864521508 --- hessian/model_debugger.py | 2 +- init2winit/dataset_lib/imagenet_dataset.py | 2 +- init2winit/mt_eval/main.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/hessian/model_debugger.py b/hessian/model_debugger.py index 93337f16..943f0b07 100644 --- a/hessian/model_debugger.py +++ b/hessian/model_debugger.py @@ -555,7 +555,7 @@ def full_eval(self, self._stored_metrics, remove_leaf_tuples(flax.core.unfreeze(all_metrics))) - if self._metrics_logger and jax.host_id() == 0: + if self._metrics_logger and jax.process_index() == 0: self._maybe_save_metrics(step) return all_metrics diff --git a/init2winit/dataset_lib/imagenet_dataset.py b/init2winit/dataset_lib/imagenet_dataset.py index 51cd2b1e..e0ab4c6d 100644 --- a/init2winit/dataset_lib/imagenet_dataset.py +++ b/init2winit/dataset_lib/imagenet_dataset.py @@ -134,7 +134,7 @@ def map(self, features): # Grain currently doesn't support the `ds.enumerate()` # functionality, they suggested moving mixup to the training loop where we # can access the step number. - batch_index = features[grain.INDEX][0] // jax.host_count() + batch_index = features[grain.INDEX][0] // jax.process_count() seed = tf.random.experimental.stateless_fold_in( self.initial_seed, batch_index ) diff --git a/init2winit/mt_eval/main.py b/init2winit/mt_eval/main.py index 99d9d7fb..ce0fe033 100644 --- a/init2winit/mt_eval/main.py +++ b/init2winit/mt_eval/main.py @@ -101,8 +101,8 @@ def main(unused_argv): if jax.process_index() == 0: logging.info('argv:\n%s', ' '.join(sys.argv)) logging.info('device_count: %d', jax.device_count()) - logging.info('num_hosts : %d', jax.host_count()) - logging.info('host_id : %d', jax.host_id()) + logging.info('num_hosts : %d', jax.process_count()) + logging.info('host_id : %d', jax.process_index()) model_class = models.get_model(model_name) dataset_builder = datasets.get_dataset(dataset_name)