Skip to content

fix: make global DOS inference work on the dpmodel and JAX backends#5722

Open
wanghan-iapcm wants to merge 5 commits into
deepmodeling:masterfrom
wanghan-iapcm:fix-deepdos-redu
Open

fix: make global DOS inference work on the dpmodel and JAX backends#5722
wanghan-iapcm wants to merge 5 commits into
deepmodeling:masterfrom
wanghan-iapcm:fix-deepdos-redu

Conversation

@wanghan-iapcm

@wanghan-iapcm wanghan-iapcm commented Jul 3, 2026

Copy link
Copy Markdown
Collaborator

Problem

Fixes #5674. Three related defects in global DOS inference, uncovered in sequence:

  1. DeepDOS.eval unconditionally read the atomic dos output and summed it, even for atomic=False. The dpmodel and JAX backends only return the atomic output when atomic=True; on the global-DOS-only path (e.g. dp test without atomic DOS labels) results["dos"] raised KeyError. TF and PyTorch masked this because they always include the atomic output.

  2. Fixing the KeyError exposed that dpmodel and JAX DOS inference was broken more deeply: both DeepEval.get_numb_dos implementations hard-returned 0, so the DOS reshape target was (nframes, 0) and inference failed on every path, not just the missing-key case.

  3. On TF, the global DOS did not equal the sum of the atomic DOS for multi-frame inputs — even though, by construction of the model, it must. deepmd/tf/model/dos.py reduced the atomic DOS with reshape([natoms[0], -1]) + reduce_sum(axis=0), which sums across the wrong axis and mixes atoms from different frames together. Single-frame inputs happened to give the right answer, so no test caught it.

Fix

Backend-agnostic (deep_dos.py): prefer the atomic dos output and sum it whenever the backend returns it (this is the exact global DOS on TF/PT, whose reduced output is not necessarily the plain sum), and fall back to the reduced dos_redu only when the atomic output is absent (dpmodel/JAX at atomic=False). Reading dos unconditionally is what raised the original KeyError.

dpmodel: add get_numb_dos to the dpmodel DOSModel (mirroring the PyTorch model), add a default get_numb_dos returning 0 on the shared base model so non-DOS models can still be serialized, and delegate dpmodel/infer/deep_eval.py:get_numb_dos to the model.

JAX: the evaluator wraps a deserialized HLO object with no live model, so numb_dos is now persisted into the StableHLO export constants and exposed via HLO.get_numb_dos; the dos output is registered in the HLO OUTPUT_DEFS table; and jax/infer/deep_eval.py:get_numb_dos delegates to the model. With these, JAX DOS inference works end to end.

TF: reduce the atomic DOS per frame — reshape([-1, natoms[0], numb_dos]) + reduce_sum(axis=1), mirroring the energy model — so the global DOS equals the atomic sum for multi-frame inputs.

Test

  • source/tests/common/dpmodel/test_deep_dos.py: builds a dpmodel DOS model and evaluates it — atomic=False returns the global DOS (KeyError on master), and the global DOS equals the sum of the atomic DOS (guarding the dos_redu == sum(dos) invariant relied on by all backends).
  • source/tests/jax/test_deep_dos.py: exports a DOS model to .hlo, checks numb_dos survives the round trip, and evaluates the global DOS.
  • source/tests/tf/test_model_dos.py: adds test_multiframe_global_equals_atomic_sum, which builds a two-frame DOS graph and asserts the global DOS equals the per-frame atomic sum — this fails on the old axis-0 reduction and passes after the per-frame fix. The existing single-frame assertions were updated to the corrected output shapes.

dpmodel and JAX DOS inference previously had no test, and the TF path had only single-frame coverage; DOS was effectively exercised only where the bugs were masked.

Han Wang added 2 commits July 3, 2026 15:39
DeepDOS.eval always read the atomic "dos" output and summed it to obtain the
global DOS, even for atomic=False. The dpmodel and JAX backends only return the
atomic OUT variables when atomic=True; for atomic=False they return the reduced
"dos_redu" instead, so reading results["dos"] raised KeyError on a global-DOS-
only path (e.g. dp test without atomic DOS labels).

Read "dos_redu" for the global DOS and only read/sum the atomic "dos" when
atomic=True. This also removes a stale "not same as dos_redu" comment: the
reduced output is the sum of the atomic DOS by construction.

Fixing that exposed a second, compounding bug: dpmodel DeepEval.get_numb_dos
hard-returned 0, and the dpmodel DOS model did not expose get_numb_dos at all,
so the reshape target was (nframes, 0) and dpmodel DOS inference was broken for
both the atomic and non-atomic paths. Add get_numb_dos to the dpmodel DOSModel
(mirroring the PyTorch model) and delegate to it from the dpmodel DeepEval.

