Skip to content

feat(tf2): add training workflow#5735

Open
njzjz wants to merge 12 commits into
deepmodeling:masterfrom
njzjz:feat/tf2-train-current
Open

feat(tf2): add training workflow#5735
njzjz wants to merge 12 commits into
deepmodeling:masterfrom
njzjz:feat/tf2-train-current

Conversation

@njzjz

@njzjz njzjz commented Jul 5, 2026

Copy link
Copy Markdown
Member

Summary

  • add TensorFlow 2 training/freeze/compress entrypoints and trainer support
  • wire multi-task, validation, finetune/pretrain, TensorBoard, checkpointing, and bias adjustment paths
  • keep TF2 training atomic virial disabled and add performance optimizations, including DP_JIT-controlled lower-forward XLA compilation

Validation

  • python -m pytest source/tests/tf2/test_training.py -q
  • ruff check .
  • ruff format --check deepmd/tf2 source/tests/tf2
  • DP_JIT=1 srun --gres=gpu:1 dp --tf2 train input.json --skip-neighbor-stat on examples/water/se_e2_a temp copy: XLA lower compiled; first step ~140s, steady windows ~0.0305 s/step through step 900 before timeout
  • DP_JIT=1 srun --gres=gpu:1 dp --tf2 train input.json --skip-neighbor-stat on fixed-selection dpa3 temp input: XLA lower compiled; first step ~227s, steady windows ~0.119 s/step through step 500 before timeout

Notes

  • DP_JIT is opt-in for training lower forward only; the whole train/eval step is intentionally not XLA compiled because neighbor/outer training logic is too broad and previously hit unsupported XLA ops.
  • dpa3 lower JIT has high first-compile cost and high memory pressure, so benchmark numbers above separate warm steady-state from first-step compile cost.

Summary by CodeRabbit

  • New Features
    • Expanded TensorFlow 2 backend with new train, freeze, and checkpoint compression entry points, including full validation support.
    • Added multi-task model-branch selection for freeze/compress via --head / --model-branch.
    • Introduced optional TF2 JIT acceleration and TF2-specific automatic batch sizing.
  • Bug Fixes
    • Improved compression dtype handling for more consistent outputs.
    • Enabled additional TF2 runtime neighbor-statistics and corrected TF2 backend wiring so more workflows complete end-to-end.

Copilot AI review requested due to automatic review settings July 5, 2026 05:18
@dosubot dosubot Bot added the new feature label Jul 5, 2026

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot was unable to review this pull request because the user who requested the review has reached their quota limit.

@github-actions github-actions Bot added the Python label Jul 5, 2026
Comment thread deepmd/tf2/train/trainer.py Fixed
Comment thread source/tests/tf2/test_training.py Fixed
@coderabbitai

coderabbitai Bot commented Jul 5, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

TF2 backend hooks, command entry points, eager training/validation, checkpoint serialization, shared-parameter utilities, and TF2 test coverage were added. Descriptor compression now normalizes dtypes with a NumPy conversion helper.

Changes

TensorFlow2 Backend and Training Pipeline

