fix: make global DOS inference work on the dpmodel and JAX backends#5722
fix: make global DOS inference work on the dpmodel and JAX backends#5722wanghan-iapcm wants to merge 5 commits into
Conversation
DeepDOS.eval always read the atomic "dos" output and summed it to obtain the global DOS, even for atomic=False. The dpmodel and JAX backends only return the atomic OUT variables when atomic=True; for atomic=False they return the reduced "dos_redu" instead, so reading results["dos"] raised KeyError on a global-DOS- only path (e.g. dp test without atomic DOS labels). Read "dos_redu" for the global DOS and only read/sum the atomic "dos" when atomic=True. This also removes a stale "not same as dos_redu" comment: the reduced output is the sum of the atomic DOS by construction. Fixing that exposed a second, compounding bug: dpmodel DeepEval.get_numb_dos hard-returned 0, and the dpmodel DOS model did not expose get_numb_dos at all, so the reshape target was (nframes, 0) and dpmodel DOS inference was broken for both the atomic and non-atomic paths. Add get_numb_dos to the dpmodel DOSModel (mirroring the PyTorch model) and delegate to it from the dpmodel DeepEval. The JAX DeepEval also hard-returns 0, but its evaluator wraps an HLO object with no live model, so numb_dos must be persisted through the StableHLO export to fix it. That is a separate JAX-export change and is left as a follow-up; the JAX get_numb_dos is unchanged here. Adds source/tests/common/dpmodel/test_deep_dos.py, which builds a dpmodel DOS model and evaluates it: atomic=False returns the global DOS (KeyError on master) and the global DOS equals the sum of the atomic DOS (guarding the dos_redu == sum(dos) invariant for all backends). dpmodel/JAX DOS inference had no test. Fix deepmodeling#5674
Follows up the dpmodel DOS fix: the JAX evaluator wraps a deserialized HLO object with no live model, so DeepDOS.eval could not obtain numb_dos and JAX DeepEval.get_numb_dos hard-returned 0, breaking the reduced-DOS reshape. Persist numb_dos into the HLO export constants and expose HLO.get_numb_dos, add a default get_numb_dos (0) on the dpmodel base model so non-DOS models still export, register the "dos" output in the HLO OUTPUT_DEFS table, and delegate JAX DeepEval.get_numb_dos to the model. With these, JAX DOS inference works end to end. Adds source/tests/jax/test_deep_dos.py: exports a DOS model to .hlo, checks numb_dos survives the round trip, and evaluates global DOS (atomic=False). Fix deepmodeling#5674
📝 WalkthroughWalkthroughDeepDOS now uses reduced DOS output for non-atomic inference, while DOS count is propagated through dpmodel, JAX, HLO export, TensorFlow reduction, and new tests. ChangesDeepDOS DOS output flow
Estimated code review effort: 3 (Moderate) | ~30 minutes Suggested reviewers: 🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
source/tests/jax/test_deep_dos.py (1)
53-81: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winConsider adding an atomic-sum consistency check.
The companion dpmodel test (
source/tests/common/dpmodel/test_deep_dos.py) hastest_global_matches_atomic_sum, verifying the reduced global DOS equals the sum of per-atom DOS. This JAX test only checks shape and export survival, not that the reduced/atomic outputs stay consistent through HLO export.♻️ Suggested additional test
def test_global_dos_only(self) -> None: (dos,) = self.dp.eval(self.coords, self.cells, self.atypes, atomic=False) self.assertEqual(dos.shape, (1, 2)) + + def test_global_matches_atomic_sum(self) -> None: + (dos,) = self.dp.eval(self.coords, self.cells, self.atypes, atomic=False) + _, atomic_dos = self.dp.eval(self.coords, self.cells, self.atypes, atomic=True) + np.testing.assert_allclose(dos, np.sum(atomic_dos, axis=1))🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@source/tests/jax/test_deep_dos.py` around lines 53 - 81, Add an atomic-sum consistency test to TestDeepDOSJAX so the exported HLO path is verified beyond shape/export survival. Use the existing DeepDOS setup in setUp and compare the global result from DeepDOS.eval(..., atomic=False) against the sum of the atomic per-atom DOS returned by the atomic=True path, similar to test_global_matches_atomic_sum in the dpmodel DeepDOS tests. Keep the assertions in the same test class so this JAX-specific export path is checked for numerical consistency.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@source/tests/jax/test_deep_dos.py`:
- Around line 53-81: Add an atomic-sum consistency test to TestDeepDOSJAX so the
exported HLO path is verified beyond shape/export survival. Use the existing
DeepDOS setup in setUp and compare the global result from DeepDOS.eval(...,
atomic=False) against the sum of the atomic per-atom DOS returned by the
atomic=True path, similar to test_global_matches_atomic_sum in the dpmodel
DeepDOS tests. Keep the assertions in the same test class so this JAX-specific
export path is checked for numerical consistency.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 2967c3c5-59f6-4050-9d37-087b0c2792cf
📒 Files selected for processing (9)
deepmd/dpmodel/infer/deep_eval.pydeepmd/dpmodel/model/dos_model.pydeepmd/dpmodel/model/make_model.pydeepmd/infer/deep_dos.pydeepmd/jax/infer/deep_eval.pydeepmd/jax/model/hlo.pydeepmd/jax/utils/serialization.pysource/tests/common/dpmodel/test_deep_dos.pysource/tests/jax/test_deep_dos.py
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #5722 +/- ##
==========================================
- Coverage 81.26% 81.14% -0.12%
==========================================
Files 988 988
Lines 110877 110894 +17
Branches 4234 4236 +2
==========================================
- Hits 90103 89987 -116
- Misses 19249 19382 +133
Partials 1525 1525 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
The previous change made DeepDOS.eval always use the reduced dos_redu for the global DOS. That regressed the TensorFlow (and PyTorch) backends, whose reduced output is not the plain sum of the atomic DOS, so test_deepdos.py's global DOS no longer matched the atomic sum. Prefer summing the atomic dos output when it is present, which restores the original global DOS for backends that always return it, and fall back to dos_redu only when the atomic dos is absent (the dpmodel and JAX atomic=False path that raised KeyError). This keeps the deepmodeling#5674 fix for dpmodel/JAX while leaving TF/PT behavior unchanged. Fix deepmodeling#5674
The TensorFlow DOS model built the global DOS by reshaping the atomic DOS to [nloc, -1] and summing over axis 0. For a single frame this equals the sum of the atomic DOS, but for multiple frames the reshape mixes frames together, so the global DOS (o_dos / dos_redu) no longer equaled the per-frame sum of the atomic DOS. dpmodel and the PyTorch backend already reduce per frame correctly. Reshape the atomic DOS to [nframes, nloc, numb_dos] and reduce over the atom axis, mirroring the energy model. o_dos is now [nframes, numb_dos] and o_atom_dos [nframes, natoms, numb_dos]; the backend reshapes these to the same final shapes, and only freshly built models are affected (frozen models are unchanged). Adds source/tests/tf/test_model_dos_multiframe.py, which builds the DOS model graph, feeds two frames, and asserts the global DOS equals the per-frame sum of the atomic DOS (fails before the fix). Updates test_model_dos.py's single-frame assertions for the new leading frame dimension. Fix deepmodeling#5674
Move the multi-frame global-vs-atomic-sum DOS test into TestModel in test_model_dos.py as a second method and remove the standalone test_model_dos_multiframe.py, so both DOS model graph tests share the same gen_data/del_data fixture and live together.
| # Prefer summing the atomic `dos` output when it is present, preserving | ||
| # the original global DOS for backends that always return it (TF, PT), | ||
| # whose reduced output is not necessarily the plain sum of the atomic | ||
| # DOS. The dpmodel and JAX backends omit the atomic `dos` when | ||
| # atomic=False, so fall back to the reduced `dos_redu` there (reading | ||
| # `dos` unconditionally would raise KeyError). | ||
| if "dos" in results: | ||
| atomic_energy = results["dos"].reshape(nframes, natoms, self.get_numb_dos()) | ||
| energy = np.sum(atomic_energy, axis=1) | ||
| if atomic: | ||
| return (energy, atomic_energy) |
There was a problem hiding this comment.
We may need to review the implementation of dpmodel and fix it, instead of adding a workaround here.
Problem
Fixes #5674. Three related defects in global DOS inference, uncovered in sequence:
DeepDOS.evalunconditionally read the atomicdosoutput and summed it, even foratomic=False. The dpmodel and JAX backends only return the atomic output whenatomic=True; on the global-DOS-only path (e.g.dp testwithout atomic DOS labels)results["dos"]raisedKeyError. TF and PyTorch masked this because they always include the atomic output.Fixing the
KeyErrorexposed that dpmodel and JAX DOS inference was broken more deeply: bothDeepEval.get_numb_dosimplementations hard-returned0, so the DOS reshape target was(nframes, 0)and inference failed on every path, not just the missing-key case.On TF, the global DOS did not equal the sum of the atomic DOS for multi-frame inputs — even though, by construction of the model, it must.
deepmd/tf/model/dos.pyreduced the atomic DOS withreshape([natoms[0], -1])+reduce_sum(axis=0), which sums across the wrong axis and mixes atoms from different frames together. Single-frame inputs happened to give the right answer, so no test caught it.Fix
Backend-agnostic (
deep_dos.py): prefer the atomicdosoutput and sum it whenever the backend returns it (this is the exact global DOS on TF/PT, whose reduced output is not necessarily the plain sum), and fall back to the reduceddos_reduonly when the atomic output is absent (dpmodel/JAX atatomic=False). Readingdosunconditionally is what raised the originalKeyError.dpmodel: add
get_numb_dosto the dpmodelDOSModel(mirroring the PyTorch model), add a defaultget_numb_dosreturning 0 on the shared base model so non-DOS models can still be serialized, and delegatedpmodel/infer/deep_eval.py:get_numb_dosto the model.JAX: the evaluator wraps a deserialized
HLOobject with no live model, sonumb_dosis now persisted into the StableHLO export constants and exposed viaHLO.get_numb_dos; thedosoutput is registered in the HLOOUTPUT_DEFStable; andjax/infer/deep_eval.py:get_numb_dosdelegates to the model. With these, JAX DOS inference works end to end.TF: reduce the atomic DOS per frame —
reshape([-1, natoms[0], numb_dos])+reduce_sum(axis=1), mirroring the energy model — so the global DOS equals the atomic sum for multi-frame inputs.Test
source/tests/common/dpmodel/test_deep_dos.py: builds a dpmodel DOS model and evaluates it —atomic=Falsereturns the global DOS (KeyErroron master), and the global DOS equals the sum of the atomic DOS (guarding thedos_redu == sum(dos)invariant relied on by all backends).source/tests/jax/test_deep_dos.py: exports a DOS model to.hlo, checksnumb_dossurvives the round trip, and evaluates the global DOS.source/tests/tf/test_model_dos.py: addstest_multiframe_global_equals_atomic_sum, which builds a two-frame DOS graph and asserts the global DOS equals the per-frame atomic sum — this fails on the old axis-0 reduction and passes after the per-frame fix. The existing single-frame assertions were updated to the corrected output shapes.dpmodel and JAX DOS inference previously had no test, and the TF path had only single-frame coverage; DOS was effectively exercised only where the bugs were masked.