fix(jax): write the checkpoint pointer beside save_ckpt#5726
fix(jax): write the checkpoint pointer beside save_ckpt#5726wanghan-iapcm wants to merge 1 commit into
Conversation
JAX training writes checkpoint directories and the stable .jax link relative to save_ckpt (which may include a directory), but always wrote the "checkpoint" pointer file to the current working directory with a value that still carried the directory prefix (e.g. "runs/water/model.ckpt.jax"). The freeze entrypoint looks for the pointer inside the folder it is given and resolves the value relative to that folder, so a directory-valued save_ckpt both misplaced the pointer and double-prefixed the resolved path, breaking freeze and restart-style tooling. Write the pointer into Path(save_ckpt).parent and store a value relative to that directory (the basename only). For the default bare save_ckpt (parent == "."), the pointer stays in the CWD with the same value, so existing behavior is unchanged. Adds source/tests/jax/test_checkpoint_pointer.py, which drives _save_checkpoint with the checkpoint I/O mocked: the directory case asserts the pointer lands beside the checkpoint with a basename value and not in the CWD (fails on master), and a bare-name control asserts the pointer stays in the CWD. The trainer's pointer writing previously had no test; the existing freeze test hand-wrote a correct pointer and never exercised it. Fix deepmodeling#5678
|
Warning Review limit reached
Next review available in: 18 minutes Enable usage-based reviews in Billing to review now. Otherwise, wait until the next included review is available. How can I continue?After more reviews become available, a review can be triggered using the To avoid repeated limits, reduce automatic review volume by pausing incremental auto-reviews earlier, using label-based review opt-in, excluding WIP or generated PR titles, or requesting reviews manually when the PR is ready. If your team needs uninterrupted high-volume reviews, an organization admin can enable usage-based reviews. How do review limits work?CodeRabbit enforces per-developer PR review limits for each organization. Most developers receive the normal plan review availability. For paid Pro and Pro+ PR reviews, CodeRabbit uses adaptive limits for sustained high-volume activity. When a developer's recent PR review activity reaches the 95th percentile or higher among CodeRabbit users, additional reviews become available more gradually as earlier reviews age out of the rolling window. Please refer docs for additional details. Review details⚙️ Run configurationConfiguration used: Repository UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (2)
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
|
||
| import os | ||
| import tempfile | ||
| import unittest |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #5726 +/- ##
==========================================
- Coverage 81.26% 80.78% -0.48%
==========================================
Files 988 988
Lines 110877 110887 +10
Branches 4234 4232 -2
==========================================
- Hits 90103 89580 -523
- Misses 19249 19782 +533
Partials 1525 1525 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
Problem
Fixes #5678. JAX training writes checkpoint directories and the stable
.jaxlink relative tosave_ckpt(which may include a directory), but always wrote thecheckpointpointer file to the current working directory with a value that still carried the directory prefix, e.g.runs/water/model.ckpt.jax. The freeze entrypoint looks for the pointer inside the folder it is given and resolves the pointer's value relative to that folder (checkpoint_folder / pointer). So forsave_ckpt = runs/water/model.ckpt, the pointer was written to./checkpoint(notruns/water/checkpoint) and, even if relocated, its value would have double-prefixed toruns/water/runs/water/model.ckpt.jax. Passingruns/waterto freeze or restart-style tooling could not find or resolve the checkpoint, even though the matching checkpoint directory and.jaxlink were written there.Fix
Write the pointer into
Path(save_ckpt).parentand store a value relative to that directory (the basename only). For the default baresave_ckpt(parent is.) the pointer stays in the CWD with the same value, so existing behavior is unchanged; only directory-valuedsave_ckptis affected.Test
Adds
source/tests/jax/test_checkpoint_pointer.py, which drives_save_checkpointwith the checkpoint I/O mocked. The directory case asserts the pointer lands beside the checkpoint (subdir/checkpoint) with a basename value (model.ckpt.jax) and not in the CWD — this fails on master — and a bare-name control asserts the pointer stays in the CWD. The trainer's pointer writing previously had no coverage; the existing freeze test hand-wrote a correct pointer and never exercised the writer.