The JAX DeepEval also hard-returns 0, but its evaluator wraps an HLO object with
no live model, so numb_dos must be persisted through the StableHLO export to fix
it. That is a separate JAX-export change and is left as a follow-up; the JAX
get_numb_dos is unchanged here.

Adds source/tests/common/dpmodel/test_deep_dos.py, which builds a dpmodel DOS
model and evaluates it: atomic=False returns the global DOS (KeyError on master)
and the global DOS equals the sum of the atomic DOS (guarding the dos_redu ==
sum(dos) invariant for all backends). dpmodel/JAX DOS inference had no test.

Fix deepmodeling#5674
Follows up the dpmodel DOS fix: the JAX evaluator wraps a deserialized HLO
object with no live model, so DeepDOS.eval could not obtain numb_dos and JAX
DeepEval.get_numb_dos hard-returned 0, breaking the reduced-DOS reshape.

Persist numb_dos into the HLO export constants and expose HLO.get_numb_dos, add
a default get_numb_dos (0) on the dpmodel base model so non-DOS models still
export, register the "dos" output in the HLO OUTPUT_DEFS table, and delegate
JAX DeepEval.get_numb_dos to the model. With these, JAX DOS inference works
end to end.

Adds source/tests/jax/test_deep_dos.py: exports a DOS model to .hlo, checks
numb_dos survives the round trip, and evaluates global DOS (atomic=False).

Fix deepmodeling#5674
@dosubot dosubot Bot added the bug label Jul 3, 2026
@github-actions github-actions Bot added the Python label Jul 3, 2026
@wanghan-iapcm wanghan-iapcm requested a review from njzjz July 3, 2026 07:48
@coderabbitai

coderabbitai Bot commented Jul 3, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

📝 Walkthrough

Walkthrough

DeepDOS now uses reduced DOS output for non-atomic inference, while DOS count is propagated through dpmodel, JAX, HLO export, TensorFlow reduction, and new tests.

Changes

DeepDOS DOS output flow

Layer / File(s) Summary
Reduced DOS inference
deepmd/infer/deep_dos.py
eval branches between atomic dos and reduced dos_redu when producing global DOS values.
DOS count accessors
deepmd/dpmodel/model/dos_model.py, deepmd/dpmodel/model/make_model.py, deepmd/dpmodel/infer/deep_eval.py, deepmd/jax/infer/deep_eval.py
DOSModel, CM, and both DeepEval implementations expose get_numb_dos(), and make_model.py adds the base method plus docstring formatting changes.
HLO DOS export
deepmd/jax/model/hlo.py, deepmd/jax/utils/serialization.py
HLO adds a dos output, stores numb_dos, exposes get_numb_dos(), and exports numb_dos in serialized constants.
TensorFlow DOS reduction and tests
deepmd/tf/model/dos.py, source/tests/tf/test_model_dos.py, source/tests/tf/test_model_dos_multiframe.py
TensorFlow DOS reduction preserves frame boundaries, and tests reshape outputs for single-frame and multi-frame comparisons.
dpmodel and JAX DeepDOS tests
source/tests/common/dpmodel/test_deep_dos.py, source/tests/jax/test_deep_dos.py
New tests cover reduced-output inference, DOS shape checks, atomic/global consistency, and DOS count preservation through export.

Estimated code review effort: 3 (Moderate) | ~30 minutes

Suggested reviewers: njzjz, iProzd

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Linked Issues check ✅ Passed The changes use dos_redu for non-atomic inference and preserve atomic DOS, matching #5674's requirements.
Out of Scope Changes check ✅ Passed The TensorFlow and test updates are supporting DOS behavior changes, not unrelated scope creep.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly summarizes the main change: fixing global DOS inference on the dpmodel and JAX backends.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
source/tests/jax/test_deep_dos.py (1)

53-81: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win

Consider adding an atomic-sum consistency check.

The companion dpmodel test (source/tests/common/dpmodel/test_deep_dos.py) has test_global_matches_atomic_sum, verifying the reduced global DOS equals the sum of per-atom DOS. This JAX test only checks shape and export survival, not that the reduced/atomic outputs stay consistent through HLO export.

♻️ Suggested additional test
     def test_global_dos_only(self) -> None:
         (dos,) = self.dp.eval(self.coords, self.cells, self.atypes, atomic=False)
         self.assertEqual(dos.shape, (1, 2))