Layer / File(s) Summary
Backend registration and CLI wiring
deepmd/backend/tf2.py, deepmd/main.py, deepmd/tf2/entrypoints/__init__.py, deepmd/utils/argcheck.py
TF2 backend advertises ENTRY_POINT/NEIGHBOR_STAT, CLI help and args include TF2 output/input forms and --head/--model-branch, entrypoints are exported, and backend docs list TensorFlow2 support.
Freeze, compress, and dispatcher entry points
deepmd/tf2/entrypoints/compress.py, deepmd/tf2/entrypoints/freeze.py, deepmd/tf2/entrypoints/main.py
New TF2 freeze, enable_compression, and main dispatcher flows load serialized checkpoints, select branches, and write TF2 outputs.
Train entrypoint and neighbor-stat update
deepmd/tf2/entrypoints/train.py
TF2 training entrypoint preprocessing, stat-file setup, train(...), and update_sel(...) are added.
Model call and derivative computation refactor
deepmd/tf2/make_model.py, deepmd/tf2/model/base_model.py, deepmd/tf2/model/dp_model.py, deepmd/tf2/atomic_model/dp_atomic_model.py
Neighbor-list lowering, derivative toggles, formatted/JIT call routing, atomic derivative refactoring, and TF2 atom-mask behavior are added.
Trainer initialization, step, checkpointing
deepmd/tf2/train/__init__.py, deepmd/tf2/train/trainer.py
TF2 training driver, compiled steps, checkpoint persistence, finetune support, and training-state handling are added.
Full validation support
deepmd/tf2/train/validation.py, deepmd/tf2/utils/auto_batch_size.py
TF2 full validation batching and metric aggregation are added, with TF2-specific GPU/OOM detection.
Multi-task sharing, finetune, and JIT utilities
deepmd/tf2/utils/multi_task.py, deepmd/tf2/utils/finetune.py, deepmd/tf2/utils/jit.py
Shared-parameter preprocessing/application, TF2 finetune rule loading, and DP_JIT helpers are added.
Checkpoint-to-dict serialization
deepmd/tf2/utils/serialization.py
TF2 checkpoints are serialized back to model dicts with restored metadata and shared-link handling.
TF2 training tests
source/tests/tf2/test_training.py
Tests cover TF2 derivative handling, call routing, trainer behavior, JIT wrapping, checkpointing, serialization, and entrypoint integration.

Compression Dtype Normalization

Layer / File(s) Summary
to_numpy_dtype helper and descriptor usages
deepmd/dpmodel/common.py, deepmd/dpmodel/descriptor/dpa1.py, deepmd/dpmodel/descriptor/se_e2_a.py, deepmd/dpmodel/descriptor/se_r.py
to_numpy_dtype is added and used when writing descriptor compression data.

Estimated code review effort: 5 (Critical) | ~120 minutes

Sequence Diagram(s)

sequenceDiagram
  participant deepmd.main
  participant TF2TrainEntrypoint
  participant Trainer
  participant TF2FullValidator
  participant serialization

  deepmd.main->>TF2TrainEntrypoint: run_training(options)
  TF2TrainEntrypoint->>TF2TrainEntrypoint: update_sel(jdata)
  TF2TrainEntrypoint->>Trainer: construct DPTrainer(...)
  loop training steps
    Trainer->>Trainer: train_step / evaluate_training
    Trainer->>TF2FullValidator: run_full_validation
    Trainer->>serialization: save_checkpoint(training_state.json)
  end
Loading
sequenceDiagram
  participant deepmd.tf2.entrypoints.main
  participant freeze
  participant compress
  participant serialization

  deepmd.tf2.entrypoints.main->>freeze: freeze(checkpoint_folder, output, head)
  freeze->>serialization: serialize_from_file(checkpoint_folder)
  deepmd.tf2.entrypoints.main->>compress: enable_compression(input_file, output, head)
  compress->>serialization: serialize_from_file(input_file)
Loading

Possibly related PRs

Suggested labels: enhancement

Suggested reviewers: wanghan-iapcm

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title is concise and clearly reflects the main TF2 training workflow additions, even though it does not mention freeze/compress entrypoints.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
deepmd/tf2/model/dp_model.py (1)

119-152: 🎯 Functional Correctness | 🟡 Minor | ⚡ Quick win

