@@ -146,8 +146,7 @@ def slice_device_count(self, slice_index: int) -> int:
146146 f"Slice { slice_index = } not found in { self .slice_to_devices = } "
147147 ) from error
148148
149- @classmethod
150- def is_error_due_to_slice_down (cls , error : Exception ) -> bool :
149+ def is_error_due_to_slice_down (self , error : Exception ) -> bool :
151150 """Returns True if the error is due to slice down.
152151
153152 The error types that are considered due to slice down are
@@ -160,7 +159,8 @@ def is_error_due_to_slice_down(cls, error: Exception) -> bool:
160159 error: The error to check.
161160 """
162161 return_value = isinstance (error , jax .errors .JaxRuntimeError ) and any (
163- error_type in str (error ) for error_type in cls ._ELASTIC_DOWN_ERROR_TYPES
162+ error_type in str (error )
163+ for error_type in self ._ELASTIC_DOWN_ERROR_TYPES
164164 )
165165 if return_value :
166166 _logger .info ("Caught an error due to slice down" )
@@ -171,8 +171,7 @@ def is_error_due_to_slice_down(cls, error: Exception) -> bool:
171171
172172 return return_value
173173
174- @classmethod
175- def _simple_execution (cls , devices : Sequence [jax .Device ]) -> jax .Array :
174+ def _simple_execution (self , devices : Sequence [jax .Device ]) -> jax .Array :
176175 """Simple execution to test if a slice is available.
177176
178177 This function is used to test if a slice is available. It executes a simple
@@ -192,7 +191,7 @@ def _simple_execution(cls, devices: Sequence[jax.Device]) -> jax.Array:
192191 raise ValueError ("No devices" )
193192
194193 test_input = np .zeros (len (devices ), dtype = float ) + (
195- cls ._SIMPLE_EXECUTION_TEST_VALUE - 1
194+ self ._SIMPLE_EXECUTION_TEST_VALUE - 1
196195 )
197196
198197 return jax .pmap (lambda x : x + 1 , devices = devices )(test_input )
@@ -374,7 +373,8 @@ def pop_snapshot(self) -> tuple[int, PyTree | None, PyTree | None]:
374373 the manager. Calls will raise an error if there are no snapshot to pop.
375374
376375 Returns:
377- A tuple of the step and the snapshot.
376+ A tuple of the step, the snapshot of jax arrays, and the snapshot of
377+ controller variables.
378378
379379 Raises:
380380 ElasticRuntimeError: If there is no snapshot to pop.
@@ -391,46 +391,6 @@ def pop_snapshot(self) -> tuple[int, PyTree | None, PyTree | None]:
391391
392392 return step , snapshot_jax_arrays , snapshot_controller
393393
394- @staticmethod
395- def _get_snapshot_jax_arrays_size (snapshot_jax_arrays : PyTree | None ) -> int :
396- """Returns the size of a snapshot.
397-
398- Args:
399- snapshot_jax_arrays: The snapshot to get the size of.
400- """
401- return sum (leaf .nbytes for leaf in jax .tree .leaves (snapshot_jax_arrays ))
402-
403- @staticmethod
404- def _put_snapshot_jax_arrays_on_host (
405- snapshot_jax_arrays : PyTree | None ,
406- ) -> PyTree | None :
407- """Puts a copy of the snapshot on the host.
408-
409- Args:
410- snapshot_jax_arrays: The snapshot to move to the host. Must be a PyTree of
411- JAX arrays or None.
412-
413- Returns:
414- A copy of the snapshot on the host.
415- """
416-
417- sharding_pinned_host = jax .tree .map (
418- lambda x : x .sharding .with_memory_kind ("pinned_host" ),
419- snapshot_jax_arrays ,
420- )
421- return jax .device_put (
422- snapshot_jax_arrays ,
423- sharding_pinned_host ,
424- donate = False ,
425- may_alias = False ,
426- )
427-
428- @staticmethod
429- def _put_snapshot_on_controller (
430- snapshot : PyTree | None ,
431- ) -> PyTree | None :
432- return copy .deepcopy (snapshot )
433-
434394 # TODO: b/407772100 - Support multiple snapshots.
435395 @timing .timeit
436396 def maybe_snapshot (
@@ -459,22 +419,30 @@ def maybe_snapshot(
459419 _logger .info ("Not saving a snapshot" )
460420 return
461421
462- total_nbytes = self ._get_snapshot_jax_arrays_size (snapshot_jax_arrays )
422+ total_nbytes = sum (
423+ leaf .nbytes for leaf in jax .tree .leaves (snapshot_jax_arrays )
424+ )
463425
464426 _logger .info ("Saving a snapshot of %s bytes on host" , total_nbytes )
465427
466- snapshot_jax_arrays_host = self ._put_snapshot_jax_arrays_on_host (
467- snapshot_jax_arrays
428+ sharding_pinned_host = jax .tree .map (
429+ lambda x : x .sharding .with_memory_kind ("pinned_host" ),
430+ snapshot_jax_arrays ,
431+ )
432+ snapshot_jax_arrays_host = jax .device_put (
433+ snapshot_jax_arrays ,
434+ sharding_pinned_host ,
435+ donate = False ,
436+ may_alias = False ,
468437 )
469438 _logger .info ("Snapshot dispatched" )
470439
471440 if block :
472441 jax .block_until_ready (snapshot_jax_arrays_host )
473442 _logger .info ("Snapshot completed" )
474443
475- snapshot_on_controller = self ._put_snapshot_on_controller (
476- snapshot_controller
477- )
444+ snapshot_on_controller = copy .deepcopy (snapshot_controller )
445+
478446 self ._snapshot = {
479447 "step" : step ,
480448 "snapshot_jax_arrays" : snapshot_jax_arrays_host ,
@@ -523,9 +491,7 @@ def get_resharded_snapshot(
523491 may_alias = False ,
524492 )
525493
526- snapshot_on_controller = self ._put_snapshot_on_controller (
527- snapshot_controller
528- )
494+ snapshot_on_controller = copy .deepcopy (snapshot_controller )
529495
530496 self ._snapshot = {
531497 "step" : step ,
0 commit comments