From e8c36e0c8b09e5c1a184211ca537c3300e426cc9 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 4 Jul 2026 01:34:57 +0800 Subject: [PATCH] fix(jax): write the checkpoint pointer beside save_ckpt JAX training writes checkpoint directories and the stable .jax link relative to save_ckpt (which may include a directory), but always wrote the "checkpoint" pointer file to the current working directory with a value that still carried the directory prefix (e.g. "runs/water/model.ckpt.jax"). The freeze entrypoint looks for the pointer inside the folder it is given and resolves the value relative to that folder, so a directory-valued save_ckpt both misplaced the pointer and double-prefixed the resolved path, breaking freeze and restart-style tooling. Write the pointer into Path(save_ckpt).parent and store a value relative to that directory (the basename only). For the default bare save_ckpt (parent == "."), the pointer stays in the CWD with the same value, so existing behavior is unchanged. Adds source/tests/jax/test_checkpoint_pointer.py, which drives _save_checkpoint with the checkpoint I/O mocked: the directory case asserts the pointer lands beside the checkpoint with a basename value and not in the CWD (fails on master), and a bare-name control asserts the pointer stays in the CWD. The trainer's pointer writing previously had no test; the existing freeze test hand-wrote a correct pointer and never exercised it. Fix #5678 --- deepmd/jax/train/trainer.py | 10 +++- source/tests/jax/test_checkpoint_pointer.py | 65 +++++++++++++++++++++ 2 files changed, 73 insertions(+), 2 deletions(-) create mode 100644 source/tests/jax/test_checkpoint_pointer.py diff --git a/deepmd/jax/train/trainer.py b/deepmd/jax/train/trainer.py index c77bc944b5..fec5e15628 100644 --- a/deepmd/jax/train/trainer.py +++ b/deepmd/jax/train/trainer.py @@ -806,8 +806,14 @@ def _save_checkpoint(self, step: int) -> None: log.info(f"Trained model has been saved to: {ckpt_path!s}") _link_checkpoint(ckpt_path, Path(f"{self.save_ckpt}.jax")) self._cleanup_old_checkpoints() - with open("checkpoint", "w") as fp: - fp.write(f"{self.save_ckpt}.jax") + # Write the pointer next to the checkpoint prefix, with a value relative + # to that directory (basename only). The freeze entrypoint looks for the + # pointer inside the folder it is given and resolves the value relative + # to it, so a directory-valued save_ckpt would otherwise be unresolvable. + ckpt_dir = Path(self.save_ckpt).parent + ckpt_dir.mkdir(parents=True, exist_ok=True) + with open(ckpt_dir / "checkpoint", "w") as fp: + fp.write(f"{Path(self.save_ckpt).name}.jax") def _save_full_validation_checkpoint( self, diff --git a/source/tests/jax/test_checkpoint_pointer.py b/source/tests/jax/test_checkpoint_pointer.py new file mode 100644 index 0000000000..8f2faf44b6 --- /dev/null +++ b/source/tests/jax/test_checkpoint_pointer.py @@ -0,0 +1,65 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Test the JAX trainer writes its checkpoint pointer beside save_ckpt. + +JAX training writes checkpoint directories and the stable ``.jax`` link relative +to ``save_ckpt`` (which may include a directory), but used to always write the +``checkpoint`` pointer file to the current working directory with a value that +still carried the directory prefix. The freeze entrypoint expects the pointer +inside the folder it is given and resolves its value relative to that folder, so +a directory-valued ``save_ckpt`` broke freeze/restart tooling. +""" + +import os +import tempfile +import unittest +from pathlib import ( + Path, +) +from unittest import ( + mock, +) + +from deepmd.jax.train.trainer import ( + DPTrainer, +) + + +class TestCheckpointPointer(unittest.TestCase): + def setUp(self) -> None: + self.tmpdir = tempfile.TemporaryDirectory() + self.cwd = os.getcwd() + os.chdir(self.tmpdir.name) + + def tearDown(self) -> None: + os.chdir(self.cwd) + self.tmpdir.cleanup() + + def _save(self, save_ckpt: str) -> None: + trainer = DPTrainer.__new__(DPTrainer) + trainer.save_ckpt = save_ckpt + with ( + mock.patch.object(DPTrainer, "_write_checkpoint"), + mock.patch.object(DPTrainer, "_cleanup_old_checkpoints"), + mock.patch("deepmd.jax.train.trainer._link_checkpoint"), + ): + trainer._save_checkpoint(1) + + def test_pointer_beside_ckpt_for_subdir(self) -> None: + # save_ckpt with a directory: pointer must land in that directory with a + # value relative to it (basename only), so freeze(subdir) resolves it. + self._save("subdir/model.ckpt") + pointer = Path("subdir") / "checkpoint" + self.assertTrue(pointer.is_file()) + self.assertEqual(pointer.read_text(), "model.ckpt.jax") + self.assertFalse(Path("checkpoint").exists()) + + def test_pointer_in_cwd_for_bare_name(self) -> None: + # default/bare save_ckpt: pointer stays in the CWD (parent is "."). + self._save("model.ckpt") + pointer = Path("checkpoint") + self.assertTrue(pointer.is_file()) + self.assertEqual(pointer.read_text(), "model.ckpt.jax") + + +if __name__ == "__main__": + unittest.main()