Skip to content

Fix optimizer momentum reset on checkpoint resume#115

Open
amazloumi wants to merge 1 commit into
mainfrom
fix/resume-optimizer-determinism
Open

Fix optimizer momentum reset on checkpoint resume#115
amazloumi wants to merge 1 commit into
mainfrom
fix/resume-optimizer-determinism

Conversation

@amazloumi
Copy link
Copy Markdown
Member

Summary

  • Bug: CheckpointManager round-tripped optimizer state through raw optimizer.state_dict() / optimizer.load_state_dict(). On resume the optimizer is freshly constructed, so its state_dict() has no exp_avg / exp_avg_sq tensors yet (AdamW creates per-parameter state lazily on the first .step()). dcp.load therefore had no moment tensors to fill, the saved moments were silently dropped, and AdamW momentum reset to zero at every resume. Model weights, scheduler, dataloader position, and RNG all restored correctly — only the optimizer moments were lost, so resumed runs were not bit-exact.
    DCP's get_model_state_dict / get_optimizer_state_dict / set_model_state_dict / set_optimizer_state_dict. The getters build a load template with the moment tensors allocated in the correct FSDP/DTensor layout, so dcp.load repopulates them; the setters write them back into the live optimizer. The exclude_keys fine-tuning path (load model, skip optimizer) is preserved.
  • Tests (fail on the pre-fix code, pass after) — single-GPU and distributed, each with a direct moment check and an end-to-end resume check:
  • tests/integration/test_checkpoint_roundtrip.py::TestCheckpointRoundtrip::test_manager_restores_optimizer_moments_single_gpu
  • tests/distributed/test_checkpoint.py::TestCheckpointRoundTrip::test_resume_restores_optimizer_moments
  • tests/e2e/test_training_e2e.py::test_resume_determinism_single_gpu
  • tests/e2e/test_training_e2e.py::test_resume_determinism_2gpu_fsdp
  • The e2e tests use a small learnable on-disk dataset on purpose: random tokens give a near-flat loss that masks the momentum reset (the test would pass even on the buggy code).
  • Docs: added a ### Fixed entry to CHANGELOG.md; corrected docs/checkpointing/dcp-model.md, which documented the old raw-state_dict() save/load and explained the empty-template behavior as if it were correct.
  • Compatibility note: optimizer state is now keyed by parameter fully-qualified name rather than positional index. Checkpoints written before this fix will not restore optimizer state on resume (training continues with a fresh optimizer); model state is unaffected.

Testing

  • uv run ruff check kempnerforge/ tests/ passes
  • uv run ruff format --check kempnerforge/ tests/ scripts/ passes
  • uv run pyright kempnerforge/ passes (0 errors)
  • uv run pytest tests/unit/ -v --timeout=60 passes (unit suite unaffected)
  • uv run torchrun --nproc_per_node=4 -m pytest tests/distributed/test_checkpoint.py -v — new moment test passes; existing 8 checkpoint tests still pass
  • uv run pytest tests/e2e/test_training_e2e.py -k resume_determinism --e2e -v — single-GPU and 2-GPU FSDP resume bit-exact
  • Confirmed fail-on-bug: all four new tests fail on main (pre-fix), pass on this branch
  • Tested on 1 and 2 GPUs (single-GPU + FSDP-2 resume paths)

Closes #114

@amazloumi amazloumi requested review from Naeemkh and mmshad May 27, 2026 16:24
@codecov
Copy link
Copy Markdown

codecov Bot commented May 27, 2026

Codecov Report

❌ Patch coverage is 81.81818% with 2 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
kempnerforge/checkpoint/manager.py 81.81% 0 Missing and 2 partials ⚠️
Files with missing lines Coverage Δ
kempnerforge/checkpoint/manager.py 89.65% <81.81%> (+0.13%) ⬆️

... and 1 file with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Resume is not bit-exact: AdamW optimizer state not restored to full precision

1 participant