From bb30614fc750646fcb00c88acc0ca6eb77aa1af1 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 3 Jul 2026 14:22:07 +0800 Subject: [PATCH] fix(tf): dispatch --init-model to checkpoint pre-inspection RunOptions records `dp train --init-model` as init_mode == "init_from_model", but DPTrainer.build() dispatched on the literal "init_model", so the branch never matched and _init_from_ckpt was skipped for --init-model. That pre-inspection is what imports the source checkpoint's meta graph and sets self.ckpt_meta when the checkpoint is a compressed_model, before the graph is built with ckpt_meta. With the mismatch, compressed-checkpoint --init-model builds the graph without its checkpoint metadata. Uncompressed --init-model still worked because variables are restored later in _init_session (which uses the correct "init_from_model" literal) and needs no ckpt_meta, which masked the bug. Fix the literal to "init_from_model". The 4-way init dispatch is extracted from the heavyweight build() into a small _init_from_run_opt() helper so it can be unit-tested in isolation; this is why the mismatch went uncaught. Adds a regression test that drives the dispatch with a stub trainer and mocked initializers: it fails on the old literal (init_from_model routes nowhere) and passes with the fix, and also covers restart, init_from_frz_model, finetune, and scratch. Fix #5679 --- deepmd/tf/train/trainer.py | 32 ++++++---- source/tests/tf/test_trainer_init_mode.py | 73 +++++++++++++++++++++++ 2 files changed, 95 insertions(+), 10 deletions(-) create mode 100644 source/tests/tf/test_trainer_init_mode.py diff --git a/deepmd/tf/train/trainer.py b/deepmd/tf/train/trainer.py index e1d7deb04b..8f4f840a12 100644 --- a/deepmd/tf/train/trainer.py +++ b/deepmd/tf/train/trainer.py @@ -255,16 +255,7 @@ def build( self.model.data_stat(data, stat_file_path=stat_file_path) # config the init_frz_model command - if self.run_opt.init_mode == "init_from_frz_model": - self._init_from_frz_model() - elif self.run_opt.init_mode == "init_model": - self._init_from_ckpt(self.run_opt.init_model) - elif self.run_opt.init_mode == "restart": - self._init_from_ckpt(self.run_opt.restart) - elif self.run_opt.init_mode == "finetune": - self._init_from_pretrained_model( - data=data, origin_type_map=origin_type_map - ) + self._init_from_run_opt(data=data, origin_type_map=origin_type_map) # neighbor_stat is moved to train.py as duplicated else: @@ -880,6 +871,27 @@ def _get_place_holders(self, data_dict: dict) -> None: tf.float32, name="t_find_" + kk ) + def _init_from_run_opt( + self, + data: DeepmdDataSystem, + origin_type_map: list[str] | None = None, + ) -> None: + """Dispatch checkpoint pre-inspection based on the run-option init mode. + + The mode strings must match the values produced by + :class:`deepmd.tf.train.run_options.RunOptions` exactly; a mismatch + silently skips the pre-inspection (see ``_init_from_ckpt``) that detects + compressed-checkpoint metadata before graph construction. + """ + if self.run_opt.init_mode == "init_from_frz_model": + self._init_from_frz_model() + elif self.run_opt.init_mode == "init_from_model": + self._init_from_ckpt(self.run_opt.init_model) + elif self.run_opt.init_mode == "restart": + self._init_from_ckpt(self.run_opt.restart) + elif self.run_opt.init_mode == "finetune": + self._init_from_pretrained_model(data=data, origin_type_map=origin_type_map) + def _init_from_frz_model(self) -> None: try: graph, graph_def = load_graph_def(self.run_opt.init_frz_model) diff --git a/source/tests/tf/test_trainer_init_mode.py b/source/tests/tf/test_trainer_init_mode.py new file mode 100644 index 0000000000..fa86363de3 --- /dev/null +++ b/source/tests/tf/test_trainer_init_mode.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Regression test for the TF trainer init-mode dispatch. + +``RunOptions`` records ``--init-model`` as ``init_mode == "init_from_model"``, +but ``DPTrainer`` used to dispatch on the literal ``"init_model"``. The +mismatch silently skipped ``_init_from_ckpt`` for ``--init-model``, so the +compressed-checkpoint pre-inspection that sets ``ckpt_meta`` before graph +construction never ran. This checks that each init mode routes to the right +initializer, using the mode strings ``RunOptions`` actually produces. +""" + +import types +import unittest +from unittest import ( + mock, +) + +from deepmd.tf.train.trainer import ( + DPTrainer, +) + + +def _dispatch(init_mode: str) -> str | None: + """Run the trainer's init-mode dispatch for ``init_mode`` and report the route. + + A bare ``DPTrainer`` instance is used (constructor bypassed) with the three + concrete initializers patched to record which one fires; the return value is + the name of the initializer that was called (or ``None`` for scratch). + """ + trainer = DPTrainer.__new__(DPTrainer) + trainer.run_opt = types.SimpleNamespace( + init_mode=init_mode, + init_model="some/init.ckpt", + restart="some/restart.ckpt", + init_frz_model="some/frozen.pb", + finetune="some/pretrained.pb", + ) + with ( + mock.patch.object(DPTrainer, "_init_from_frz_model") as frz, + mock.patch.object(DPTrainer, "_init_from_ckpt") as ckpt, + mock.patch.object(DPTrainer, "_init_from_pretrained_model") as pre, + ): + trainer._init_from_run_opt(data=None, origin_type_map=None) + if frz.called: + return "frz" + if ckpt.called: + return f"ckpt:{ckpt.call_args.args[0]}" + if pre.called: + return "pretrained" + return None + + +class TestTrainerInitMode(unittest.TestCase): + def test_init_from_model_uses_ckpt(self) -> None: + # RunOptions sets this string for `dp train --init-model`; it must reach + # _init_from_ckpt so compressed-checkpoint metadata is pre-inspected. + self.assertEqual(_dispatch("init_from_model"), "ckpt:some/init.ckpt") + + def test_restart_uses_ckpt(self) -> None: + self.assertEqual(_dispatch("restart"), "ckpt:some/restart.ckpt") + + def test_init_from_frz_model_uses_frz(self) -> None: + self.assertEqual(_dispatch("init_from_frz_model"), "frz") + + def test_finetune_uses_pretrained(self) -> None: + self.assertEqual(_dispatch("finetune"), "pretrained") + + def test_scratch_is_noop(self) -> None: + self.assertIsNone(_dispatch("init_from_scratch")) + + +if __name__ == "__main__": + unittest.main()