Forward do_deriv_c on the non-formatted lower path (deepmd/tf2/model/dp_model.py:141-152). neighbor_list callers can hit this branch with do_deriv_c=False, but super().call_common_lower(...) doesn’t take the flag, so the virial-derivative suppression is ignored here.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/tf2/model/dp_model.py` around lines 119 - 152, The non-formatted lower
path in call_common_lower is dropping the do_deriv_c argument, so requests that
disable virial derivatives still behave as if it were enabled. Update the
dp_model.py implementation in the call_common_lower branch to forward do_deriv_c
through the super().call_common_lower call, matching the formatted path and
preserving neighbor_list behavior for callers that pass do_deriv_c=False.
🧹 Nitpick comments (3)
deepmd/tf2/entrypoints/freeze.py (1)

8-10: 🚀 Performance & Scalability | 🔵 Trivial | ⚡ Quick win

Avoid deep-copying the entire multi-task payload just to swap two keys.

deepcopy(data) duplicates every other branch's weights in memory (and this cost is inherited by compress.py too, since it calls the same function). Since model/model_def_script are fully reassigned rather than mutated, a shallow copy is sufficient.

♻️ Proposed fix to avoid unnecessary deep copy
-from copy import (
-    deepcopy,
-)
 from typing import (
     Any,
 )
@@
-    selected = deepcopy(data)
+    selected = data.copy()
     selected["model"] = data["model"]["model_dict"][resolved_head]
     selected["model_def_script"] = model_def_script["model_dict"][resolved_head]

Also applies to: 49-79

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/tf2/entrypoints/freeze.py` around lines 8 - 10, The multi-task payload
handling in the freeze path is doing an unnecessary deepcopy when only swapping
the model-related keys. Update the shared payload-copy logic in freeze.py (and
the helper it is reused by compress.py) to use a shallow copy instead of
deepcopy(data), since model and model_def_script are fully reassigned rather
than mutated. Keep the change localized to the function that prepares the data
for freezing so it preserves the existing behavior without duplicating the rest
of the task branches in memory.
deepmd/tf2/utils/multi_task.py (1)

396-411: 📐 Maintainability & Code Quality | 🔵 Trivial | 🏗️ Heavy lift

Attribute-sharing via automatic isinstance discovery with an exclude-list is fragile.

_share_tf2_state_attrs shares any top-level, non-underscore, non-excluded attribute on link_class that happens to be a tf.Module/tf.Variable/tf.Tensor/xp.Array. This is an "opt-out" design: any new state attribute added to a future fitting-net implementation that shouldn't be shared across multi-task branches (e.g. a per-branch scale or embedding) will be silently and incorrectly shared unless someone remembers to add it to excluded. An explicit allowlist of shareable attribute names (mirroring the current _merge_and_share_param_stats calls) would be safer against silent regressions as fitting-net types evolve.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/tf2/utils/multi_task.py` around lines 396 - 411, The automatic sharing
in _share_tf2_state_attrs is too broad because it discovers shareable attributes
by isinstance and excludes only a small denylist. Replace this with an explicit
allowlist of attribute names that are safe to share, similar to the existing
_merge_and_share_param_stats usage, and have the sharing loop only copy those
named attributes from base_class to link_class. Keep _is_shareable_tf2_state as
a guard if needed, but do not rely on it for discovery.
deepmd/tf2/train/trainer.py (1)

995-1015: 🩺 Stability & Availability | 🔵 Trivial | ⚖️ Poor tradeoff

_write_checkpoint_directory mutates the shared self.step variable as a side effect.

self.step.assign(step) here reassigns the same tf.Variable that is part of the primary self.checkpoint/self.checkpoint_manager (used by save_checkpoint), purely to persist a value into the separate "best" validation-checkpoint directory. Since self.step is shared state, if training exits or another checkpoint-related read occurs between a run_full_validation() call and the next regular train_step/save_checkpoint() call, self.step will transiently hold the full-validation display_step value rather than the actual global training step. It self-heals on the next train_step() (Line 847), so the practical window is narrow, but it's fragile shared-mutable-state design that could bite if display_step and the true global step ever diverge (e.g. custom full-validation cadence).

