feat(tf2): add training workflow#5735
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughTF2 backend hooks, command entry points, eager training/validation, checkpoint serialization, shared-parameter utilities, and TF2 test coverage were added. Descriptor compression now normalizes dtypes with a NumPy conversion helper. ChangesTensorFlow2 Backend and Training Pipeline
Compression Dtype Normalization
Estimated code review effort: 5 (Critical) | ~120 minutes Sequence Diagram(s)sequenceDiagram
participant deepmd.main
participant TF2TrainEntrypoint
participant Trainer
participant TF2FullValidator
participant serialization
deepmd.main->>TF2TrainEntrypoint: run_training(options)
TF2TrainEntrypoint->>TF2TrainEntrypoint: update_sel(jdata)
TF2TrainEntrypoint->>Trainer: construct DPTrainer(...)
loop training steps
Trainer->>Trainer: train_step / evaluate_training
Trainer->>TF2FullValidator: run_full_validation
Trainer->>serialization: save_checkpoint(training_state.json)
end
sequenceDiagram
participant deepmd.tf2.entrypoints.main
participant freeze
participant compress
participant serialization
deepmd.tf2.entrypoints.main->>freeze: freeze(checkpoint_folder, output, head)
freeze->>serialization: serialize_from_file(checkpoint_folder)
deepmd.tf2.entrypoints.main->>compress: enable_compression(input_file, output, head)
compress->>serialization: serialize_from_file(input_file)
Possibly related PRs
Suggested labels: Suggested reviewers: 🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
deepmd/tf2/model/dp_model.py (1)
119-152: 🎯 Functional Correctness | 🟡 Minor | ⚡ Quick winForward
do_deriv_con the non-formatted lower path (deepmd/tf2/model/dp_model.py:141-152).neighbor_listcallers can hit this branch withdo_deriv_c=False, butsuper().call_common_lower(...)doesn’t take the flag, so the virial-derivative suppression is ignored here.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/tf2/model/dp_model.py` around lines 119 - 152, The non-formatted lower path in call_common_lower is dropping the do_deriv_c argument, so requests that disable virial derivatives still behave as if it were enabled. Update the dp_model.py implementation in the call_common_lower branch to forward do_deriv_c through the super().call_common_lower call, matching the formatted path and preserving neighbor_list behavior for callers that pass do_deriv_c=False.
🧹 Nitpick comments (3)
deepmd/tf2/entrypoints/freeze.py (1)
8-10: 🚀 Performance & Scalability | 🔵 Trivial | ⚡ Quick winAvoid deep-copying the entire multi-task payload just to swap two keys.
deepcopy(data)duplicates every other branch's weights in memory (and this cost is inherited bycompress.pytoo, since it calls the same function). Sincemodel/model_def_scriptare fully reassigned rather than mutated, a shallow copy is sufficient.♻️ Proposed fix to avoid unnecessary deep copy
-from copy import ( - deepcopy, -) from typing import ( Any, ) @@ - selected = deepcopy(data) + selected = data.copy() selected["model"] = data["model"]["model_dict"][resolved_head] selected["model_def_script"] = model_def_script["model_dict"][resolved_head]Also applies to: 49-79
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/tf2/entrypoints/freeze.py` around lines 8 - 10, The multi-task payload handling in the freeze path is doing an unnecessary deepcopy when only swapping the model-related keys. Update the shared payload-copy logic in freeze.py (and the helper it is reused by compress.py) to use a shallow copy instead of deepcopy(data), since model and model_def_script are fully reassigned rather than mutated. Keep the change localized to the function that prepares the data for freezing so it preserves the existing behavior without duplicating the rest of the task branches in memory.deepmd/tf2/utils/multi_task.py (1)
396-411: 📐 Maintainability & Code Quality | 🔵 Trivial | 🏗️ Heavy liftAttribute-sharing via automatic
isinstancediscovery with an exclude-list is fragile.
_share_tf2_state_attrsshares any top-level, non-underscore, non-excluded attribute onlink_classthat happens to be atf.Module/tf.Variable/tf.Tensor/xp.Array. This is an "opt-out" design: any new state attribute added to a future fitting-net implementation that shouldn't be shared across multi-task branches (e.g. a per-branch scale or embedding) will be silently and incorrectly shared unless someone remembers to add it toexcluded. An explicit allowlist of shareable attribute names (mirroring the current_merge_and_share_param_statscalls) would be safer against silent regressions as fitting-net types evolve.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/tf2/utils/multi_task.py` around lines 396 - 411, The automatic sharing in _share_tf2_state_attrs is too broad because it discovers shareable attributes by isinstance and excludes only a small denylist. Replace this with an explicit allowlist of attribute names that are safe to share, similar to the existing _merge_and_share_param_stats usage, and have the sharing loop only copy those named attributes from base_class to link_class. Keep _is_shareable_tf2_state as a guard if needed, but do not rely on it for discovery.deepmd/tf2/train/trainer.py (1)
995-1015: 🩺 Stability & Availability | 🔵 Trivial | ⚖️ Poor tradeoff
_write_checkpoint_directorymutates the sharedself.stepvariable as a side effect.
self.step.assign(step)here reassigns the sametf.Variablethat is part of the primaryself.checkpoint/self.checkpoint_manager(used bysave_checkpoint), purely to persist a value into the separate "best" validation-checkpoint directory. Sinceself.stepis shared state, if training exits or another checkpoint-related read occurs between arun_full_validation()call and the next regulartrain_step/save_checkpoint()call,self.stepwill transiently hold the full-validationdisplay_stepvalue rather than the actual global training step. It self-heals on the nexttrain_step()(Line 847), so the practical window is narrow, but it's fragile shared-mutable-state design that could bite ifdisplay_stepand the true global step ever diverge (e.g. custom full-validation cadence).Consider using a throwaway
tf.Variable/separateCheckpoint(step=..., model=..., optimizer=...)object for the "best" checkpoint save path instead of mutating the primary trainingself.step.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/tf2/train/trainer.py` around lines 995 - 1015, The _write_checkpoint_directory helper is mutating the shared training step via self.step.assign(step) just to write the separate validation checkpoint. Update _save_full_validation_checkpoint/_write_checkpoint_directory so the “best” checkpoint path uses its own temporary step variable or separate Checkpoint object instead of changing self.step, and keep the primary checkpoint state used by save_checkpoint and train_step untouched.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/tf2/make_model.py`:
- Around line 179-184: The coord correction remapping in make_model should
reshape coord_corr_for_virial before calling xp.take_along_axis, since the
current flattened nframes x (nloc x 3) tensor is invalid for axis-based lookup.
Update the coord_corr handling block to convert the correction tensor to
(nframes, nloc, 3) first, then apply the existing mapping_idx remap logic so
extended_coord_corr is built from a 3D input.
In `@deepmd/tf2/train/trainer.py`:
- Around line 384-403: The checkpoint restore flow in trainer initialization is
dropping the `CheckpointLoadStatus` too early, which can hide deferred
mismatches for `restart_model`. Update the `Trainer` restore path used by
`_restore_checkpoint` so it retains the status object from
`self.checkpoint.restore(...).expect_partial()`, then defer the final
consumption check until after `_build_optimizer_slots()` has run. Ensure the
`restart_model` branch still sets `self.start_step` from `self.step` only after
the restore status has been preserved and validated.
In `@deepmd/tf2/utils/serialization.py`:
- Around line 606-616: The checkpoint resolution logic in
_resolve_checkpoint_path is mistakenly treating existing directories as prefix
files because it uses path.exists(), which can send metadata lookup to the wrong
parent directory. Update the fallback in _resolve_checkpoint_path so only real
checkpoint prefixes are accepted there, and ensure directory paths are handled
exclusively through the is_dir()/latest_checkpoint branch with the correct
state_dir. Keep the fix localized to _resolve_checkpoint_path and the returned
(path, state_dir) values.
- Around line 623-634: The restore flow in _restore_models_from_checkpoint
currently calls checkpoint.restore(...).expect_partial(), which suppresses
warnings but does not verify that the restored weights were actually matched to
the built models. Update this path to keep the restore status from
tf.train.Checkpoint.restore in _restore_models_from_checkpoint, ensure any lazy
variables in _build_models/_TaskModelContainer are materialized before
validation, and then call assert_existing_objects_matched() so mismatches
between the checkpoint and model object graph fail fast.
In `@source/tests/tf2/test_training.py`:
- Around line 49-51: Add the missing module-level timeout marker in the training
test module so it follows the same convention as the other training tests:
update the existing pytestmark definition in test_training to include
pytest.mark.timeout(60) alongside the current warning filter. Keep the marker
applied at module scope so all tests in this file are bounded consistently.
---
Outside diff comments:
In `@deepmd/tf2/model/dp_model.py`:
- Around line 119-152: The non-formatted lower path in call_common_lower is
dropping the do_deriv_c argument, so requests that disable virial derivatives
still behave as if it were enabled. Update the dp_model.py implementation in the
call_common_lower branch to forward do_deriv_c through the
super().call_common_lower call, matching the formatted path and preserving
neighbor_list behavior for callers that pass do_deriv_c=False.
---
Nitpick comments:
In `@deepmd/tf2/entrypoints/freeze.py`:
- Around line 8-10: The multi-task payload handling in the freeze path is doing
an unnecessary deepcopy when only swapping the model-related keys. Update the
shared payload-copy logic in freeze.py (and the helper it is reused by
compress.py) to use a shallow copy instead of deepcopy(data), since model and
model_def_script are fully reassigned rather than mutated. Keep the change
localized to the function that prepares the data for freezing so it preserves
the existing behavior without duplicating the rest of the task branches in
memory.
In `@deepmd/tf2/train/trainer.py`:
- Around line 995-1015: The _write_checkpoint_directory helper is mutating the
shared training step via self.step.assign(step) just to write the separate
validation checkpoint. Update
_save_full_validation_checkpoint/_write_checkpoint_directory so the “best”
checkpoint path uses its own temporary step variable or separate Checkpoint
object instead of changing self.step, and keep the primary checkpoint state used
by save_checkpoint and train_step untouched.
In `@deepmd/tf2/utils/multi_task.py`:
- Around line 396-411: The automatic sharing in _share_tf2_state_attrs is too
broad because it discovers shareable attributes by isinstance and excludes only
a small denylist. Replace this with an explicit allowlist of attribute names
that are safe to share, similar to the existing _merge_and_share_param_stats
usage, and have the sharing loop only copy those named attributes from
base_class to link_class. Keep _is_shareable_tf2_state as a guard if needed, but
do not rely on it for discovery.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 6a6ad7b4-539e-402a-be1a-2f04c7ec0928
📒 Files selected for processing (25)
deepmd/backend/tf2.pydeepmd/dpmodel/common.pydeepmd/dpmodel/descriptor/dpa1.pydeepmd/dpmodel/descriptor/se_e2_a.pydeepmd/dpmodel/descriptor/se_r.pydeepmd/main.pydeepmd/tf2/atomic_model/dp_atomic_model.pydeepmd/tf2/entrypoints/__init__.pydeepmd/tf2/entrypoints/compress.pydeepmd/tf2/entrypoints/freeze.pydeepmd/tf2/entrypoints/main.pydeepmd/tf2/entrypoints/train.pydeepmd/tf2/make_model.pydeepmd/tf2/model/base_model.pydeepmd/tf2/model/dp_model.pydeepmd/tf2/train/__init__.pydeepmd/tf2/train/trainer.pydeepmd/tf2/train/validation.pydeepmd/tf2/utils/auto_batch_size.pydeepmd/tf2/utils/finetune.pydeepmd/tf2/utils/jit.pydeepmd/tf2/utils/multi_task.pydeepmd/tf2/utils/serialization.pydeepmd/utils/argcheck.pysource/tests/tf2/test_training.py
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/tf2/train/trainer.py`:
- Around line 1421-1428: The fallback in the prediction assembly helper is
incorrectly copying label virial into model_pred when virial was requested,
which can mask missing model outputs. Update the helper that returns model_pred
in trainer.py to accept do_virial, and only inject label_dict["virial"] when
virial was not requested intentionally; if do_virial is true and "virial" is
absent from model_pred, leave it unset so the missing prediction is exposed. Use
the existing virial checks in the return path to keep the behavior aligned with
output_defs and model_pred.
- Around line 834-835: Gate prepared-step selection by task capability instead
of only the global enable_compile flag. Update _use_prepared_step in trainer.py
to consult per-task compile/prepared state for the given task_key, using the
task tracking established by _configure_model_compile() or an equivalent
capability check, so tasks that were skipped or lack compile support do not get
routed into _prepare_lower_batch().
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 3151d253-36c6-4125-b75a-e1ea5b378e9a
📒 Files selected for processing (4)
.github/workflows/test_python.ymldeepmd/tf2/train/trainer.pydpa_adapt/cli.pysource/tests/tf2/test_training.py
💤 Files with no reviewable changes (1)
- dpa_adapt/cli.py
🚧 Files skipped from review as they are similar to previous changes (1)
- source/tests/tf2/test_training.py
Summary
Validation
python -m pytest source/tests/tf2/test_training.py -qruff check .ruff format --check deepmd/tf2 source/tests/tf2DP_JIT=1 srun --gres=gpu:1 dp --tf2 train input.json --skip-neighbor-statonexamples/water/se_e2_atemp copy: XLA lower compiled; first step ~140s, steady windows ~0.0305 s/step through step 900 before timeoutDP_JIT=1 srun --gres=gpu:1 dp --tf2 train input.json --skip-neighbor-staton fixed-selection dpa3 temp input: XLA lower compiled; first step ~227s, steady windows ~0.119 s/step through step 500 before timeoutNotes
DP_JITis opt-in for training lower forward only; the whole train/eval step is intentionally not XLA compiled because neighbor/outer training logic is too broad and previously hit unsupported XLA ops.Summary by CodeRabbit
--head/--model-branch.