diff --git a/deepmd/jax/train/trainer.py b/deepmd/jax/train/trainer.py index c77bc944b5..fec5e15628 100644 --- a/deepmd/jax/train/trainer.py +++ b/deepmd/jax/train/trainer.py @@ -806,8 +806,14 @@ def _save_checkpoint(self, step: int) -> None: log.info(f"Trained model has been saved to: {ckpt_path!s}") _link_checkpoint(ckpt_path, Path(f"{self.save_ckpt}.jax")) self._cleanup_old_checkpoints() - with open("checkpoint", "w") as fp: - fp.write(f"{self.save_ckpt}.jax") + # Write the pointer next to the checkpoint prefix, with a value relative + # to that directory (basename only). The freeze entrypoint looks for the + # pointer inside the folder it is given and resolves the value relative + # to it, so a directory-valued save_ckpt would otherwise be unresolvable. + ckpt_dir = Path(self.save_ckpt).parent + ckpt_dir.mkdir(parents=True, exist_ok=True) + with open(ckpt_dir / "checkpoint", "w") as fp: + fp.write(f"{Path(self.save_ckpt).name}.jax") def _save_full_validation_checkpoint( self, diff --git a/source/tests/jax/test_checkpoint_pointer.py b/source/tests/jax/test_checkpoint_pointer.py new file mode 100644 index 0000000000..8f2faf44b6 --- /dev/null +++ b/source/tests/jax/test_checkpoint_pointer.py @@ -0,0 +1,65 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Test the JAX trainer writes its checkpoint pointer beside save_ckpt. + +JAX training writes checkpoint directories and the stable ``.jax`` link relative +to ``save_ckpt`` (which may include a directory), but used to always write the +``checkpoint`` pointer file to the current working directory with a value that +still carried the directory prefix. The freeze entrypoint expects the pointer +inside the folder it is given and resolves its value relative to that folder, so +a directory-valued ``save_ckpt`` broke freeze/restart tooling. +""" + +import os +import tempfile +import unittest +from pathlib import ( + Path, +) +from unittest import ( + mock, +) + +from deepmd.jax.train.trainer import ( + DPTrainer, +) + + +class TestCheckpointPointer(unittest.TestCase): + def setUp(self) -> None: + self.tmpdir = tempfile.TemporaryDirectory() + self.cwd = os.getcwd() + os.chdir(self.tmpdir.name) + + def tearDown(self) -> None: + os.chdir(self.cwd) + self.tmpdir.cleanup() + + def _save(self, save_ckpt: str) -> None: + trainer = DPTrainer.__new__(DPTrainer) + trainer.save_ckpt = save_ckpt + with ( + mock.patch.object(DPTrainer, "_write_checkpoint"), + mock.patch.object(DPTrainer, "_cleanup_old_checkpoints"), + mock.patch("deepmd.jax.train.trainer._link_checkpoint"), + ): + trainer._save_checkpoint(1) + + def test_pointer_beside_ckpt_for_subdir(self) -> None: + # save_ckpt with a directory: pointer must land in that directory with a + # value relative to it (basename only), so freeze(subdir) resolves it. + self._save("subdir/model.ckpt") + pointer = Path("subdir") / "checkpoint" + self.assertTrue(pointer.is_file()) + self.assertEqual(pointer.read_text(), "model.ckpt.jax") + self.assertFalse(Path("checkpoint").exists()) + + def test_pointer_in_cwd_for_bare_name(self) -> None: + # default/bare save_ckpt: pointer stays in the CWD (parent is "."). + self._save("model.ckpt") + pointer = Path("checkpoint") + self.assertTrue(pointer.is_file()) + self.assertEqual(pointer.read_text(), "model.ckpt.jax") + + +if __name__ == "__main__": + unittest.main()