Consider using a throwaway tf.Variable/separate Checkpoint(step=..., model=..., optimizer=...) object for the "best" checkpoint save path instead of mutating the primary training self.step.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/tf2/train/trainer.py` around lines 995 - 1015, The
_write_checkpoint_directory helper is mutating the shared training step via
self.step.assign(step) just to write the separate validation checkpoint. Update
_save_full_validation_checkpoint/_write_checkpoint_directory so the “best”
checkpoint path uses its own temporary step variable or separate Checkpoint
object instead of changing self.step, and keep the primary checkpoint state used
by save_checkpoint and train_step untouched.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/tf2/make_model.py`:
- Around line 179-184: The coord correction remapping in make_model should
reshape coord_corr_for_virial before calling xp.take_along_axis, since the
current flattened nframes x (nloc x 3) tensor is invalid for axis-based lookup.
Update the coord_corr handling block to convert the correction tensor to
(nframes, nloc, 3) first, then apply the existing mapping_idx remap logic so
extended_coord_corr is built from a 3D input.

In `@deepmd/tf2/train/trainer.py`:
- Around line 384-403: The checkpoint restore flow in trainer initialization is
dropping the `CheckpointLoadStatus` too early, which can hide deferred
mismatches for `restart_model`. Update the `Trainer` restore path used by
`_restore_checkpoint` so it retains the status object from
`self.checkpoint.restore(...).expect_partial()`, then defer the final
consumption check until after `_build_optimizer_slots()` has run. Ensure the
`restart_model` branch still sets `self.start_step` from `self.step` only after
the restore status has been preserved and validated.

In `@deepmd/tf2/utils/serialization.py`:
- Around line 606-616: The checkpoint resolution logic in
_resolve_checkpoint_path is mistakenly treating existing directories as prefix
files because it uses path.exists(), which can send metadata lookup to the wrong
parent directory. Update the fallback in _resolve_checkpoint_path so only real
checkpoint prefixes are accepted there, and ensure directory paths are handled
exclusively through the is_dir()/latest_checkpoint branch with the correct
state_dir. Keep the fix localized to _resolve_checkpoint_path and the returned
(path, state_dir) values.
- Around line 623-634: The restore flow in _restore_models_from_checkpoint
currently calls checkpoint.restore(...).expect_partial(), which suppresses
warnings but does not verify that the restored weights were actually matched to
the built models. Update this path to keep the restore status from
tf.train.Checkpoint.restore in _restore_models_from_checkpoint, ensure any lazy
variables in _build_models/_TaskModelContainer are materialized before
validation, and then call assert_existing_objects_matched() so mismatches
between the checkpoint and model object graph fail fast.

In `@source/tests/tf2/test_training.py`:
- Around line 49-51: Add the missing module-level timeout marker in the training
test module so it follows the same convention as the other training tests:
update the existing pytestmark definition in test_training to include
pytest.mark.timeout(60) alongside the current warning filter. Keep the marker
applied at module scope so all tests in this file are bounded consistently.

---

Outside diff comments:
In `@deepmd/tf2/model/dp_model.py`:
- Around line 119-152: The non-formatted lower path in call_common_lower is
dropping the do_deriv_c argument, so requests that disable virial derivatives
still behave as if it were enabled. Update the dp_model.py implementation in the
call_common_lower branch to forward do_deriv_c through the
super().call_common_lower call, matching the formatted path and preserving
neighbor_list behavior for callers that pass do_deriv_c=False.

---

Nitpick comments:
In `@deepmd/tf2/entrypoints/freeze.py`:
- Around line 8-10: The multi-task payload handling in the freeze path is doing
an unnecessary deepcopy when only swapping the model-related keys. Update the
shared payload-copy logic in freeze.py (and the helper it is reused by
compress.py) to use a shallow copy instead of deepcopy(data), since model and
model_def_script are fully reassigned rather than mutated. Keep the change
localized to the function that prepares the data for freezing so it preserves
the existing behavior without duplicating the rest of the task branches in
memory.

In `@deepmd/tf2/train/trainer.py`:
- Around line 995-1015: The _write_checkpoint_directory helper is mutating the
shared training step via self.step.assign(step) just to write the separate
validation checkpoint. Update
_save_full_validation_checkpoint/_write_checkpoint_directory so the “best”
checkpoint path uses its own temporary step variable or separate Checkpoint
object instead of changing self.step, and keep the primary checkpoint state used
by save_checkpoint and train_step untouched.