+
+    def test_global_matches_atomic_sum(self) -> None:
+        (dos,) = self.dp.eval(self.coords, self.cells, self.atypes, atomic=False)
+        _, atomic_dos = self.dp.eval(self.coords, self.cells, self.atypes, atomic=True)
+        np.testing.assert_allclose(dos, np.sum(atomic_dos, axis=1))
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/tests/jax/test_deep_dos.py` around lines 53 - 81, Add an atomic-sum
consistency test to TestDeepDOSJAX so the exported HLO path is verified beyond
shape/export survival. Use the existing DeepDOS setup in setUp and compare the
global result from DeepDOS.eval(..., atomic=False) against the sum of the atomic
per-atom DOS returned by the atomic=True path, similar to
test_global_matches_atomic_sum in the dpmodel DeepDOS tests. Keep the assertions
in the same test class so this JAX-specific export path is checked for numerical
consistency.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@source/tests/jax/test_deep_dos.py`:
- Around line 53-81: Add an atomic-sum consistency test to TestDeepDOSJAX so the
exported HLO path is verified beyond shape/export survival. Use the existing
DeepDOS setup in setUp and compare the global result from DeepDOS.eval(...,
atomic=False) against the sum of the atomic per-atom DOS returned by the
atomic=True path, similar to test_global_matches_atomic_sum in the dpmodel
DeepDOS tests. Keep the assertions in the same test class so this JAX-specific
export path is checked for numerical consistency.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 2967c3c5-59f6-4050-9d37-087b0c2792cf

📥 Commits

Reviewing files that changed from the base of the PR and between dd38b35 and a7da1dd.

📒 Files selected for processing (9)
  • deepmd/dpmodel/infer/deep_eval.py
  • deepmd/dpmodel/model/dos_model.py
  • deepmd/dpmodel/model/make_model.py
  • deepmd/infer/deep_dos.py
  • deepmd/jax/infer/deep_eval.py
  • deepmd/jax/model/hlo.py
  • deepmd/jax/utils/serialization.py
  • source/tests/common/dpmodel/test_deep_dos.py
  • source/tests/jax/test_deep_dos.py

@codecov

codecov Bot commented Jul 3, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 81.14%. Comparing base (dd38b35) to head (d1e8cd7).
⚠️ Report is 3 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5722      +/-   ##
==========================================
- Coverage   81.26%   81.14%   -0.12%     
==========================================
  Files         988      988              
  Lines      110877   110894      +17     
  Branches     4234     4236       +2     
==========================================
- Hits        90103    89987     -116     
- Misses      19249    19382     +133     
  Partials     1525     1525              

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Han Wang added 3 commits July 4, 2026 10:23
The previous change made DeepDOS.eval always use the reduced dos_redu for the
global DOS. That regressed the TensorFlow (and PyTorch) backends, whose reduced
output is not the plain sum of the atomic DOS, so test_deepdos.py's global DOS
no longer matched the atomic sum.

Prefer summing the atomic dos output when it is present, which restores the
original global DOS for backends that always return it, and fall back to
dos_redu only when the atomic dos is absent (the dpmodel and JAX atomic=False
path that raised KeyError). This keeps the deepmodeling#5674 fix for dpmodel/JAX while
leaving TF/PT behavior unchanged.

Fix deepmodeling#5674
The TensorFlow DOS model built the global DOS by reshaping the atomic DOS to
[nloc, -1] and summing over axis 0. For a single frame this equals the sum of
the atomic DOS, but for multiple frames the reshape mixes frames together, so
the global DOS (o_dos / dos_redu) no longer equaled the per-frame sum of the
atomic DOS. dpmodel and the PyTorch backend already reduce per frame correctly.

Reshape the atomic DOS to [nframes, nloc, numb_dos] and reduce over the atom
axis, mirroring the energy model. o_dos is now [nframes, numb_dos] and
o_atom_dos [nframes, natoms, numb_dos]; the backend reshapes these to the same
final shapes, and only freshly built models are affected (frozen models are
unchanged).

Adds source/tests/tf/test_model_dos_multiframe.py, which builds the DOS model
graph, feeds two frames, and asserts the global DOS equals the per-frame sum of
the atomic DOS (fails before the fix). Updates test_model_dos.py's single-frame
assertions for the new leading frame dimension.

Fix deepmodeling#5674
Move the multi-frame global-vs-atomic-sum DOS test into TestModel in
test_model_dos.py as a second method and remove the standalone
test_model_dos_multiframe.py, so both DOS model graph tests share the same
gen_data/del_data fixture and live together.
Comment thread deepmd/infer/deep_dos.py
Comment on lines +127 to +137
# Prefer summing the atomic `dos` output when it is present, preserving
# the original global DOS for backends that always return it (TF, PT),
# whose reduced output is not necessarily the plain sum of the atomic
# DOS. The dpmodel and JAX backends omit the atomic `dos` when
# atomic=False, so fall back to the reduced `dos_redu` there (reading
# `dos` unconditionally would raise KeyError).
if "dos" in results:
atomic_energy = results["dos"].reshape(nframes, natoms, self.get_numb_dos())
energy = np.sum(atomic_energy, axis=1)
if atomic:
return (energy, atomic_energy)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may need to review the implementation of dpmodel and fix it, instead of adding a workaround here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Code scan] Use reduced DOS outputs for non-atomic DeepDOS inference

2 participants