In `@deepmd/tf2/utils/multi_task.py`:
- Around line 396-411: The automatic sharing in _share_tf2_state_attrs is too
broad because it discovers shareable attributes by isinstance and excludes only
a small denylist. Replace this with an explicit allowlist of attribute names
that are safe to share, similar to the existing _merge_and_share_param_stats
usage, and have the sharing loop only copy those named attributes from
base_class to link_class. Keep _is_shareable_tf2_state as a guard if needed, but
do not rely on it for discovery.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 6a6ad7b4-539e-402a-be1a-2f04c7ec0928

📥 Commits

Reviewing files that changed from the base of the PR and between ffe57a3 and 8defd7c.

📒 Files selected for processing (25)
  • deepmd/backend/tf2.py
  • deepmd/dpmodel/common.py
  • deepmd/dpmodel/descriptor/dpa1.py
  • deepmd/dpmodel/descriptor/se_e2_a.py
  • deepmd/dpmodel/descriptor/se_r.py
  • deepmd/main.py
  • deepmd/tf2/atomic_model/dp_atomic_model.py
  • deepmd/tf2/entrypoints/__init__.py
  • deepmd/tf2/entrypoints/compress.py
  • deepmd/tf2/entrypoints/freeze.py
  • deepmd/tf2/entrypoints/main.py
  • deepmd/tf2/entrypoints/train.py
  • deepmd/tf2/make_model.py
  • deepmd/tf2/model/base_model.py
  • deepmd/tf2/model/dp_model.py
  • deepmd/tf2/train/__init__.py
  • deepmd/tf2/train/trainer.py
  • deepmd/tf2/train/validation.py
  • deepmd/tf2/utils/auto_batch_size.py
  • deepmd/tf2/utils/finetune.py
  • deepmd/tf2/utils/jit.py
  • deepmd/tf2/utils/multi_task.py
  • deepmd/tf2/utils/serialization.py
  • deepmd/utils/argcheck.py
  • source/tests/tf2/test_training.py

Comment thread deepmd/tf2/make_model.py
Comment thread deepmd/tf2/train/trainer.py
Comment thread deepmd/tf2/utils/serialization.py
Comment thread deepmd/tf2/utils/serialization.py
Comment thread source/tests/tf2/test_training.py Outdated
Comment thread deepmd/jax/utils/type_embed.py Fixed
Comment thread source/tests/tf2/test_training.py Fixed
Comment thread source/tests/tf2/test_training.py Fixed

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/tf2/train/trainer.py`:
- Around line 1421-1428: The fallback in the prediction assembly helper is
incorrectly copying label virial into model_pred when virial was requested,
which can mask missing model outputs. Update the helper that returns model_pred
in trainer.py to accept do_virial, and only inject label_dict["virial"] when
virial was not requested intentionally; if do_virial is true and "virial" is
absent from model_pred, leave it unset so the missing prediction is exposed. Use
the existing virial checks in the return path to keep the behavior aligned with
output_defs and model_pred.
- Around line 834-835: Gate prepared-step selection by task capability instead
of only the global enable_compile flag. Update _use_prepared_step in trainer.py
to consult per-task compile/prepared state for the given task_key, using the
task tracking established by _configure_model_compile() or an equivalent
capability check, so tasks that were skipped or lack compile support do not get
routed into _prepare_lower_batch().
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 3151d253-36c6-4125-b75a-e1ea5b378e9a

📥 Commits

Reviewing files that changed from the base of the PR and between f23d383 and 51679c2.

📒 Files selected for processing (4)
  • .github/workflows/test_python.yml
  • deepmd/tf2/train/trainer.py
  • dpa_adapt/cli.py
  • source/tests/tf2/test_training.py
💤 Files with no reviewable changes (1)
  • dpa_adapt/cli.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • source/tests/tf2/test_training.py

Comment thread deepmd/tf2/train/trainer.py Outdated
Comment thread deepmd/tf2/train/trainer.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants