From 304d27ceb4a691d3e7a85f01d035d715c98b63e0 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 5 Jul 2026 00:35:29 +0800 Subject: [PATCH 01/14] feat(tf2): add training workflow --- deepmd/backend/tf2.py | 19 +- deepmd/dpmodel/common.py | 12 + deepmd/dpmodel/descriptor/dpa1.py | 3 +- deepmd/dpmodel/descriptor/se_e2_a.py | 3 +- deepmd/dpmodel/descriptor/se_r.py | 3 +- deepmd/main.py | 13 +- deepmd/tf2/atomic_model/dp_atomic_model.py | 8 + deepmd/tf2/entrypoints/__init__.py | 14 + deepmd/tf2/entrypoints/compress.py | 75 ++ deepmd/tf2/entrypoints/freeze.py | 79 ++ deepmd/tf2/entrypoints/main.py | 58 + deepmd/tf2/entrypoints/train.py | 279 +++++ deepmd/tf2/train/__init__.py | 16 + deepmd/tf2/train/trainer.py | 1193 ++++++++++++++++++++ deepmd/tf2/train/validation.py | 212 ++++ deepmd/tf2/utils/auto_batch_size.py | 19 + deepmd/tf2/utils/finetune.py | 38 + deepmd/tf2/utils/multi_task.py | 410 +++++++ deepmd/tf2/utils/serialization.py | 161 ++- deepmd/utils/argcheck.py | 2 +- source/tests/tf2/test_training.py | 242 ++++ 21 files changed, 2845 insertions(+), 14 deletions(-) create mode 100644 deepmd/tf2/entrypoints/__init__.py create mode 100644 deepmd/tf2/entrypoints/compress.py create mode 100644 deepmd/tf2/entrypoints/freeze.py create mode 100644 deepmd/tf2/entrypoints/main.py create mode 100644 deepmd/tf2/entrypoints/train.py create mode 100644 deepmd/tf2/train/__init__.py create mode 100644 deepmd/tf2/train/trainer.py create mode 100644 deepmd/tf2/train/validation.py create mode 100644 deepmd/tf2/utils/auto_batch_size.py create mode 100644 deepmd/tf2/utils/finetune.py create mode 100644 deepmd/tf2/utils/multi_task.py create mode 100644 source/tests/tf2/test_training.py diff --git a/deepmd/backend/tf2.py b/deepmd/backend/tf2.py index e0c750cb70..e661c7e99c 100644 --- a/deepmd/backend/tf2.py +++ b/deepmd/backend/tf2.py @@ -33,7 +33,12 @@ class TensorFlow2Backend(Backend): """TensorFlow 2 eager backend.""" name = "TensorFlow2" - features: ClassVar[Backend.Feature] = Backend.Feature.DEEP_EVAL | Backend.Feature.IO + features: ClassVar[Backend.Feature] = ( + Backend.Feature.ENTRY_POINT + | Backend.Feature.DEEP_EVAL + | Backend.Feature.NEIGHBOR_STAT + | Backend.Feature.IO + ) suffixes: ClassVar[list[str]] = [".savedmodeltf"] @classmethod @@ -45,7 +50,11 @@ def is_available(self) -> bool: @property def entry_point_hook(self) -> Callable[["Namespace"], None]: - raise NotImplementedError("Training entry point is not implemented for TF2") + from deepmd.tf2.entrypoints.main import ( + main, + ) + + return main @property def deep_eval(self) -> type["DeepEvalBackend"]: @@ -57,7 +66,11 @@ def deep_eval(self) -> type["DeepEvalBackend"]: @property def neighbor_stat(self) -> type["NeighborStat"]: - raise NotImplementedError("Neighbor statistics are not implemented for TF2") + from deepmd.dpmodel.utils.neighbor_stat import ( + NeighborStat, + ) + + return NeighborStat @property def serialize_hook(self) -> Callable[[str], dict]: diff --git a/deepmd/dpmodel/common.py b/deepmd/dpmodel/common.py index 0d363614a7..d28b69d810 100644 --- a/deepmd/dpmodel/common.py +++ b/deepmd/dpmodel/common.py @@ -90,6 +90,18 @@ def get_xp_precision( raise ValueError(f"unsupported precision {precision} for {xp}") +def to_numpy_dtype(dtype: Any) -> np.dtype: + """Normalize backend dtype objects to a NumPy dtype.""" + dtype = getattr(dtype, "as_numpy_dtype", dtype) + try: + return np.dtype(dtype) + except TypeError: + dtype_name = getattr(dtype, "name", None) + if dtype_name is not None: + return np.dtype(dtype_name) + raise + + class NativeOP(ABC): """The unit operation of a native model.""" diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 13f8f3e351..c8145f5833 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -27,6 +27,7 @@ from deepmd.dpmodel.common import ( cast_precision, to_numpy_array, + to_numpy_dtype, ) from deepmd.dpmodel.utils import ( EmbeddingNet, @@ -1409,7 +1410,7 @@ def enable_compression( ) -> None: """Store tabulated geometric embedding-net data.""" net = "filter_net" - dtype = self.mean.dtype + dtype = to_numpy_dtype(self.mean.dtype) self.compress_info = [ np.asarray( [ diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index c7e4bc8257..1030aa23ff 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -22,6 +22,7 @@ from deepmd.dpmodel.common import ( cast_precision, to_numpy_array, + to_numpy_dtype, ) from deepmd.dpmodel.utils import ( EmbeddingNet, @@ -431,7 +432,7 @@ def _store_compress_data( """Store tabulated embedding-net data in the descriptor state.""" compress_data = [] compress_info = [] - dtype = self.davg.dtype + dtype = to_numpy_dtype(self.davg.dtype) ndim = 1 if self.type_one_side else 2 for embedding_idx in range(self.ntypes**ndim): if self.type_one_side: diff --git a/deepmd/dpmodel/descriptor/se_r.py b/deepmd/dpmodel/descriptor/se_r.py index e992409452..4c4d86b258 100644 --- a/deepmd/dpmodel/descriptor/se_r.py +++ b/deepmd/dpmodel/descriptor/se_r.py @@ -22,6 +22,7 @@ cast_precision, get_xp_precision, to_numpy_array, + to_numpy_dtype, ) from deepmd.dpmodel.utils import ( EmbeddingNet, @@ -410,7 +411,7 @@ def _store_compress_data( """Store tabulated embedding-net data in the descriptor state.""" compress_data = [] compress_info = [] - dtype = self.davg.dtype + dtype = to_numpy_dtype(self.davg.dtype) for embedding_idx in range(self.ntypes): net = "filter_-1_net_" + str(embedding_idx) if net not in table_data: diff --git a/deepmd/main.py b/deepmd/main.py index 1ec9d425c3..42a03bc3f0 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -327,7 +327,7 @@ def main_parser() -> argparse.ArgumentParser: "--output", type=str, default="frozen_model", - help="Filename (prefix) of the output model file. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pth; Paddle backend: suffix is .json and .pdiparams", + help="Filename (prefix) of the output model file. TensorFlow backend: suffix is .pb; TensorFlow2 backend: suffix is .savedmodeltf; PyTorch backend: suffix is .pth; Paddle backend: suffix is .json and .pdiparams", ) parser_frz.add_argument( "-n", @@ -605,14 +605,14 @@ def main_parser() -> argparse.ArgumentParser: "--input", default="frozen_model", type=str, - help="The original frozen model, which will be compressed by the code. Filename (prefix) of the input model file. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pth; DPModel backend: suffix is .dp; JAX backend: suffix is .hlo or .jax", + help="The original frozen model or checkpoint, which will be compressed by the code. Filename (prefix) of the input model file. TensorFlow backend: suffix is .pb; TensorFlow2 backend: .tf2 checkpoint directory or checkpoint prefix; PyTorch backend: suffix is .pth; DPModel backend: suffix is .dp; JAX backend: suffix is .hlo or .jax", ) parser_compress.add_argument( "-o", "--output", default="frozen_model_compressed", type=str, - help="The compressed model. Filename (prefix) of the output model file. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pth; DPModel backend: suffix is .dp; JAX backend: suffix is .hlo or .jax", + help="The compressed model. Filename (prefix) of the output model file. TensorFlow backend: suffix is .pb; TensorFlow2 backend: suffix is .savedmodeltf; PyTorch backend: suffix is .pth; DPModel backend: suffix is .dp; JAX backend: suffix is .hlo or .jax", ) parser_compress.add_argument( "-s", @@ -659,6 +659,13 @@ def main_parser() -> argparse.ArgumentParser: default=None, help="The training script of the input frozen model", ) + parser_compress.add_argument( + "--head", + "--model-branch", + default=None, + type=str, + help="Task head (alias: model branch) to compress if in multi-task mode.", + ) # * print docs script ************************************************************** parsers_doc = subparsers.add_parser( diff --git a/deepmd/tf2/atomic_model/dp_atomic_model.py b/deepmd/tf2/atomic_model/dp_atomic_model.py index 604ea5b6de..2d68ba81ef 100644 --- a/deepmd/tf2/atomic_model/dp_atomic_model.py +++ b/deepmd/tf2/atomic_model/dp_atomic_model.py @@ -11,6 +11,7 @@ ) from deepmd.tf2.env import ( stop_gradient, + tf, xp, ) from deepmd.tf2.fitting.base_fitting import ( @@ -41,6 +42,13 @@ class tf2_atomic_model(dpmodel_atomic_model): base_fitting_cls = BaseFitting """The base fitting class.""" + @tf.autograph.experimental.do_not_convert + def make_atom_mask( + self, + atype: xp.ndarray, + ) -> xp.ndarray: + return atype >= 0 + def forward_common_atomic( self, extended_coord: xp.ndarray, diff --git a/deepmd/tf2/entrypoints/__init__.py b/deepmd/tf2/entrypoints/__init__.py new file mode 100644 index 0000000000..f665d1cce8 --- /dev/null +++ b/deepmd/tf2/entrypoints/__init__.py @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Entry points for the TensorFlow 2 backend.""" + +from deepmd.tf2.entrypoints.compress import ( + enable_compression, +) +from deepmd.tf2.entrypoints.freeze import ( + freeze, +) +from deepmd.tf2.entrypoints.train import ( + train, +) + +__all__ = ["enable_compression", "freeze", "train"] diff --git a/deepmd/tf2/entrypoints/compress.py b/deepmd/tf2/entrypoints/compress.py new file mode 100644 index 0000000000..39ce71ee0b --- /dev/null +++ b/deepmd/tf2/entrypoints/compress.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Compress TensorFlow 2 checkpoints by tabulating embedding networks.""" + +from __future__ import ( + annotations, +) + +import logging +from typing import ( + Any, +) + +from deepmd.backend.suffix import ( + format_model_suffix, +) +from deepmd.dpmodel.entrypoints.compress_common import ( + enable_model_compression, + resolve_min_nbor_dist, +) +from deepmd.dpmodel.utils.update_sel import ( + UpdateSel, +) +from deepmd.tf2.entrypoints.freeze import ( + select_model_branch, +) +from deepmd.tf2.model.base_model import ( + BaseModel, +) +from deepmd.tf2.utils.serialization import ( + deserialize_to_file, + serialize_from_file, +) + +log = logging.getLogger(__name__) + + +def enable_compression( + input_file: str, + output: str, + stride: float = 0.01, + extrapolate: int = 5, + check_frequency: int = -1, + training_script: str | None = None, + head: str | None = None, + **kwargs: Any, +) -> None: + """Compress a TF2 training checkpoint and export a SavedModel.""" + del kwargs + output = format_model_suffix( + output, + preferred_backend="tf2", + strict_prefer=True, + ) + data = serialize_from_file(input_file) + data = select_model_branch(data, head=head) + model = BaseModel.deserialize(data["model"]) + min_nbor_dist = resolve_min_nbor_dist( + model, + [data], + training_script, + UpdateSel, + ) + enable_model_compression( + model, + min_nbor_dist, + stride, + extrapolate, + check_frequency, + ) + + compressed_data = data.copy() + compressed_data["model"] = model.serialize() + compressed_data["min_nbor_dist"] = float(min_nbor_dist) + deserialize_to_file(output, compressed_data) + log.info("Compressed TF2 model saved to %s", output) diff --git a/deepmd/tf2/entrypoints/freeze.py b/deepmd/tf2/entrypoints/freeze.py new file mode 100644 index 0000000000..ebc216408f --- /dev/null +++ b/deepmd/tf2/entrypoints/freeze.py @@ -0,0 +1,79 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Freeze utilities for the TensorFlow 2 backend.""" + +from __future__ import ( + annotations, +) + +from copy import ( + deepcopy, +) +from typing import ( + Any, +) + +from deepmd.backend.suffix import ( + format_model_suffix, +) +from deepmd.dpmodel.train import ( + DEFAULT_TASK_KEY, +) +from deepmd.tf2.utils.serialization import ( + deserialize_to_file, + serialize_from_file, +) +from deepmd.utils.model_branch_dict import ( + get_model_dict, +) + + +def freeze( + *, + checkpoint_folder: str, + output: str, + head: str | None = None, + **kwargs: Any, +) -> None: + """Freeze a TF2 training checkpoint into a TensorFlow SavedModel.""" + del kwargs + output = format_model_suffix( + output, + preferred_backend="tf2", + strict_prefer=True, + ) + data = serialize_from_file(checkpoint_folder) + data = select_model_branch(data, head=head) + deserialize_to_file(output, data) + + +def select_model_branch( + data: dict[str, Any], head: str | None = None +) -> dict[str, Any]: + """Select one branch from a single-task or multi-task serialized payload.""" + model_def_script = data["model_def_script"] + if "model_dict" not in model_def_script: + if head not in (None, "", DEFAULT_TASK_KEY): + raise ValueError( + f"Single-task TF2 checkpoints do not have a head named {head!r}." + ) + return data + + if head in (None, ""): + raise ValueError( + "Multi-task TF2 checkpoints require --head/--model-branch to select " + "which model branch to freeze." + ) + model_alias_dict, _ = get_model_dict(model_def_script["model_dict"]) + if head not in model_alias_dict: + raise ValueError( + f"No model branch or alias named {head!r}. Available branches are " + f"{list(model_def_script['model_dict'])}." + ) + resolved_head = model_alias_dict[head] + selected = deepcopy(data) + selected["model"] = data["model"]["model_dict"][resolved_head] + selected["model_def_script"] = model_def_script["model_dict"][resolved_head] + min_nbor_dist = data.get("min_nbor_dist") + if isinstance(min_nbor_dist, dict): + selected["min_nbor_dist"] = min_nbor_dist.get(resolved_head) + return selected diff --git a/deepmd/tf2/entrypoints/main.py b/deepmd/tf2/entrypoints/main.py new file mode 100644 index 0000000000..4729c1a8b5 --- /dev/null +++ b/deepmd/tf2/entrypoints/main.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""DeePMD-kit entry point for the TensorFlow 2 backend.""" + +import argparse +from pathlib import ( + Path, +) + +from deepmd.loggers.loggers import ( + set_log_handles, +) +from deepmd.main import ( + parse_args, +) +from deepmd.tf2.entrypoints.compress import ( + enable_compression, +) +from deepmd.tf2.entrypoints.freeze import ( + freeze, +) +from deepmd.tf2.entrypoints.train import ( + train, +) + +__all__ = ["main"] + + +def main(args: list[str] | argparse.Namespace | None = None) -> None: + """TensorFlow 2 backend command dispatcher.""" + if not isinstance(args, argparse.Namespace): + args = parse_args(args=args) + + set_log_handles( + args.log_level, + Path(args.log_path) if args.log_path else None, + mpi_log=None, + ) + + if args.command == "train": + train(**vars(args)) + elif args.command == "freeze": + freeze(**vars(args)) + elif args.command == "compress": + enable_compression( + input_file=args.input, + output=args.output, + stride=args.step, + extrapolate=args.extrapolate, + check_frequency=args.frequency, + training_script=args.training_script, + head=args.head, + ) + elif args.command is None: + pass + else: + raise RuntimeError( + f"Unsupported command '{args.command}' for the TensorFlow 2 backend." + ) diff --git a/deepmd/tf2/entrypoints/train.py b/deepmd/tf2/entrypoints/train.py new file mode 100644 index 0000000000..873c7df89a --- /dev/null +++ b/deepmd/tf2/entrypoints/train.py @@ -0,0 +1,279 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Training entrypoint for the TensorFlow 2 eager backend.""" + +from __future__ import ( + annotations, +) + +import logging +import time +from pathlib import ( + Path, +) +from typing import ( + Any, +) + +import h5py + +from deepmd.dpmodel.model.base_model import ( + BaseModel, +) +from deepmd.dpmodel.train import ( + AbstractTrainEntrypoint, + TrainEntrypointOptions, + TrainingTaskConfig, + iter_training_task_configs, + make_task_maps, + print_data_summaries, +) +from deepmd.tf2.env import ( + tf, +) +from deepmd.tf2.train.trainer import ( + DPTrainer, +) +from deepmd.tf2.utils.serialization import ( + serialize_from_file, +) +from deepmd.utils import random as dp_random +from deepmd.utils.data_system import ( + get_data, +) +from deepmd.utils.path import ( + DPPath, +) +from deepmd.utils.summary import SummaryPrinter as BaseSummaryPrinter + +__all__ = ["train", "update_sel"] + +log = logging.getLogger(__name__) + + +class SummaryPrinter(BaseSummaryPrinter): + """Summary printer for TensorFlow 2.""" + + def is_built_with_cuda(self) -> bool: + return bool(tf.test.is_built_with_cuda()) + + def is_built_with_rocm(self) -> bool: + return False + + def get_compute_device(self) -> str: + return "gpu" if tf.config.list_physical_devices("GPU") else "cpu" + + def get_ngpus(self) -> int: + return len(tf.config.list_physical_devices("GPU")) + + def get_backend_info(self) -> dict: + return { + "Backend": "TensorFlow2", + "TensorFlow ver": tf.__version__, + "Eager mode": str(tf.executing_eagerly()), + } + + def get_device_name(self) -> str | None: + devices = tf.config.list_physical_devices("GPU") + return devices[0].name if devices else None + + +class TF2TrainEntrypoint(AbstractTrainEntrypoint): + """TensorFlow 2 implementation of the common training entrypoint pipeline.""" + + def __init__(self) -> None: + self.finetune_links: dict[str, Any] | None = None + self.shared_links: dict[str, Any] | None = None + + def validate_options( + self, + config: dict[str, Any], + options: TrainEntrypointOptions, + ) -> None: + if options.init_frz_model: + raise NotImplementedError( + "TF2 training does not support init_frz_model yet." + ) + + def preprocess_config( + self, + config: dict[str, Any], + options: TrainEntrypointOptions, + ) -> dict[str, Any]: + """Apply TF2 fine-tuning and pretrained-script preprocessing.""" + self.finetune_links = None + self.shared_links = None + if self.is_multi_task(config): + if "RANDOM" in config["model"]["model_dict"]: + raise ValueError("Model name can not be 'RANDOM' in multi-task mode!") + if config["model"].get("shared_dict"): + from deepmd.tf2.utils.multi_task import ( + preprocess_shared_params, + ) + + config["model"], self.shared_links = preprocess_shared_params( + config["model"] + ) + if options.finetune is not None: + from deepmd.tf2.utils.finetune import ( + get_finetune_rules, + ) + + config["model"], self.finetune_links = get_finetune_rules( + options.finetune, + config["model"], + model_branch=options.model_branch, + change_model_params=options.use_pretrain_script, + ) + elif options.init_model is not None and options.use_pretrain_script: + model_data = serialize_from_file(options.init_model) + config["model"] = model_data["model_def_script"] + self.shared_links = model_data.get("shared_links") + return config + + def update_neighbor_stat( + self, + config: dict[str, Any], + options: TrainEntrypointOptions, + *, + multi_task: bool, + ) -> tuple[dict[str, Any], float | dict[str, float | None] | None]: + log.info( + "Calculate neighbor statistics... " + "(add --skip-neighbor-stat to skip this step)" + ) + return update_sel(config, multi_task=multi_task) + + def print_summary(self) -> None: + SummaryPrinter()() + + def run_training( + self, + config: dict[str, Any], + options: TrainEntrypointOptions, + neighbor_stat: Any, + ) -> None: + seed = config["training"].get("seed") + dp_random.seed(None if seed is None else int(seed) % (2**32)) + + def factory( + task_config: TrainingTaskConfig, + ) -> tuple[Any, Any | None, DPPath | None]: + type_map = list(task_config.model_params.get("type_map", [])) + ipt_type_map = type_map if type_map else None + train_data = get_data( + dict(task_config.training_data_params), + None, + ipt_type_map, + None, + ) + valid_data = None + if task_config.validation_data_params is not None: + valid_data = get_data( + dict(task_config.validation_data_params), + None, + train_data.type_map, + None, + ) + return train_data, valid_data, _make_stat_file_path(task_config.stat_file) + + train_data_map, valid_data_map, stat_file_map = make_task_maps( + config, + factory, + ) + print_data_summaries(train_data_map, valid_data_map) + + trainer = DPTrainer( + config, + train_data_map, + stat_file_path=stat_file_map, + validation_data=valid_data_map, + init_model=options.init_model, + restart_model=options.restart, + finetune_model=options.finetune, + finetune_links=self.finetune_links, + shared_links=self.shared_links, + min_nbor_dist=neighbor_stat, + ) + start_time = time.time() + trainer.run() + end_time = time.time() + log.info("finished training") + log.info("wall time: %.3f s", end_time - start_time) + + +def train( + *, + INPUT: str, + init_model: str | None, + restart: str | None, + output: str, + init_frz_model: str | None, + mpi_log: str, + log_level: int, + log_path: str | None, + skip_neighbor_stat: bool = False, + finetune: str | None = None, + model_branch: str = "", + use_pretrain_script: bool = False, + **kwargs: Any, +) -> None: + """Run DeePMD model training with TensorFlow 2 eager execution.""" + TF2TrainEntrypoint().run( + TrainEntrypointOptions( + input_file=INPUT, + output=output, + init_model=init_model, + restart=restart, + init_frz_model=init_frz_model, + finetune=finetune, + model_branch=model_branch, + use_pretrain_script=use_pretrain_script, + skip_neighbor_stat=skip_neighbor_stat, + ) + ) + + +def update_sel( + jdata: dict[str, Any], + *, + multi_task: bool | None = None, +) -> tuple[dict[str, Any], float | dict[str, float | None] | None]: + """Update descriptor selections from neighbor statistics when available.""" + jdata_cpy = jdata.copy() + if multi_task is None: + multi_task = "model_dict" in jdata["model"] + min_nbor_dist: dict[str, float | None] = {} + for task_config in iter_training_task_configs(jdata): + type_map = task_config.model_params.get("type_map") + train_data = get_data( + dict(task_config.training_data_params), + 0, + type_map, + None, + ) + updated_model, task_min_nbor_dist = BaseModel.update_sel( + train_data, + type_map, + dict(task_config.model_params), + ) + min_nbor_dist[task_config.key] = task_min_nbor_dist + if multi_task: + jdata_cpy["model"]["model_dict"][task_config.key] = updated_model + else: + jdata_cpy["model"] = updated_model + return jdata_cpy, task_min_nbor_dist + return jdata_cpy, min_nbor_dist + + +def _make_stat_file_path(stat_file_raw: str | None) -> DPPath | None: + if stat_file_raw is None: + return None + stat_file_target = Path(stat_file_raw) + stat_file_target.parent.mkdir(parents=True, exist_ok=True) + if not stat_file_target.exists(): + if stat_file_raw.endswith((".h5", ".hdf5")): + with h5py.File(stat_file_raw, "w"): + pass + else: + stat_file_target.mkdir(parents=True, exist_ok=True) + return DPPath(stat_file_raw, "a") diff --git a/deepmd/tf2/train/__init__.py b/deepmd/tf2/train/__init__.py new file mode 100644 index 0000000000..b5594b2397 --- /dev/null +++ b/deepmd/tf2/train/__init__.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Training utilities for the TensorFlow 2 eager backend.""" + +from .trainer import ( + DPTrainer, + Trainer, + get_additional_data_requirement, + get_loss, +) + +__all__ = [ + "DPTrainer", + "Trainer", + "get_additional_data_requirement", + "get_loss", +] diff --git a/deepmd/tf2/train/trainer.py b/deepmd/tf2/train/trainer.py new file mode 100644 index 0000000000..ffd9cc8c58 --- /dev/null +++ b/deepmd/tf2/train/trainer.py @@ -0,0 +1,1193 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""TensorFlow 2 eager training loop.""" + +from __future__ import ( + annotations, +) + +import functools +import json +import logging +import shutil +import time +from collections.abc import ( + Mapping, +) +from copy import ( + deepcopy, +) +from pathlib import ( + Path, +) +from typing import ( + TYPE_CHECKING, + Any, +) + +import numpy as np + +from deepmd.dpmodel.loss import ( + DOSLoss, + EnergyLoss, + PropertyLoss, + TensorLoss, +) +from deepmd.dpmodel.train import ( + DEFAULT_TASK_KEY, + AbstractTrainer, + RankContext, + TrainerConfig, + TrainingTask, + TrainingTaskCollection, + TrainStepResult, +) +from deepmd.dpmodel.utils.batch import ( + normalize_batch, + split_batch, +) +from deepmd.dpmodel.utils.learning_rate import ( + LearningRateExp, +) +from deepmd.dpmodel.utils.training_utils import ( + resolve_model_prob, +) +from deepmd.tf2.common import ( + to_tensorflow_array, + to_tf_tensor, + unwrap_value, +) +from deepmd.tf2.env import ( + tf, +) +from deepmd.tf2.model.model import ( + get_model, +) +from deepmd.tf2.utils.multi_task import ( + apply_shared_links, + sanitize_shared_links, +) +from deepmd.utils.argcheck import ( + resolve_full_validation_start_step, +) +from deepmd.utils.data import ( + DataRequirementItem, +) +from deepmd.utils.finetune import ( + warn_configuration_mismatch_during_finetune, +) +from deepmd.utils.model_stat import ( + make_stat_input, +) + +if TYPE_CHECKING: + from deepmd.utils.data_system import ( + DeepmdDataSystem, + ) + from deepmd.utils.path import ( + DPPath, + ) + +log = logging.getLogger(__name__) + +TF2_TRAINING_STATE_FILE = "training_state.json" + + +def get_loss( + loss_params: dict[str, Any], + start_lr: float, + _ntypes: int, + _model: Any, +) -> EnergyLoss | DOSLoss | TensorLoss | PropertyLoss: + """Build a dpmodel-compatible loss object for TF2 training.""" + loss_type = loss_params.get("type", "ener") + loss_params = dict(loss_params) + if loss_type == "ener": + loss_params["starter_learning_rate"] = start_lr + return EnergyLoss(**loss_params) + if loss_type == "dos": + loss_params["starter_learning_rate"] = start_lr + loss_params["numb_dos"] = _model.model_output_def()["dos"].output_size + return DOSLoss(**loss_params) + if loss_type == "tensor": + model_output_type = list(_model.model_output_type()) + if "mask" in model_output_type: + model_output_type.remove("mask") + tensor_name = model_output_type[0] + loss_params["tensor_size"] = _model.model_output_def()[tensor_name].output_size + loss_params["label_name"] = tensor_name + if tensor_name == "polarizability": + tensor_name = "polar" + loss_params["tensor_name"] = tensor_name + return TensorLoss(**loss_params) + if loss_type == "property": + loss_params["task_dim"] = _model.get_task_dim() + loss_params["var_name"] = _model.get_var_name() + loss_params["intensive"] = _model.get_intensive() + return PropertyLoss(**loss_params) + raise ValueError(f"Unsupported loss type for tf2: {loss_type}") + + +def get_additional_data_requirement(_model: Any) -> list[DataRequirementItem]: + """Return model-input data requirements not declared by the loss.""" + additional_data_requirement: list[DataRequirementItem] = [] + if _model.get_dim_fparam() > 0: + has_default_fparam = _model.has_default_fparam() + fparam_default = ( + np.asarray(_model.get_default_fparam()) if has_default_fparam else 0.0 + ) + additional_data_requirement.append( + DataRequirementItem( + "fparam", + _model.get_dim_fparam(), + atomic=False, + must=not has_default_fparam, + default=fparam_default, + ) + ) + if _model.get_dim_aparam() > 0: + additional_data_requirement.append( + DataRequirementItem( + "aparam", + _model.get_dim_aparam(), + atomic=True, + must=True, + ) + ) + if _model.has_chg_spin_ebd(): + has_default_cs = _model.has_default_chg_spin() + default_cs = ( + np.asarray(to_tf_tensor(_model.get_default_chg_spin()).numpy()) + if has_default_cs + else 0.0 + ) + additional_data_requirement.append( + DataRequirementItem( + "charge_spin", + ndof=2, + atomic=False, + must=not has_default_cs, + default=default_cs, + ) + ) + return additional_data_requirement + + +def _as_task_map( + value: Any, + *, + multi_task: bool, + model_keys: list[str], +) -> dict[str, Any]: + if isinstance(value, Mapping): + if all(model_key in value for model_key in model_keys): + return {model_key: value[model_key] for model_key in model_keys} + if multi_task: + return {model_key: value[model_key] for model_key in model_keys} + return {DEFAULT_TASK_KEY: value} + + +class _TaskModelContainer(tf.Module): + """Track task-keyed TF modules with stable attribute names.""" + + def __init__(self, models: Mapping[str, tf.Module]) -> None: + super().__init__(name="models") + self.task_keys = tuple(models) + for index, key in enumerate(self.task_keys): + setattr(self, f"task_{index}", models[key]) + + +class Trainer(AbstractTrainer): + """Training driver for TensorFlow 2 eager models.""" + + def __init__( + self, + config: dict[str, Any], + training_data: DeepmdDataSystem | Mapping[str, DeepmdDataSystem], + stat_file_path: DPPath | Mapping[str, DPPath | None] | None = None, + validation_data: DeepmdDataSystem + | Mapping[str, DeepmdDataSystem | None] + | None = None, + init_model: str | None = None, + restart_model: str | None = None, + finetune_model: str | None = None, + finetune_links: dict[str, Any] | None = None, + shared_links: dict[str, Any] | None = None, + min_nbor_dist: float | Mapping[str, float | None] | None = None, + ) -> None: + if finetune_model is not None and ( + init_model is not None or restart_model is not None + ): + raise ValueError( + "finetune_model cannot be combined with init_model or restart_model." + ) + if init_model is not None and restart_model is not None: + raise ValueError("init_model cannot be combined with restart_model.") + + self.config = config + self.init_model = init_model + self.restart_model = restart_model + self.finetune_model = finetune_model + self.finetune_links = finetune_links + self.shared_links = shared_links + self.restart_training = restart_model is not None + model_params = config["model"] + training_params = config["training"] + self.validating_params = config.get("validating", {}) or {} + self._validate_unsupported_config(config) + + self.multi_task = "model_dict" in model_params + self.model_def_script = deepcopy(model_params) + self.full_validation_state: dict[str, Any] = {} + self.model_keys = ( + list(model_params["model_dict"]) if self.multi_task else [DEFAULT_TASK_KEY] + ) + self.model_params_by_task = ( + { + model_key: model_params["model_dict"][model_key] + for model_key in self.model_keys + } + if self.multi_task + else {DEFAULT_TASK_KEY: model_params} + ) + self.training_data_by_task = _as_task_map( + training_data, + multi_task=self.multi_task, + model_keys=self.model_keys, + ) + self.validation_data_by_task = _as_task_map( + validation_data, + multi_task=self.multi_task, + model_keys=self.model_keys, + ) + self.stat_file_path_by_task = _as_task_map( + stat_file_path, + multi_task=self.multi_task, + model_keys=self.model_keys, + ) + + self.num_steps = int(training_params["numb_steps"]) + self.save_ckpt = str(training_params.get("save_ckpt", "model.ckpt")) + self.max_ckpt_keep = int(training_params.get("max_ckpt_keep", 5)) + self.gradient_max_norm = float(training_params.get("gradient_max_norm", 0.0)) + self.tensorboard = bool(training_params.get("tensorboard", False)) + self.tensorboard_log_dir = str( + training_params.get("tensorboard_log_dir", "log") + ) + self.tensorboard_freq = int(training_params.get("tensorboard_freq", 1)) + self.change_bias_after_training = bool( + training_params.get("change_bias_after_training", False) + ) + self.start_step = 0 + + self.models = { + model_key: get_model(deepcopy(self.model_params_by_task[model_key])) + for model_key in self.model_keys + } + self.set_min_nbor_dist(min_nbor_dist) + self.model = self.models if self.multi_task else self.models[DEFAULT_TASK_KEY] + + self.losses = {} + for model_key in self.model_keys: + loss_param = ( + config["loss_dict"][model_key] + if self.multi_task + else config.get("loss", {}) + ) + self.losses[model_key] = get_loss( + deepcopy(loss_param), + config["learning_rate"]["start_lr"], + len(self.model_params_by_task[model_key]["type_map"]), + self.models[model_key], + ) + self.loss = self.losses if self.multi_task else self.losses[DEFAULT_TASK_KEY] + + self.valid_numb_batch_by_task = {} + for model_key in self.model_keys: + data_requirement = list(self.losses[model_key].label_requirement) + data_requirement += get_additional_data_requirement(self.models[model_key]) + self.training_data_by_task[model_key].add_data_requirements( + data_requirement + ) + valid_data = self.validation_data_by_task[model_key] + if valid_data is not None: + valid_data.add_data_requirements(data_requirement) + valid_params = ( + training_params["data_dict"][model_key].get("validation_data", {}) + if self.multi_task + else training_params.get("validation_data", {}) + ) or {} + self.valid_numb_batch_by_task[model_key] = max( + int(valid_params.get("numb_btch", 1)), + 1, + ) + + self._sample_funcs = {} + for model_key in self.model_keys: + nbatch = int( + self.model_params_by_task[model_key].get("data_stat_nbatch", 10) + ) + train_data = self.training_data_by_task[model_key] + + @functools.lru_cache + def sample( + _data: DeepmdDataSystem = train_data, + _nbatch: int = nbatch, + ) -> list[dict[str, np.ndarray]]: + return make_stat_input(_data, _nbatch) + + self._sample_funcs[model_key] = sample + + if init_model is None and restart_model is None: + for model_key in self.model_keys: + finetune_has_new_type = ( + self.finetune_model is not None + and self.finetune_links is not None + and model_key in self.finetune_links + and self.finetune_links[model_key].get_has_new_type() + ) + if self.finetune_model is not None and not finetune_has_new_type: + continue + log.info( + "data stating for task %s... (this step may take long time)", + model_key, + ) + self.models[model_key].compute_or_load_stat( + self._sample_funcs[model_key], + stat_file_path=self.stat_file_path_by_task[model_key], + ) + + if self.finetune_model is not None: + self._apply_finetune() + self.model = ( + self.models if self.multi_task else self.models[DEFAULT_TASK_KEY] + ) + + self.model_prob = ( + resolve_model_prob( + self.model_keys, + training_params.get("model_prob"), + self.training_data_by_task, + ) + if self.multi_task + else None + ) + self._apply_shared_links( + resume=init_model is not None or restart_model is not None + ) + + lr_params = dict(config["learning_rate"]) + lr_params["num_steps"] = self.num_steps + self.lr_schedule = LearningRateExp(**lr_params) + self.optimizer = self._build_optimizer(config.get("optimizer", {})) + self.model_container = _TaskModelContainer(self.models) + self.step = tf.Variable(0, dtype=tf.int64, trainable=False, name="step") + self.checkpoint = tf.train.Checkpoint( + step=self.step, + optimizer=self.optimizer, + model=self.model_container, + ) + self.checkpoint_manager = tf.train.CheckpointManager( + self.checkpoint, + directory=self._checkpoint_directory(), + max_to_keep=self.max_ckpt_keep if self.max_ckpt_keep > 0 else None, + checkpoint_name=Path(self.save_ckpt).name, + ) + + if init_model is not None: + self._restore_model(init_model) + self.step.assign(0) + elif restart_model is not None: + self._restore_checkpoint(restart_model) + self.start_step = int(self.step.numpy()) + + self._build_optimizer_slots() + self._compiled_train_steps: dict[str, Any] = {} + self._compiled_eval_steps: dict[str, Any] = {} + self.training_tasks = self._make_training_tasks() + self.summary_writer: Any | None = None + self.full_validator: Any | None = None + super().__init__( + TrainerConfig.from_training_params( + training_params, + num_steps=self.num_steps, + start_step=self.start_step, + restart_training=self.restart_training, + ), + rank_context=RankContext(rank=0, world_size=1), + ) + self.full_validator = self._create_full_validator() + + def _validate_unsupported_config(self, config: Mapping[str, Any]) -> None: + training_params = config["training"] + unsupported_true_flags = { + "profiling": "profiling", + "enable_profiler": "TensorFlow profiler", + } + for flag_name, feature_name in unsupported_true_flags.items(): + if training_params.get(flag_name, False): + raise NotImplementedError( + f"TF2 training does not support {feature_name} yet." + ) + if training_params.get("mixed_precision") is not None: + raise NotImplementedError( + "TF2 training does not support mixed_precision yet." + ) + if (config.get("nvnmd", {}) or {}).get("enable", False): + raise NotImplementedError("TF2 training does not support NVNMD yet.") + if config["model"].get("modifier") is not None: + raise NotImplementedError( + "TF2 training does not support model.modifier yet." + ) + + def _create_full_validator(self) -> Any | None: + if not self._is_validation_requested("full_validation"): + return None + self._raise_if_full_validation_unsupported() + from deepmd.dpmodel.train.validation import ( + resolve_best_checkpoint_dir, + ) + from deepmd.tf2.train.validation import ( + TF2FullValidator, + ) + + return TF2FullValidator( + validating_params=self.validating_params, + validation_data=self.validation_data_by_task[DEFAULT_TASK_KEY], + model=self.models[DEFAULT_TASK_KEY], + state_store=self.full_validation_state, + num_steps=self.num_steps, + rank=0, + restart_training=self.restart_training, + checkpoint_dir=resolve_best_checkpoint_dir( + self.validating_params, self.save_ckpt + ), + ) + + def _is_validation_requested(self, flag_name: str) -> bool: + if not self.validating_params.get(flag_name, False): + return False + start_step = resolve_full_validation_start_step( + self.validating_params.get("full_val_start", 0.5), + self.num_steps, + ) + return start_step is not None and start_step <= self.num_steps + + def _raise_if_full_validation_unsupported(self) -> None: + if self.multi_task: + raise ValueError( + "validating.full_validation only supports single-task energy " + "training; multi-task training is not supported." + ) + if not isinstance(self.loss, EnergyLoss): + raise ValueError( + "validating.full_validation only supports single-task energy training." + ) + if self.validation_data_by_task[DEFAULT_TASK_KEY] is None: + raise ValueError( + "validating.full_validation requires `training.validation_data` " + "to be configured." + ) + + def _build_optimizer(self, optimizer_params: Mapping[str, Any]) -> Any: + optimizer_type = optimizer_params.get("type", "Adam") + beta1 = float(optimizer_params.get("adam_beta1", 0.9)) + beta2 = float(optimizer_params.get("adam_beta2", 0.999)) + weight_decay = float(optimizer_params.get("weight_decay", 0.0)) + learning_rate = float(self.lr_schedule.value(self.start_step)) + if optimizer_type == "Adam": + if weight_decay != 0.0: + raise RuntimeError( + "TF2 Adam optimizer does not support weight_decay. " + "Set optimizer/weight_decay to 0 or use AdamW." + ) + return tf.keras.optimizers.Adam( + learning_rate=learning_rate, + beta_1=beta1, + beta_2=beta2, + ) + if optimizer_type == "AdamW": + if not hasattr(tf.keras.optimizers, "AdamW"): + raise RuntimeError("This TensorFlow version does not provide AdamW.") + return tf.keras.optimizers.AdamW( + learning_rate=learning_rate, + beta_1=beta1, + beta_2=beta2, + weight_decay=weight_decay, + ) + raise ValueError(f"Unsupported optimizer type for tf2: {optimizer_type}") + + def _build_optimizer_slots(self) -> None: + variables = _unique_variables(self.model_container.trainable_variables) + build = getattr(self.optimizer, "build", None) + if callable(build): + build(variables) + + def _checkpoint_directory(self) -> str: + return str(Path(f"{self.save_ckpt}.tf2")) + + def _resolve_checkpoint_path(self, checkpoint_path: str) -> str: + path = Path(checkpoint_path) + candidates = [path] + if not str(path).endswith(".tf2"): + candidates.append(Path(f"{checkpoint_path}.tf2")) + for candidate in candidates: + if candidate.is_dir(): + latest = tf.train.latest_checkpoint(str(candidate)) + if latest is not None: + return latest + if path.exists() or Path(f"{path}.index").exists(): + return str(path) + raise FileNotFoundError( + f"Cannot find TF2 checkpoint {checkpoint_path!r}. Expected a " + "CheckpointManager directory or a checkpoint prefix." + ) + + def _restore_model(self, checkpoint_path: str) -> None: + resolved = self._resolve_checkpoint_path(checkpoint_path) + model_checkpoint = tf.train.Checkpoint(model=self.model_container) + model_checkpoint.restore(resolved).expect_partial() + log.info("Initialized TF2 model variables from %s", resolved) + + def _restore_checkpoint(self, checkpoint_path: str) -> None: + resolved = self._resolve_checkpoint_path(checkpoint_path) + self.checkpoint.restore(resolved).expect_partial() + log.info( + "Restarted TF2 training from %s at step %d", resolved, self.step.numpy() + ) + + @staticmethod + def _model_params_by_task( + model_params: dict[str, Any], + ) -> dict[str, dict[str, Any]]: + if "model_dict" in model_params: + return { + model_key: model_params["model_dict"][model_key] + for model_key in model_params["model_dict"] + } + return {DEFAULT_TASK_KEY: model_params} + + @staticmethod + def _deserialize_models(model_data: dict[str, Any]) -> dict[str, Any]: + from deepmd.tf2.model.base_model import ( + BaseModel, + ) + + if "model_dict" in model_data["model_def_script"]: + return { + model_key: BaseModel.deserialize( + model_data["model"]["model_dict"][model_key] + ) + for model_key in model_data["model_def_script"]["model_dict"] + } + return {DEFAULT_TASK_KEY: BaseModel.deserialize(model_data["model"])} + + def set_min_nbor_dist( + self, + min_nbor_dist: float | Mapping[str, float | None] | None, + ) -> None: + if min_nbor_dist is None: + return + if isinstance(min_nbor_dist, Mapping): + for model_key, value in min_nbor_dist.items(): + if value is not None and model_key in self.models: + self.models[model_key].min_nbor_dist = float(value) + return + self.models[DEFAULT_TASK_KEY].min_nbor_dist = float(min_nbor_dist) + + def _apply_finetune(self) -> None: + if self.finetune_model is None or self.finetune_links is None: + return + from deepmd.tf2.utils.serialization import ( + serialize_from_file, + ) + + pretrained_data = serialize_from_file(self.finetune_model) + pretrained_params = pretrained_data["model_def_script"] + pretrained_models = self._deserialize_models(pretrained_data) + for model_key in self.model_keys: + finetune_rule = self.finetune_links[model_key] + source_key = finetune_rule.get_model_branch() + if source_key not in pretrained_models: + raise ValueError( + f"Pretrained model branch {source_key!r} is not available." + ) + source_model = pretrained_models[source_key] + if finetune_rule.get_finetune_tmap() != source_model.get_type_map(): + model_with_new_type_stat = ( + self.models[model_key] if finetune_rule.get_has_new_type() else None + ) + source_model.change_type_map( + finetune_rule.get_finetune_tmap(), + model_with_new_type_stat=model_with_new_type_stat, + ) + self._warn_finetune_config_mismatch( + model_key, source_key, pretrained_params + ) + self.models[model_key] = self._copy_finetune_state( + self.models[model_key], + source_model, + random_fitting=finetune_rule.get_random_fitting(), + ) + if finetune_rule.get_resuming(): + log.info("Model branch %s will resume training.", model_key) + continue + bias_mode = ( + "change-by-statistic" + if not finetune_rule.get_random_fitting() + else "set-by-statistic" + ) + self.models[model_key].change_out_bias( + self._sample_funcs[model_key], + bias_adjust_mode=bias_mode, + ) + + def _apply_shared_links(self, *, resume: bool) -> None: + if self.shared_links is None: + return + model_key_prob_map = ( + { + model_key: float(prob) + for model_key, prob in zip( + self.model_keys, + self.model_prob, + strict=True, + ) + } + if self.model_prob is not None + else dict.fromkeys(self.model_keys, 1.0) + ) + apply_shared_links( + self.models, + self.shared_links, + model_key_prob_map=model_key_prob_map, + resume=resume, + ) + self.model = self.models if self.multi_task else self.models[DEFAULT_TASK_KEY] + + def _warn_finetune_config_mismatch( + self, + model_key: str, + source_key: str, + pretrained_params: dict[str, Any], + ) -> None: + input_model_params = self.model_params_by_task[model_key] + branch_pretrained_params = ( + pretrained_params["model_dict"][source_key] + if "model_dict" in pretrained_params + else pretrained_params + ) + if ( + "descriptor" in input_model_params + and "descriptor" in branch_pretrained_params + ): + warn_configuration_mismatch_during_finetune( + input_model_params["descriptor"], + branch_pretrained_params["descriptor"], + source_key, + ) + + @staticmethod + def _copy_finetune_state( + target_model: Any, + source_model: Any, + *, + random_fitting: bool, + ) -> Any: + from deepmd.tf2.model.base_model import ( + BaseModel, + ) + + copied = _copy_matching_state_tree( + target_model.serialize(), + source_model.serialize(), + random_fitting=random_fitting, + ) + return BaseModel.deserialize(copied) + + def _make_training_tasks(self) -> TrainingTaskCollection: + return TrainingTaskCollection( + [ + TrainingTask( + key=model_key, + training_data=self.training_data_by_task[model_key], + validation_data=self.validation_data_by_task[model_key], + valid_numb_batch=self.valid_numb_batch_by_task[model_key], + ) + for model_key in self.model_keys + ], + probabilities=self.model_prob, + ) + + def run(self) -> None: + """Run TF2 training through the backend-independent trainer loop.""" + log.info("Start to train %d steps.", self.num_steps) + wall_start = time.time() + super().run(self.training_tasks) + if self.change_bias_after_training: + self._change_bias_after_training() + if self.rank_context.is_chief: + self.save_checkpoint(self.num_steps) + log.info("Training finished. Total wall time: %.2fs", time.time() - wall_start) + + def on_train_begin(self, tasks: TrainingTaskCollection) -> None: + del tasks + if self.tensorboard and self.rank_context.is_chief: + self.summary_writer = tf.summary.create_file_writer( + self.tensorboard_log_dir + ) + + def on_train_end(self, tasks: TrainingTaskCollection) -> None: + del tasks + if self.summary_writer is not None: + self.summary_writer.close() + self.summary_writer = None + + def select_task(self, tasks: TrainingTaskCollection) -> TrainingTask: + if not tasks.is_multitask: + return tasks[tasks.keys[0]] + from deepmd.utils import random as dp_random + + model_index = dp_random.choice( + np.arange(len(tasks), dtype=np.int_), + p=tasks.probabilities, + ) + return tasks[tasks.keys[int(model_index)]] + + def train_step(self, task: TrainingTask, step: int) -> TrainStepResult: + """Run one TensorFlow optimizer step.""" + task_key = task.key + cur_lr = float(self.lr_schedule.value(step)) + input_dict, label_dict, natoms = self.get_data(is_train=True, task_key=task_key) + more_loss = self._compiled_train_step( + task_key, + input_dict, + label_dict, + tf.constant(float(natoms), dtype=tf.float64), + tf.constant(cur_lr, dtype=tf.float64), + tf.constant(step + 1, dtype=tf.int64), + ) + self._write_tensorboard_step( + task_key, + display_step=step + 1, + learning_rate=cur_lr, + more_loss=more_loss, + ) + return TrainStepResult( + task_key=task_key, + step=step, + payload={ + "more_loss": more_loss, + "cur_lr": cur_lr, + }, + ) + + def _compiled_train_step( + self, + task_key: str, + input_dict: dict[str, Any], + label_dict: dict[str, Any], + natoms: Any, + cur_lr: Any, + next_step: Any, + ) -> dict[str, Any]: + if task_key not in self._compiled_train_steps: + self._compiled_train_steps[task_key] = self._make_compiled_train_step( + task_key + ) + return self._compiled_train_steps[task_key]( + input_dict, + label_dict, + natoms, + cur_lr, + next_step, + ) + + def _make_compiled_train_step(self, task_key: str) -> Any: + variables = _unique_variables(self.models[task_key].trainable_variables) + + @tf.function(reduce_retracing=True) + def compiled_train_step( + input_dict: dict[str, Any], + label_dict: dict[str, Any], + natoms: Any, + cur_lr: Any, + next_step: Any, + ) -> dict[str, Any]: + self._assign_learning_rate(cur_lr) + with tf.GradientTape() as tape: + model_pred = self._call_model(task_key, input_dict) + loss, more_loss = self.losses[task_key]( + learning_rate=cur_lr, + natoms=natoms, + model_dict=model_pred, + label_dict=label_dict, + ) + loss_tensor = to_tf_tensor(loss) + gradients = tape.gradient(loss_tensor, variables) + gradients_and_variables = [ + (grad, var) + for grad, var in zip(gradients, variables, strict=True) + if grad is not None + ] + if self.gradient_max_norm > 0.0 and gradients_and_variables: + grads, vars_ = zip(*gradients_and_variables, strict=True) + grads, _ = tf.clip_by_global_norm(grads, self.gradient_max_norm) + gradients_and_variables = list(zip(grads, vars_, strict=True)) + self.optimizer.apply_gradients(gradients_and_variables) + self.step.assign(next_step) + return unwrap_value(more_loss) + + return compiled_train_step + + def evaluate_training( + self, + task: TrainingTask, + step: int, + step_result: TrainStepResult | None, + ) -> dict[str, float]: + if step_result is not None and step_result.task_key == task.key: + return self._more_loss_to_float(step_result.payload["more_loss"]) + input_dict, label_dict, natoms = self.get_data(is_train=True, task_key=task.key) + return self._evaluate_batch(task.key, step, input_dict, label_dict, natoms) + + def evaluate_validation( + self, + task: TrainingTask, + step: int, + step_result: TrainStepResult | None, + ) -> dict[str, float] | None: + if task.validation_data is None: + return None + valid_results: dict[str, float] = {} + sum_natoms = 0 + for _ii in range(task.valid_numb_batch): + input_dict, label_dict, natoms = self.get_data( + is_train=False, + task_key=task.key, + ) + results = self._evaluate_batch( + task.key, + step, + input_dict, + label_dict, + natoms, + ) + sum_natoms += natoms + for key, value in results.items(): + valid_results[key] = valid_results.get(key, 0.0) + value * natoms + if sum_natoms == 0: + return valid_results + return {key: value / sum_natoms for key, value in valid_results.items()} + + def _evaluate_batch( + self, + task_key: str, + step: int, + input_dict: dict[str, Any], + label_dict: dict[str, Any], + natoms: int, + ) -> dict[str, float]: + cur_lr = float(self.lr_schedule.value(step)) + return self._compiled_eval_step( + task_key, + input_dict, + label_dict, + tf.constant(float(natoms), dtype=tf.float64), + tf.constant(cur_lr, dtype=tf.float64), + ) + + def _compiled_eval_step( + self, + task_key: str, + input_dict: dict[str, Any], + label_dict: dict[str, Any], + natoms: Any, + cur_lr: Any, + ) -> dict[str, float]: + if task_key not in self._compiled_eval_steps: + self._compiled_eval_steps[task_key] = self._make_compiled_eval_step( + task_key + ) + more_loss = self._compiled_eval_steps[task_key]( + input_dict, + label_dict, + natoms, + cur_lr, + ) + return self._more_loss_to_float(more_loss) + + def _make_compiled_eval_step(self, task_key: str) -> Any: + @tf.function(reduce_retracing=True) + def compiled_eval_step( + input_dict: dict[str, Any], + label_dict: dict[str, Any], + natoms: Any, + cur_lr: Any, + ) -> dict[str, Any]: + model_pred = self._call_model(task_key, input_dict) + _, more_loss = self.losses[task_key]( + learning_rate=cur_lr, + natoms=natoms, + model_dict=model_pred, + label_dict=label_dict, + ) + return unwrap_value(more_loss) + + return compiled_eval_step + + def learning_rate(self, step: int) -> float: + return float(self.lr_schedule.value(step)) + + def save_checkpoint(self, step: int) -> None: + self.step.assign(step) + save_path = self.checkpoint_manager.save(checkpoint_number=step) + self._write_training_state(Path(self._checkpoint_directory()), step=step) + log.info("Saved TF2 checkpoint to %s", save_path) + + def run_full_validation( + self, + *, + step: int, + display_step: int, + learning_rate: float, + ) -> None: + if self.full_validator is None: + return None + self.full_validator.model = self.models[DEFAULT_TASK_KEY] + self.full_validator.run( + step_id=display_step, + display_step=display_step, + lr=learning_rate, + save_checkpoint=self._save_full_validation_checkpoint, + ) + return None + + def _save_full_validation_checkpoint( + self, + save_path: Path, + lr: float = 0.0, + step: int = 0, + ) -> None: + del lr + self._write_checkpoint_directory(save_path, step=step) + + def _write_checkpoint_directory(self, directory: Path, *, step: int) -> None: + self.step.assign(step) + if directory.exists(): + shutil.rmtree(directory) + manager = tf.train.CheckpointManager( + self.checkpoint, + directory=str(directory), + max_to_keep=1, + checkpoint_name=directory.stem, + ) + manager.save(checkpoint_number=step) + self._write_training_state(directory, step=step) + + def _write_training_state(self, directory: Path, *, step: int) -> None: + directory.mkdir(parents=True, exist_ok=True) + state = { + "backend": "TensorFlow2", + "format_version": 1, + "current_step": int(step), + "model_def_script": deepcopy(self.model_def_script), + "shared_links": sanitize_shared_links(self.shared_links), + "multi_task": self.multi_task, + "model_keys": list(self.model_keys), + "min_nbor_dist": self._current_min_nbor_dist(), + "full_validation": deepcopy(self.full_validation_state), + } + with (directory / TF2_TRAINING_STATE_FILE).open("w") as fp: + json.dump(state, fp, indent=2) + + def _current_min_nbor_dist(self) -> Any: + values = {} + for model_key in self.model_keys: + value = self.models[model_key].get_min_nbor_dist() + values[model_key] = None if value is None else float(value) + if self.multi_task: + return values + return values[DEFAULT_TASK_KEY] + + def _write_tensorboard_step( + self, + task_key: str, + *, + display_step: int, + learning_rate: float, + more_loss: dict[str, Any], + ) -> None: + if ( + self.summary_writer is None + or self.tensorboard_freq <= 0 + or display_step % self.tensorboard_freq != 0 + ): + return + prefix = f"train/{task_key}" if self.multi_task else "train" + with self.summary_writer.as_default(): + tf.summary.scalar("learning_rate", learning_rate, step=display_step) + for key, value in self._more_loss_to_float(more_loss).items(): + tf.summary.scalar(f"{prefix}/{key}", value, step=display_step) + self.summary_writer.flush() + + def _change_bias_after_training(self) -> None: + log.info("Changing output bias after training.") + for model_key in self.model_keys: + self.models[model_key].change_out_bias( + self._sample_funcs[model_key], + bias_adjust_mode="change-by-statistic", + ) + + def get_data( + self, + *, + is_train: bool, + task_key: str, + ) -> tuple[dict[str, Any], dict[str, Any], int]: + task_key = task_key if self.multi_task else DEFAULT_TASK_KEY + data_sys = ( + self.training_data_by_task[task_key] + if is_train + else self.validation_data_by_task[task_key] + ) + if data_sys is None: + return {}, {}, 0 + batch = normalize_batch(data_sys.get_batch()) + input_dict, label_dict = split_batch(batch) + for opt_key in ("fparam", "charge_spin"): + find_key = f"find_{opt_key}" + if ( + opt_key in input_dict + and find_key in label_dict + and not bool(label_dict[find_key]) + ): + input_dict.pop(opt_key) + natoms = int(input_dict["atype"].shape[1]) + label_dict["type"] = input_dict["atype"] + input_tf = { + key: self._to_input_tensor(key, value) for key, value in input_dict.items() + } + label_tf = { + key: self._to_label_array(key, value) for key, value in label_dict.items() + } + return input_tf, label_tf, natoms + + def _call_model( + self, + task_key: str, + input_dict: dict[str, Any], + ) -> dict[str, Any]: + return self.models[task_key].call( + input_dict["coord"], + input_dict["atype"], + box=input_dict.get("box"), + fparam=input_dict.get("fparam"), + aparam=input_dict.get("aparam"), + charge_spin=input_dict.get("charge_spin"), + ) + + @staticmethod + def _to_input_tensor(key: str, value: Any) -> Any: + if value is None: + return None + if key == "atype": + return tf.convert_to_tensor(value, dtype=tf.int32) + if isinstance(value, np.ndarray) and np.issubdtype(value.dtype, np.integer): + return tf.convert_to_tensor(value, dtype=tf.int64) + return tf.convert_to_tensor(value, dtype=tf.float64) + + @staticmethod + def _to_label_array(key: str, value: Any) -> Any: + if value is None: + return None + if key in {"type", "natoms"}: + return to_tensorflow_array(tf.convert_to_tensor(value, dtype=tf.int32)) + if key.startswith("find_"): + return to_tensorflow_array( + tf.constant(1.0 if bool(value) else 0.0, dtype=tf.float64) + ) + if isinstance(value, np.ndarray) and np.issubdtype(value.dtype, np.integer): + return to_tensorflow_array(tf.convert_to_tensor(value, dtype=tf.int32)) + return to_tensorflow_array(tf.convert_to_tensor(value, dtype=tf.float64)) + + def _assign_learning_rate(self, learning_rate: float) -> None: + lr_attr = self.optimizer.learning_rate + if hasattr(lr_attr, "assign"): + lr_attr.assign(learning_rate) + else: + self.optimizer.learning_rate = learning_rate + + @staticmethod + def _to_float(value: Any) -> float: + tensor = to_tf_tensor(value) + if tensor is not None: + return float(tensor.numpy()) + return float(value) + + @classmethod + def _more_loss_to_float(cls, more_loss: dict[str, Any]) -> dict[str, float]: + return { + key: cls._to_float(value) + for key, value in more_loss.items() + if "l2_" not in key + } + + +DPTrainer = Trainer + + +def _copy_matching_state_tree( + target: Any, + source: Any, + *, + random_fitting: bool, + path: tuple[Any, ...] = (), +) -> Any: + if isinstance(target, dict): + if not isinstance(source, dict): + return target + return { + key: _copy_matching_state_tree( + value, + source.get(key), + random_fitting=random_fitting, + path=(*path, key), + ) + for key, value in target.items() + } + if source is None: + return target + if random_fitting and not any("descriptor" in str(part) for part in path): + return target + if _same_state_leaf(target, source): + return source + return target + + +def _same_state_leaf(target: Any, source: Any) -> bool: + target_array = np.asarray(target) if _is_array_like(target) else None + source_array = np.asarray(source) if _is_array_like(source) else None + if target_array is None or source_array is None: + return False + return ( + target_array.shape == source_array.shape + and target_array.dtype == source_array.dtype + ) + + +def _is_array_like(value: Any) -> bool: + if isinstance(value, (str, bytes)): + return False + return hasattr(value, "shape") and hasattr(value, "dtype") + + +def _unique_variables(variables: list[Any] | tuple[Any, ...]) -> list[Any]: + unique = [] + seen: set[int] = set() + for variable in variables: + variable_id = id(variable) + if variable_id in seen: + continue + seen.add(variable_id) + unique.append(variable) + return unique diff --git a/deepmd/tf2/train/validation.py b/deepmd/tf2/train/validation.py new file mode 100644 index 0000000000..999a12b925 --- /dev/null +++ b/deepmd/tf2/train/validation.py @@ -0,0 +1,212 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Full validation support for the TensorFlow 2 trainer.""" + +from __future__ import ( + annotations, +) + +from typing import ( + TYPE_CHECKING, + Any, +) + +import numpy as np + +from deepmd.dpmodel.train.validation import ( + LOG_COLUMN_ORDER, + FullValidatorBase, +) +from deepmd.tf2.common import ( + to_tf_tensor, +) +from deepmd.tf2.env import ( + tf, +) +from deepmd.tf2.utils.auto_batch_size import ( + AutoBatchSize, +) +from deepmd.utils.eval_metrics import ( + FULL_VALIDATION_WEIGHTED_METRIC_KEYS, + compute_energy_type_metrics, +) +from deepmd.utils.weight_avg import ( + weighted_average, +) + +if TYPE_CHECKING: + from deepmd.tf2.model.base_model import ( + BaseModel, + ) + + +class TF2FullValidator(FullValidatorBase): + """Run full validation for a single-task TF2 energy model.""" + + def __init__( + self, + *, + validating_params: dict[str, Any], + validation_data: Any, + model: BaseModel, + state_store: dict[str, Any], + num_steps: int, + rank: int, + restart_training: bool, + checkpoint_dir: Any = None, + ) -> None: + self.validation_data = validation_data + self.model = model + self.auto_batch_size = AutoBatchSize(silent=True) + super().__init__( + validating_params=validating_params, + state_store=state_store, + num_steps=num_steps, + rank=rank, + restart_training=restart_training, + checkpoint_dir=checkpoint_dir, + best_checkpoint_suffix=".tf2", + ) + + def evaluate_all_systems(self) -> dict[str, float]: + """Evaluate every validation system and aggregate metrics.""" + system_metrics = [ + self._evaluate_system(data_system) + for data_system in self._iter_validation_data_systems() + ] + aggregated = weighted_average([metric for metric in system_metrics if metric]) + return { + metric_key: float(aggregated[metric_key]) + for _, metric_key in LOG_COLUMN_ORDER + if metric_key in aggregated + } + + def _iter_validation_data_systems(self) -> Any: + validation_data = self.validation_data + if hasattr(validation_data, "data_systems"): + yield from validation_data.data_systems + return + if hasattr(validation_data, "get_test"): + yield validation_data + return + if hasattr(validation_data, "systems"): + for dataset in validation_data.systems: + yield getattr(dataset, "data_system", dataset) + return + raise TypeError( + "TF2 full validation expects a DeepmdDataSystem, DeepmdData-like " + f"object, or loader set; got {type(validation_data)!r}." + ) + + def _evaluate_system(self, data_system: Any) -> dict[str, tuple[float, float]]: + test_data = data_system.get_test() + natoms = int(test_data["type"].shape[1]) + nframes = int(test_data["coord"].shape[0]) + has_pbc = bool(getattr(data_system, "pbc", False)) + include_virial = has_pbc and bool(test_data.get("find_virial", 0.0)) + prediction = self._predict_outputs( + coord=test_data["coord"].reshape(nframes, -1), + atom_types=test_data["type"], + box=test_data["box"] if has_pbc else None, + fparam=test_data["fparam"] + if self.model.get_dim_fparam() > 0 + and bool(test_data.get("find_fparam", 0.0)) + else None, + aparam=test_data["aparam"] if self.model.get_dim_aparam() > 0 else None, + include_virial=include_virial, + natoms=natoms, + nframes=nframes, + ) + shared_metrics = compute_energy_type_metrics( + prediction=prediction, + test_data=test_data, + natoms=natoms, + has_pbc=has_pbc, + ) + return shared_metrics.as_weighted_average_errors( + FULL_VALIDATION_WEIGHTED_METRIC_KEYS + ) + + def _predict_outputs( + self, + *, + coord: np.ndarray, + atom_types: np.ndarray, + box: np.ndarray | None, + fparam: np.ndarray | None, + aparam: np.ndarray | None, + include_virial: bool, + natoms: int, + nframes: int, + ) -> dict[str, np.ndarray]: + """Predict energy, force, and virial for the full validation batch.""" + + def predict_batch( + coord_batch: np.ndarray, + atom_types_batch: np.ndarray, + box_batch: np.ndarray | None, + fparam_batch: np.ndarray | None, + aparam_batch: np.ndarray | None, + ) -> dict[str, np.ndarray]: + batch_nframes = coord_batch.shape[0] + model_output = self.model.call( + tf.convert_to_tensor( + coord_batch.reshape(batch_nframes, -1), tf.float64 + ), + tf.convert_to_tensor(atom_types_batch, tf.int32), + box=tf.convert_to_tensor( + box_batch.reshape(batch_nframes, 9), tf.float64 + ) + if box_batch is not None + else None, + fparam=tf.convert_to_tensor( + fparam_batch.reshape(batch_nframes, self.model.get_dim_fparam()), + tf.float64, + ) + if fparam_batch is not None + else None, + aparam=tf.convert_to_tensor( + aparam_batch.reshape( + batch_nframes, natoms, self.model.get_dim_aparam() + ), + tf.float64, + ) + if aparam_batch is not None + else None, + do_atomic_virial=include_virial, + ) + prediction = { + "energy": np.asarray(to_tf_tensor(model_output["energy"])).reshape( + -1, 1 + ), + "force": np.asarray(to_tf_tensor(model_output["force"])).reshape( + -1, natoms * 3 + ), + } + if include_virial: + if "virial" not in model_output: + raise KeyError( + "Full validation requested virial metrics, but model " + "output does not contain `virial`." + ) + prediction["virial"] = np.asarray( + to_tf_tensor(model_output["virial"]) + ).reshape(-1, 9) + return prediction + + batch_prediction = self.auto_batch_size.execute_all( + predict_batch, + nframes, + natoms, + coord, + atom_types, + box, + fparam, + aparam, + ) + prediction = { + "energy": np.asarray(batch_prediction["energy"]), + "force": np.asarray(batch_prediction["force"]), + } + if include_virial: + prediction["virial"] = np.asarray(batch_prediction["virial"]) + return prediction diff --git a/deepmd/tf2/utils/auto_batch_size.py b/deepmd/tf2/utils/auto_batch_size.py new file mode 100644 index 0000000000..3036a2a465 --- /dev/null +++ b/deepmd/tf2/utils/auto_batch_size.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Automatic inference batch sizing for TensorFlow 2.""" + +from deepmd.tf2.env import ( + tf, +) +from deepmd.utils.batch_size import AutoBatchSize as AutoBatchSizeBase + + +class AutoBatchSize(AutoBatchSizeBase): + """Auto batch size helper for TF2 eager inference.""" + + def is_gpu_available(self) -> bool: + """Return whether a GPU is visible to TensorFlow.""" + return bool(tf.config.list_physical_devices("GPU")) + + def is_oom_error(self, e: Exception) -> bool: + """Return whether an exception is TensorFlow's OOM signal.""" + return isinstance(e, tf.errors.ResourceExhaustedError) diff --git a/deepmd/tf2/utils/finetune.py b/deepmd/tf2/utils/finetune.py new file mode 100644 index 0000000000..65ca32a938 --- /dev/null +++ b/deepmd/tf2/utils/finetune.py @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Fine-tuning config utilities for the TensorFlow 2 backend.""" + +from __future__ import ( + annotations, +) + +from typing import ( + Any, +) + +from deepmd.tf2.utils.serialization import ( + serialize_from_file, +) +from deepmd.utils.finetune import ( + FinetuneRuleItem, + get_finetune_rules_from_model_params, +) + + +def _load_model_params(finetune_model: str) -> dict[str, Any]: + """Extract model params from a TF2 training checkpoint.""" + return serialize_from_file(finetune_model)["model_def_script"] + + +def get_finetune_rules( + finetune_model: str, + model_config: dict[str, Any], + model_branch: str = "", + change_model_params: bool = True, +) -> tuple[dict[str, Any], dict[str, FinetuneRuleItem]]: + """Build TF2 fine-tuning rules for single-task or multi-task configs.""" + return get_finetune_rules_from_model_params( + _load_model_params(finetune_model), + model_config, + model_branch=model_branch, + change_model_params=change_model_params, + ) diff --git a/deepmd/tf2/utils/multi_task.py b/deepmd/tf2/utils/multi_task.py new file mode 100644 index 0000000000..a2f1b3111f --- /dev/null +++ b/deepmd/tf2/utils/multi_task.py @@ -0,0 +1,410 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Multi-task sharing helpers for the TensorFlow 2 backend.""" + +from __future__ import ( + annotations, +) + +import logging +from copy import ( + deepcopy, +) +from typing import ( + Any, +) + +import array_api_compat +import numpy as np + +from deepmd.dpmodel.utils.env_mat_stat import ( + merge_env_stat, +) +from deepmd.tf2.descriptor.base_descriptor import ( + BaseDescriptor, +) +from deepmd.tf2.env import ( + tf, + xp, +) +from deepmd.tf2.fitting.base_fitting import ( + BaseFitting, +) + +log = logging.getLogger(__name__) + + +def preprocess_shared_params( + model_config: dict[str, Any], +) -> tuple[dict[str, Any], dict[str, Any]]: + """Expand ``shared_dict`` references and generate runtime sharing links.""" + assert "model_dict" in model_config, "only multi-task model can use this method!" + supported_types = ["type_map", "descriptor", "fitting_net"] + shared_dict = model_config.get("shared_dict", {}) + shared_links: dict[str, Any] = {} + type_map_keys: list[str] = [] + + def replace_one_item( + params_dict: dict[str, Any] | list[Any], + model_key: str, + key_type: str, + key_in_dict: str, + suffix: str = "", + index: int | None = None, + ) -> None: + shared_type = key_type + shared_key = key_in_dict + shared_level = 0 + if ":" in key_in_dict: + shared_key = key_in_dict.split(":")[0] + shared_level = int(key_in_dict.split(":")[1]) + assert shared_key in shared_dict, ( + f"Appointed {shared_type} {shared_key} are not in the shared_dict! " + "Please check the input params." + ) + if index is None: + assert isinstance(params_dict, dict) + params_dict[shared_type] = deepcopy(shared_dict[shared_key]) + else: + params_dict[index] = deepcopy(shared_dict[shared_key]) + if shared_type == "type_map": + if key_in_dict not in type_map_keys: + type_map_keys.append(key_in_dict) + else: + if shared_key not in shared_links: + shared_links[shared_key] = { + "type": get_class_name(shared_type, shared_dict[shared_key]), + "links": [], + } + shared_links[shared_key]["links"].append( + { + "model_key": model_key, + "shared_type": shared_type + suffix, + "shared_level": shared_level, + } + ) + + for model_key in model_config["model_dict"]: + model_params_item = model_config["model_dict"][model_key] + for item_key in model_params_item: + if item_key not in supported_types: + continue + item_params = model_params_item[item_key] + if isinstance(item_params, str): + replace_one_item(model_params_item, model_key, item_key, item_params) + elif ( + item_key == "descriptor" + and isinstance(item_params, dict) + and item_params.get("type", "") == "hybrid" + ): + for ii, hybrid_item in enumerate(item_params["list"]): + if isinstance(hybrid_item, str): + replace_one_item( + model_params_item[item_key]["list"], + model_key, + item_key, + hybrid_item, + suffix=f"_hybrid_{ii}", + index=ii, + ) + + for shared_key in shared_links: + shared_links[shared_key]["links"] = sorted( + shared_links[shared_key]["links"], + key=lambda x: ( + x["shared_level"] + - ("spin" in model_config["model_dict"][x["model_key"]]) * 100 + ), + ) + assert len(type_map_keys) == 1, "Multitask model must have only one type_map!" + return model_config, shared_links + + +def get_class_name(item_key: str, item_params: dict[str, Any]) -> type: + if item_key == "descriptor": + return BaseDescriptor.get_class_by_type(item_params.get("type", "se_e2_a")) + if item_key == "fitting_net": + return BaseFitting.get_class_by_type(item_params.get("type", "ener")) + raise RuntimeError(f"Unknown class_name type {item_key}") + + +def sanitize_shared_links(shared_links: dict[str, Any] | None) -> dict[str, Any] | None: + """Return a JSON-safe copy of ``shared_links``.""" + if shared_links is None: + return None + sanitized: dict[str, Any] = {} + for shared_key, shared_info in shared_links.items(): + class_type = shared_info.get("type") + sanitized[shared_key] = { + "type": getattr(class_type, "__name__", str(class_type)), + "links": deepcopy(shared_info.get("links", [])), + } + return sanitized + + +def apply_shared_links( + models: dict[str, Any], + shared_links: dict[str, Any] | None, + *, + model_key_prob_map: dict[str, float] | None = None, + data_stat_protect: float = 1e-2, + resume: bool = False, +) -> None: + """Share TF2 model parameters according to ``shared_links``.""" + if not shared_links: + return + if model_key_prob_map is None: + model_key_prob_map = dict.fromkeys(models, 1.0) + + for shared_item, shared_info in shared_links.items(): + links = shared_info.get("links", []) + if len(links) < 2: + continue + shared_base = links[0] + class_type_base = shared_base["shared_type"] + model_key_base = shared_base["model_key"] + shared_level_base = int(shared_base["shared_level"]) + if "descriptor" in class_type_base: + base_class = _get_descriptor(models[model_key_base], class_type_base) + for link_item in links[1:]: + class_type_link = link_item["shared_type"] + model_key_link = link_item["model_key"] + shared_level_link = int(link_item["shared_level"]) + assert shared_level_link >= shared_level_base, ( + "The shared_links must be sorted by shared_level!" + ) + assert "descriptor" in class_type_link, ( + f"Class type mismatched: {class_type_base} vs {class_type_link}!" + ) + link_class = _get_descriptor(models[model_key_link], class_type_link) + frac_prob = _model_prob_ratio( + model_key_prob_map, model_key_base, model_key_link + ) + _share_descriptor( + models[model_key_link], + class_type_link, + link_class, + base_class, + shared_level_link, + frac_prob, + resume=resume, + ) + log.warning( + "Shared params of %s.%s and %s.%s!", + model_key_base, + class_type_base, + model_key_link, + class_type_link, + ) + else: + if not hasattr(models[model_key_base].atomic_model, class_type_base): + continue + base_class = getattr(models[model_key_base].atomic_model, class_type_base) + for link_item in links[1:]: + class_type_link = link_item["shared_type"] + model_key_link = link_item["model_key"] + shared_level_link = int(link_item["shared_level"]) + assert shared_level_link >= shared_level_base, ( + "The shared_links must be sorted by shared_level!" + ) + assert class_type_base == class_type_link, ( + f"Class type mismatched: {class_type_base} vs {class_type_link}!" + ) + link_class = getattr( + models[model_key_link].atomic_model, class_type_link + ) + frac_prob = _model_prob_ratio( + model_key_prob_map, model_key_base, model_key_link + ) + _share_fitting( + link_class, + base_class, + shared_level_link, + frac_prob, + protection=data_stat_protect, + resume=resume, + ) + log.warning( + "Shared params of %s.%s and %s.%s!", + model_key_base, + class_type_base, + model_key_link, + class_type_link, + ) + + +def _model_prob_ratio( + model_key_prob_map: dict[str, float], + model_key_base: str, + model_key_link: str, +) -> float: + base_prob = float(model_key_prob_map.get(model_key_base, 1.0)) + link_prob = float(model_key_prob_map.get(model_key_link, 1.0)) + if base_prob == 0.0: + return 1.0 + return link_prob / base_prob + + +def _get_descriptor(model: Any, shared_type: str) -> Any: + if shared_type == "descriptor": + return model.get_descriptor() + if "hybrid" in shared_type: + hybrid_index = int(shared_type.split("_")[-1]) + return model.get_descriptor().descrpt_list[hybrid_index] + raise RuntimeError(f"Unknown class_type {shared_type}!") + + +def _set_descriptor(model: Any, shared_type: str, descriptor: Any) -> None: + if shared_type == "descriptor": + model.atomic_model.descriptor = descriptor + return + if "hybrid" in shared_type: + hybrid_index = int(shared_type.split("_")[-1]) + model.get_descriptor().descrpt_list[hybrid_index] = descriptor + return + raise RuntimeError(f"Unknown class_type {shared_type}!") + + +def _share_descriptor( + link_model: Any, + link_type: str, + link_class: Any, + base_class: Any, + shared_level: int, + model_prob: float, + *, + resume: bool, +) -> None: + assert link_class.__class__ == base_class.__class__, ( + "Only descriptors of the same type can share params!" + ) + if shared_level == 0: + if not resume: + merge_env_stat(base_class, link_class, model_prob) + _set_descriptor(link_model, link_type, base_class) + return + if shared_level == 1 and hasattr(link_class, "type_embedding"): + link_class.type_embedding = base_class.type_embedding + return + raise NotImplementedError( + f"TF2 descriptor shared level {shared_level} is not supported for " + f"{link_class.__class__.__name__}." + ) + + +def _share_fitting( + link_class: Any, + base_class: Any, + shared_level: int, + model_prob: float, + *, + protection: float, + resume: bool, +) -> None: + assert link_class.__class__ == base_class.__class__, ( + "Only fitting nets of the same type can share params!" + ) + if shared_level != 0: + raise NotImplementedError( + f"TF2 fitting_net shared level {shared_level} is not supported for " + f"{link_class.__class__.__name__}." + ) + + _merge_and_share_param_stats( + link_class, + base_class, + "fparam", + "fparam_avg", + "fparam_inv_std", + model_prob, + protection=protection, + resume=resume, + ) + _merge_and_share_param_stats( + link_class, + base_class, + "aparam", + "aparam_avg", + "aparam_inv_std", + model_prob, + protection=protection, + resume=resume, + ) + _share_tf2_state_attrs( + link_class, + base_class, + excluded={ + "bias_atom_e", + "case_embd", + "fparam_avg", + "fparam_inv_std", + "aparam_avg", + "aparam_inv_std", + }, + ) + + +def _merge_and_share_param_stats( + link_class: Any, + base_class: Any, + stat_name: str, + avg_attr: str, + inv_std_attr: str, + model_prob: float, + *, + protection: float, + resume: bool, +) -> None: + avg_value = getattr(link_class, avg_attr, None) + inv_std_value = getattr(link_class, inv_std_attr, None) + if avg_value is None or inv_std_value is None: + return + if not resume: + base_stats = base_class.get_param_stats().get(stat_name, []) + link_stats = link_class.get_param_stats().get(stat_name, []) + if base_stats and link_stats: + assert len(base_stats) == len(link_stats) + merged = [ + base_stats[ii] + link_stats[ii] * model_prob + for ii in range(len(base_stats)) + ] + avg = np.array([stat.compute_avg() for stat in merged], dtype=np.float64) + inv_std = 1.0 / np.array( + [stat.compute_std(protection=protection) for stat in merged], + dtype=np.float64, + ) + _assign_array_like(base_class, avg_attr, avg) + _assign_array_like(base_class, inv_std_attr, inv_std) + base_class._param_stats[stat_name] = merged + setattr(link_class, avg_attr, getattr(base_class, avg_attr)) + setattr(link_class, inv_std_attr, getattr(base_class, inv_std_attr)) + + +def _assign_array_like(obj: Any, attr: str, value: Any) -> None: + current = getattr(obj, attr) + current_xp = array_api_compat.array_namespace(current) + setattr( + obj, + attr, + current_xp.asarray( + value, + dtype=current.dtype, + device=array_api_compat.device(current), + ), + ) + + +def _share_tf2_state_attrs( + link_class: Any, + base_class: Any, + *, + excluded: set[str], +) -> None: + for name, value in list(vars(link_class).items()): + if name in excluded or name.startswith("_"): + continue + if _is_shareable_tf2_state(value): + setattr(link_class, name, getattr(base_class, name)) + + +def _is_shareable_tf2_state(value: Any) -> bool: + return isinstance(value, (tf.Module, tf.Variable, tf.Tensor, xp.Array)) diff --git a/deepmd/tf2/utils/serialization.py b/deepmd/tf2/utils/serialization.py index f4aac52fd7..95fc2ad2dd 100644 --- a/deepmd/tf2/utils/serialization.py +++ b/deepmd/tf2/utils/serialization.py @@ -3,6 +3,13 @@ import os from collections.abc import ( Callable, + Mapping, +) +from copy import ( + deepcopy, +) +from pathlib import ( + Path, ) from typing import ( Any, @@ -12,6 +19,9 @@ import tensorflow as tf from deepmd._vendors import ndtensorflow as xp +from deepmd.dpmodel.train import ( + DEFAULT_TASK_KEY, +) from deepmd.tf2.common import ( unwrap_value, ) @@ -21,9 +31,18 @@ from deepmd.tf2.model.base_model import ( BaseModel, ) +from deepmd.tf2.model.model import ( + get_model, +) +from deepmd.tf2.train.trainer import ( + TF2_TRAINING_STATE_FILE, +) from deepmd.tf2.utils._dpmodel import ( format_nlist, ) +from deepmd.tf2.utils.multi_task import ( + apply_shared_links, +) def _env_flag(name: str) -> bool: @@ -528,9 +547,143 @@ def get_default_fparam() -> tf.Tensor: def serialize_from_file(model_file: str) -> dict: - """Serialize a TF2 SavedModel to a dictionary. + """Serialize a TF2 training checkpoint to a DeePMD model dictionary. - SavedModel does not currently carry enough structured variable metadata to - round-trip back to the DeePMD dictionary format. + TensorFlow SavedModel exports are inference artifacts and do not currently + carry enough structured variable metadata to round-trip back to DeePMD's + dictionary format. The lossless source for TF2 is the ``.tf2`` + CheckpointManager directory written by ``dp --tf2 train``. """ - raise ValueError(f"TF2 backend cannot serialize {model_file!r} to a model dict") + path = Path(model_file) + if str(path).lower().endswith(".savedmodeltf"): + raise ValueError( + "TF2 SavedModel files cannot be serialized back to a DeePMD model " + "dict. Use the .tf2 training checkpoint directory or checkpoint " + "prefix instead." + ) + checkpoint_path, state = _load_checkpoint_state(path) + model_def_script = state["model_def_script"] + models = _restore_models_from_checkpoint(checkpoint_path, model_def_script, state) + model_payload = _serialize_models(models, model_def_script) + min_nbor_dist = _normalize_json_value(state.get("min_nbor_dist")) + data = { + "backend": "TensorFlow2", + "model": model_payload, + "model_def_script": model_def_script, + "shared_links": state.get("shared_links"), + "@variables": { + "current_step": int(state.get("current_step", 0)), + }, + "min_nbor_dist": min_nbor_dist, + } + if state.get("full_validation"): + data["full_validation"] = state["full_validation"] + return data + + +class _TaskModelContainer(tf.Module): + """Track task-keyed TF modules with the same object graph as training.""" + + def __init__(self, models: Mapping[str, tf.Module]) -> None: + super().__init__(name="models") + self.task_keys = tuple(models) + for index, key in enumerate(self.task_keys): + setattr(self, f"task_{index}", models[key]) + + +def _load_checkpoint_state(path: Path) -> tuple[str, dict[str, Any]]: + checkpoint_path, state_dir = _resolve_checkpoint_path(path) + state_path = state_dir / TF2_TRAINING_STATE_FILE + if not state_path.is_file(): + raise FileNotFoundError( + f"Cannot find TF2 checkpoint metadata {state_path!s}. " + "Only checkpoints produced by the TF2 trainer can be frozen or " + "compressed losslessly." + ) + with state_path.open() as fp: + state = json.load(fp) + if state.get("backend") != "TensorFlow2": + raise ValueError(f"{state_path!s} is not a TensorFlow2 training state file.") + if "model_def_script" not in state: + raise ValueError(f"{state_path!s} does not contain model_def_script.") + return checkpoint_path, state + + +def _resolve_checkpoint_path(path: Path) -> tuple[str, Path]: + candidates = [path] + if not str(path).endswith(".tf2"): + candidates.append(Path(f"{path}.tf2")) + for candidate in candidates: + if candidate.is_dir(): + latest = tf.train.latest_checkpoint(str(candidate)) + if latest is not None: + return latest, candidate + if path.exists() or Path(f"{path}.index").exists(): + return str(path), path.parent + raise FileNotFoundError( + f"Cannot find TF2 checkpoint {str(path)!r}. Expected a CheckpointManager " + "directory ending with .tf2 or a checkpoint prefix." + ) + + +def _restore_models_from_checkpoint( + checkpoint_path: str, + model_def_script: dict[str, Any], + state: dict[str, Any], +) -> dict[str, BaseModel]: + models = _build_models(model_def_script) + _set_min_nbor_dist(models, state.get("min_nbor_dist")) + apply_shared_links(models, state.get("shared_links"), resume=True) + container = _TaskModelContainer(models) + checkpoint = tf.train.Checkpoint(model=container) + checkpoint.restore(checkpoint_path).expect_partial() + return models + + +def _build_models(model_def_script: dict[str, Any]) -> dict[str, BaseModel]: + if "model_dict" in model_def_script: + return { + model_key: get_model(deepcopy(model_def_script["model_dict"][model_key])) + for model_key in model_def_script["model_dict"] + } + return {DEFAULT_TASK_KEY: get_model(deepcopy(model_def_script))} + + +def _serialize_models( + models: dict[str, BaseModel], + model_def_script: dict[str, Any], +) -> dict[str, Any]: + if "model_dict" in model_def_script: + return { + "model_dict": { + model_key: models[model_key].serialize() + for model_key in model_def_script["model_dict"] + } + } + return models[DEFAULT_TASK_KEY].serialize() + + +def _set_min_nbor_dist( + models: dict[str, BaseModel], + min_nbor_dist: Any, +) -> None: + if min_nbor_dist is None: + return + if isinstance(min_nbor_dist, Mapping): + for model_key, value in min_nbor_dist.items(): + if value is not None and model_key in models: + models[model_key].min_nbor_dist = float(value) + return + models[DEFAULT_TASK_KEY].min_nbor_dist = float(min_nbor_dist) + + +def _normalize_json_value(value: Any) -> Any: + if isinstance(value, dict): + return {key: _normalize_json_value(item) for key, item in value.items()} + if isinstance(value, list): + return [_normalize_json_value(item) for item in value] + if value is None: + return None + if hasattr(value, "item"): + return value.item() + return value diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 84b04e02ba..0dc733cc44 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -5549,7 +5549,7 @@ def validating_args() -> Argument: """Generate full validation arguments.""" valid_metrics = ", ".join(item.upper() for item in FULL_VALIDATION_METRIC_PREFS) doc_full_validation_supported = ( - "(Supported Backend: PyTorch, PyTorch Experimental, JAX) " + "(Supported Backend: PyTorch, PyTorch Experimental, JAX, TensorFlow2) " ) doc_full_validation = ( "Whether to run an additional full validation pass over the entire " diff --git a/source/tests/tf2/test_training.py b/source/tests/tf2/test_training.py new file mode 100644 index 0000000000..3b9084d967 --- /dev/null +++ b/source/tests/tf2/test_training.py @@ -0,0 +1,242 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Unit tests for TensorFlow 2 training internals.""" + +import importlib +from types import ( + SimpleNamespace, +) +from typing import ( + Any, + ClassVar, +) + +import numpy as np +import pytest + +from deepmd.dpmodel.train import ( + DEFAULT_TASK_KEY, + TrainEntrypointOptions, + TrainingTask, +) +from deepmd.tf2.entrypoints.train import ( + TF2TrainEntrypoint, +) +from deepmd.tf2.env import ( + tf, +) +from deepmd.tf2.train.trainer import ( + Trainer, +) + +pytestmark = pytest.mark.filterwarnings( + "ignore:.*__init__ missing .*:DeprecationWarning:gast\\.astn" +) + + +class _LinearModel(tf.Module): + def __init__(self) -> None: + super().__init__() + self.weight = tf.Variable(2.0, dtype=tf.float64) + + def call( + self, + coord: Any, + atype: Any, + *, + box: Any = None, + fparam: Any = None, + aparam: Any = None, + charge_spin: Any = None, + ) -> dict[str, Any]: + del atype, box, fparam, aparam, charge_spin + return {"prediction": coord * self.weight} + + +class _SquaredLoss: + label_requirement: ClassVar[list[Any]] = [] + + def __call__( + self, + *, + learning_rate: Any, + natoms: Any, + model_dict: dict[str, Any], + label_dict: dict[str, Any], + ) -> tuple[Any, dict[str, Any]]: + del learning_rate + diff = model_dict["prediction"] - label_dict["target"] + loss = tf.reduce_mean(tf.square(diff)) + return loss, { + "rmse": tf.sqrt(loss), + "natoms": natoms, + "l2_regularization": tf.constant(100.0, dtype=tf.float64), + } + + +def _make_minimal_trainer() -> tuple[Trainer, _LinearModel]: + trainer = object.__new__(Trainer) + model = _LinearModel() + trainer.models = {DEFAULT_TASK_KEY: model} + trainer.losses = {DEFAULT_TASK_KEY: _SquaredLoss()} + trainer.optimizer = tf.keras.optimizers.SGD(learning_rate=0.0) + build = getattr(trainer.optimizer, "build", None) + if callable(build): + build(model.trainable_variables) + trainer.gradient_max_norm = 0.0 + trainer.step = tf.Variable(0, dtype=tf.int64, trainable=False) + trainer._compiled_train_steps = {} + trainer._compiled_eval_steps = {} + return trainer, model + + +def test_compiled_train_step_is_tf_function_and_updates_model() -> None: + trainer, model = _make_minimal_trainer() + compiled = trainer._make_compiled_train_step(DEFAULT_TASK_KEY) + + assert hasattr(compiled, "get_concrete_function") + more_loss = compiled( + { + "coord": tf.constant([[1.0]], dtype=tf.float64), + "atype": tf.constant([[0]], dtype=tf.int32), + }, + {"target": tf.constant([[0.0]], dtype=tf.float64)}, + tf.constant(1.0, dtype=tf.float64), + tf.constant(0.1, dtype=tf.float64), + tf.constant(1, dtype=tf.int64), + ) + + np.testing.assert_allclose(model.weight.numpy(), 1.6) + assert int(trainer.step.numpy()) == 1 + assert more_loss["rmse"].numpy() == 2.0 + + +def test_compiled_eval_step_returns_python_floats_without_l2_terms() -> None: + trainer, model = _make_minimal_trainer() + + result = trainer._compiled_eval_step( + DEFAULT_TASK_KEY, + { + "coord": tf.constant([[2.0]], dtype=tf.float64), + "atype": tf.constant([[0]], dtype=tf.int32), + }, + {"target": tf.constant([[1.0]], dtype=tf.float64)}, + tf.constant(3.0, dtype=tf.float64), + tf.constant(0.1, dtype=tf.float64), + ) + + np.testing.assert_allclose(model.weight.numpy(), 2.0) + assert result == {"rmse": 3.0, "natoms": 3.0} + + +def test_compiled_train_step_is_cached_per_task() -> None: + trainer = object.__new__(Trainer) + trainer._compiled_train_steps = {} + calls: list[str] = [] + + def make_compiled_step(task_key: str) -> Any: + calls.append(task_key) + + def compiled_step(*args: Any) -> dict[str, Any]: + del args + return {"task": task_key} + + return compiled_step + + trainer._make_compiled_train_step = make_compiled_step + + assert trainer._compiled_train_step("a", {}, {}, 1.0, 0.1, 1)["task"] == "a" + assert trainer._compiled_train_step("a", {}, {}, 1.0, 0.1, 2)["task"] == "a" + assert trainer._compiled_train_step("b", {}, {}, 1.0, 0.1, 1)["task"] == "b" + assert calls == ["a", "b"] + + +def test_train_step_passes_float_natoms_to_compiled_step() -> None: + trainer = object.__new__(Trainer) + trainer.lr_schedule = SimpleNamespace(value=lambda step: 0.25) + trainer.get_data = lambda *, is_train, task_key: ({}, {}, 7) + trainer._write_tensorboard_step = lambda *args, **kwargs: None + captured: dict[str, Any] = {} + + def compiled_train_step( + task_key: str, + input_dict: dict[str, Any], + label_dict: dict[str, Any], + natoms: Any, + cur_lr: Any, + next_step: Any, + ) -> dict[str, Any]: + del task_key, input_dict, label_dict + captured["natoms"] = natoms + captured["cur_lr"] = cur_lr + captured["next_step"] = next_step + return {"rmse": tf.constant(1.0, dtype=tf.float64)} + + trainer._compiled_train_step = compiled_train_step + + result = Trainer.train_step( + trainer, + TrainingTask(DEFAULT_TASK_KEY, SimpleNamespace()), + 4, + ) + + assert result.payload["cur_lr"] == 0.25 + assert captured["natoms"].dtype == tf.float64 + assert captured["natoms"].numpy() == 7.0 + assert captured["cur_lr"].dtype == tf.float64 + assert captured["next_step"].dtype == tf.int64 + assert captured["next_step"].numpy() == 5 + + +def test_train_entrypoint_builds_data_without_descriptor_rcut( + monkeypatch: pytest.MonkeyPatch, +) -> None: + calls: list[tuple[dict[str, Any], Any, Any, Any]] = [] + trainer_calls: list[dict[str, Any]] = [] + + class FakeData: + type_map: ClassVar[list[str]] = ["O", "H"] + + def print_summary(self, *args: Any) -> None: + del args + + def fake_get_data( + params: dict[str, Any], + rcut: Any, + type_map: Any, + optional_type_map: Any, + ) -> FakeData: + calls.append((params, rcut, type_map, optional_type_map)) + return FakeData() + + class FakeTrainer: + def __init__(self, *args: Any, **kwargs: Any) -> None: + trainer_calls.append({"args": args, "kwargs": kwargs}) + + def run(self) -> None: + trainer_calls[-1]["ran"] = True + + module = importlib.import_module(TF2TrainEntrypoint.__module__) + monkeypatch.setattr(module, "get_data", fake_get_data) + monkeypatch.setattr(module, "DPTrainer", FakeTrainer) + + config = { + "model": {"type_map": ["O", "H"]}, + "training": { + "training_data": {"systems": ["train"]}, + "validation_data": {"systems": ["valid"]}, + "numb_steps": 1, + }, + } + + TF2TrainEntrypoint().run_training( + config, + TrainEntrypointOptions(input_file="input.json"), + neighbor_stat=0.5, + ) + + assert calls == [ + ({"systems": ["train"]}, None, ["O", "H"], None), + ({"systems": ["valid"]}, None, ["O", "H"], None), + ] + assert trainer_calls[-1]["ran"] is True + assert trainer_calls[-1]["kwargs"]["min_nbor_dist"] == 0.5 From 7ba3e523ab314a9c1f31c0e9a0cde5985cef4e9b Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 5 Jul 2026 01:13:12 +0800 Subject: [PATCH 02/14] perf(tf2): reuse atomic forward during training --- deepmd/tf2/model/base_model.py | 136 ++++++++++++++++++------------ source/tests/tf2/test_training.py | 70 +++++++++++++++ 2 files changed, 153 insertions(+), 53 deletions(-) diff --git a/deepmd/tf2/model/base_model.py b/deepmd/tf2/model/base_model.py index dbb22b6f76..45d9ed1268 100644 --- a/deepmd/tf2/model/base_model.py +++ b/deepmd/tf2/model/base_model.py @@ -1,5 +1,9 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + from deepmd.dpmodel.model.base_model import ( make_base_model, ) @@ -20,35 +24,12 @@ BaseModel = make_base_model() -def forward_common_atomic( - self: "BaseModel", - extended_coord: xp.ndarray, - extended_atype: xp.ndarray, - nlist: xp.ndarray, - mapping: xp.ndarray | None = None, - fparam: xp.ndarray | None = None, - aparam: xp.ndarray | None = None, - do_atomic_virial: bool = False, - extended_coord_corr: xp.ndarray | None = None, - comm_dict: dict | None = None, - charge_spin: xp.ndarray | None = None, -) -> dict[str, xp.ndarray]: - del comm_dict # tf2 path has no MPI ghost exchange - - coord_tensor = to_tf_tensor(extended_coord) - assert coord_tensor is not None - coord_array = wrap_tensor(coord_tensor) - atomic_ret = self.atomic_model.forward_common_atomic( - coord_array, - extended_atype, - nlist, - mapping=mapping, - fparam=fparam, - aparam=aparam, - charge_spin=charge_spin, - ) - atomic_output_def = self.atomic_output_def() - model_predict = {} +def _collect_model_predict( + atomic_ret: dict[str, xp.ndarray], + atomic_output_def: Any, +) -> tuple[dict[str, xp.ndarray], dict[str, tf.Tensor]]: + model_predict: dict[str, xp.ndarray] = {} + reduced_output_tensors: dict[str, tf.Tensor] = {} for kk, vv in atomic_ret.items(): model_predict[kk] = vv vdef = atomic_output_def[kk] @@ -68,33 +49,81 @@ def forward_common_atomic( else: model_predict[kk_redu] = xp.sum(vv, axis=atom_axis) - kk_derv_r, kk_derv_c = get_deriv_name(kk) if vdef.r_differentiable: - with tf.GradientTape() as tape: - tape.watch(coord_tensor) - grad_atomic_ret = self.atomic_model.forward_common_atomic( - wrap_tensor(coord_tensor), - extended_atype, - nlist, - mapping=mapping, - fparam=fparam, - aparam=aparam, - charge_spin=charge_spin, - ) - reduced_output = xp.sum(grad_atomic_ret[kk], axis=atom_axis) - reduced_output_tensor = to_tf_tensor(reduced_output) - assert reduced_output_tensor is not None - ff_tensor = -tape.batch_jacobian(reduced_output_tensor, coord_tensor) - ff = wrap_tensor(ff_tensor) + reduced_output_tensor = to_tf_tensor(model_predict[kk_redu]) + assert reduced_output_tensor is not None + reduced_output_tensors[kk] = reduced_output_tensor + return model_predict, reduced_output_tensors + - # extended_force: [nf, nall, *def, 3] - def_ndim = len(vdef.shape) - model_predict[kk_derv_r] = xp.transpose( - ff, [0, def_ndim + 1, *range(1, def_ndim + 1), def_ndim + 2] +def forward_common_atomic( + self: "BaseModel", + extended_coord: xp.ndarray, + extended_atype: xp.ndarray, + nlist: xp.ndarray, + mapping: xp.ndarray | None = None, + fparam: xp.ndarray | None = None, + aparam: xp.ndarray | None = None, + do_atomic_virial: bool = False, + extended_coord_corr: xp.ndarray | None = None, + comm_dict: dict | None = None, + charge_spin: xp.ndarray | None = None, +) -> dict[str, xp.ndarray]: + del comm_dict # tf2 path has no MPI ghost exchange + + coord_tensor = to_tf_tensor(extended_coord) + assert coord_tensor is not None + atomic_output_def = self.atomic_output_def() + derivative_keys = [ + kk for kk in atomic_output_def.keys() if atomic_output_def[kk].r_differentiable + ] + tape: tf.GradientTape | None = None + if derivative_keys: + tape = tf.GradientTape(persistent=len(derivative_keys) > 1) + with tape: + tape.watch(coord_tensor) + atomic_ret = self.atomic_model.forward_common_atomic( + wrap_tensor(coord_tensor), + extended_atype, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + charge_spin=charge_spin, ) - if vdef.r_hessian: - kk_hessian = get_hessian_name(kk) - model_predict[kk_hessian] = None + model_predict, reduced_output_tensors = _collect_model_predict( + atomic_ret, atomic_output_def + ) + else: + atomic_ret = self.atomic_model.forward_common_atomic( + wrap_tensor(coord_tensor), + extended_atype, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + charge_spin=charge_spin, + ) + model_predict, reduced_output_tensors = _collect_model_predict( + atomic_ret, atomic_output_def + ) + + for kk in derivative_keys: + vdef = atomic_output_def[kk] + kk_derv_r, kk_derv_c = get_deriv_name(kk) + assert tape is not None + reduced_output_tensor = reduced_output_tensors[kk] + ff_tensor = -tape.batch_jacobian(reduced_output_tensor, coord_tensor) + ff = wrap_tensor(ff_tensor) + + # extended_force: [nf, nall, *def, 3] + def_ndim = len(vdef.shape) + model_predict[kk_derv_r] = xp.transpose( + ff, [0, def_ndim + 1, *range(1, def_ndim + 1), def_ndim + 2] + ) + if vdef.r_hessian: + kk_hessian = get_hessian_name(kk) + model_predict[kk_hessian] = None if vdef.c_differentiable: assert vdef.r_differentiable @@ -151,4 +180,5 @@ def forward_common_atomic( model_predict[kk_derv_c] = extended_virial # [nf, *def, 9] model_predict[kk_derv_c + "_redu"] = xp.sum(extended_virial, axis=1) + del tape return model_predict diff --git a/source/tests/tf2/test_training.py b/source/tests/tf2/test_training.py index 3b9084d967..488c0ede8d 100644 --- a/source/tests/tf2/test_training.py +++ b/source/tests/tf2/test_training.py @@ -13,17 +13,28 @@ import numpy as np import pytest +from deepmd.dpmodel.output_def import ( + FittingOutputDef, + OutputVariableDef, +) from deepmd.dpmodel.train import ( DEFAULT_TASK_KEY, TrainEntrypointOptions, TrainingTask, ) +from deepmd.tf2.common import ( + to_tf_tensor, + wrap_tensor, +) from deepmd.tf2.entrypoints.train import ( TF2TrainEntrypoint, ) from deepmd.tf2.env import ( tf, ) +from deepmd.tf2.model.base_model import ( + forward_common_atomic, +) from deepmd.tf2.train.trainer import ( Trainer, ) @@ -73,6 +84,40 @@ def __call__( } +class _CountingAtomicModel: + def __init__(self) -> None: + self.calls = 0 + + def forward_common_atomic( + self, + extended_coord: Any, + extended_atype: Any, + nlist: Any, + **kwargs: Any, + ) -> dict[str, Any]: + del extended_atype, nlist, kwargs + self.calls += 1 + coord = to_tf_tensor(extended_coord) + return {"energy": wrap_tensor(tf.reduce_sum(coord * coord, axis=-1)[..., None])} + + +class _FakeEnergyModel: + def __init__(self) -> None: + self.atomic_model = _CountingAtomicModel() + + def atomic_output_def(self) -> FittingOutputDef: + return FittingOutputDef( + [ + OutputVariableDef( + "energy", + [1], + reducible=True, + r_differentiable=True, + ) + ] + ) + + def _make_minimal_trainer() -> tuple[Trainer, _LinearModel]: trainer = object.__new__(Trainer) model = _LinearModel() @@ -89,6 +134,31 @@ def _make_minimal_trainer() -> tuple[Trainer, _LinearModel]: return trainer, model +def test_forward_common_atomic_reuses_taped_atomic_forward() -> None: + model = _FakeEnergyModel() + coord = tf.constant( + [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]], + dtype=tf.float64, + ) + + result = forward_common_atomic( + model, + wrap_tensor(coord), + tf.constant([[0, 1]], dtype=tf.int32), + tf.constant([[[0], [1]]], dtype=tf.int32), + ) + + assert model.atomic_model.calls == 1 + np.testing.assert_allclose( + to_tf_tensor(result["energy_redu"]).numpy(), + [[91.0]], + ) + np.testing.assert_allclose( + to_tf_tensor(result["energy_derv_r"]).numpy(), + (-2.0 * coord[:, :, tf.newaxis, :]).numpy(), + ) + + def test_compiled_train_step_is_tf_function_and_updates_model() -> None: trainer, model = _make_minimal_trainer() compiled = trainer._make_compiled_train_step(DEFAULT_TASK_KEY) From 3c2810d0dbbbcfa6049ed9321044f515b2ce7230 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 5 Jul 2026 01:29:48 +0800 Subject: [PATCH 03/14] perf(tf2): avoid tensorboard scalar host sync --- deepmd/tf2/train/trainer.py | 14 +++++++++++--- source/tests/tf2/test_training.py | 27 +++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/deepmd/tf2/train/trainer.py b/deepmd/tf2/train/trainer.py index ffd9cc8c58..a1326e83d7 100644 --- a/deepmd/tf2/train/trainer.py +++ b/deepmd/tf2/train/trainer.py @@ -1025,9 +1025,17 @@ def _write_tensorboard_step( return prefix = f"train/{task_key}" if self.multi_task else "train" with self.summary_writer.as_default(): - tf.summary.scalar("learning_rate", learning_rate, step=display_step) - for key, value in self._more_loss_to_float(more_loss).items(): - tf.summary.scalar(f"{prefix}/{key}", value, step=display_step) + tf.summary.scalar( + "learning_rate", + tf.convert_to_tensor(learning_rate, dtype=tf.float64), + step=display_step, + ) + for key, value in more_loss.items(): + if "l2_" in key: + continue + tf.summary.scalar( + f"{prefix}/{key}", to_tf_tensor(value), step=display_step + ) self.summary_writer.flush() def _change_bias_after_training(self) -> None: diff --git a/source/tests/tf2/test_training.py b/source/tests/tf2/test_training.py index 488c0ede8d..afda659883 100644 --- a/source/tests/tf2/test_training.py +++ b/source/tests/tf2/test_training.py @@ -257,6 +257,33 @@ def compiled_train_step( assert captured["next_step"].numpy() == 5 +def test_tensorboard_step_writes_tensors_without_float_sync( + tmp_path: Any, +) -> None: + trainer = object.__new__(Trainer) + trainer.summary_writer = tf.summary.create_file_writer(str(tmp_path)) + trainer.tensorboard_freq = 1 + trainer.multi_task = False + + def fail_if_float_sync_is_used(more_loss: dict[str, Any]) -> dict[str, float]: + del more_loss + raise AssertionError("tensorboard path should not convert tensors to floats") + + trainer._more_loss_to_float = fail_if_float_sync_is_used + + Trainer._write_tensorboard_step( + trainer, + DEFAULT_TASK_KEY, + display_step=1, + learning_rate=0.1, + more_loss={ + "rmse": tf.constant(1.0, dtype=tf.float64), + "l2_regularization": tf.constant(2.0, dtype=tf.float64), + }, + ) + trainer.summary_writer.close() + + def test_train_entrypoint_builds_data_without_descriptor_rcut( monkeypatch: pytest.MonkeyPatch, ) -> None: From 2dd8b1981f6dad35f6cee3f13de142a46ef10beb Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 5 Jul 2026 02:57:10 +0800 Subject: [PATCH 04/14] perf(tf2): reduce training graph overhead --- deepmd/tf2/make_model.py | 157 ++++++++++++----- deepmd/tf2/model/base_model.py | 28 +++- deepmd/tf2/model/dp_model.py | 55 +++++- deepmd/tf2/train/trainer.py | 72 +++++++- source/tests/tf2/test_training.py | 268 +++++++++++++++++++++++++++++- 5 files changed, 527 insertions(+), 53 deletions(-) diff --git a/deepmd/tf2/make_model.py b/deepmd/tf2/make_model.py index 027e41fa4a..2976807508 100644 --- a/deepmd/tf2/make_model.py +++ b/deepmd/tf2/make_model.py @@ -11,20 +11,27 @@ from deepmd.dpmodel.array_api import ( Array, ) -from deepmd.dpmodel.model.transform_output import ( - communicate_extended_output, -) from deepmd.dpmodel.output_def import ( ModelOutputDef, ) +from deepmd.dpmodel.utils.neighbor_list import ( + NeighborList, +) +from deepmd.dpmodel.utils.nlist import ( + nlist_distinguish_types, +) from deepmd.tf2.common import ( to_tensorflow_array, to_tf_tensor, + unwrap_value, wrap_value, ) from deepmd.tf2.env import ( xp, ) +from deepmd.tf2.transform_output import ( + communicate_extended_output, +) from deepmd.tf2.utils._dpmodel import ( build_neighbor_list, extend_coord_with_ghosts, @@ -56,6 +63,11 @@ def model_call_from_call_lower( fparam: Array | None, aparam: Array | None, do_atomic_virial: bool = False, + do_deriv_c: bool = True, + coord_corr_for_virial: Array | None = None, + charge_spin: Array | None = None, + neighbor_list: NeighborList | None = None, + pass_lower_kwargs: bool = False, ) -> dict[str, Array]: """Return model prediction from lower interface. @@ -74,6 +86,13 @@ def model_call_from_call_lower( atomic parameter. nf x nloc x nda do_atomic_virial If calculate the atomic virial. + neighbor_list + Optional dense-neighbor-list strategy. ``None`` uses the native TF2 + all-pairs builder. + pass_lower_kwargs + Pass optional lower-interface keyword arguments. SavedModel export wraps + the lower with fixed signatures and keeps this disabled; direct TF2 model + calls enable it. Returns ------- @@ -87,60 +106,112 @@ def model_call_from_call_lower( bb = to_tensorflow_array(box) fp = to_tensorflow_array(fparam) ap = to_tensorflow_array(aparam) - del coord, box, fparam, aparam + cs = to_tensorflow_array(charge_spin) + coord_corr = to_tensorflow_array(coord_corr_for_virial) + del coord, box, fparam, aparam, charge_spin, coord_corr_for_virial nframes, nloc = atype.shape[:2] - def with_pbc() -> tuple[Array, Array, Array]: + def with_pbc() -> tuple[Array, Array, Array, Array]: assert bb is not None coord_normalized = normalize_coord( xp.reshape(cc, (nframes, nloc, 3)), xp.reshape(bb, (nframes, 3, 3)), ) - return extend_coord_with_ghosts(coord_normalized, atype, bb, rcut) + extended_coord, extended_atype, mapping = extend_coord_with_ghosts( + coord_normalized, atype, bb, rcut + ) + nlist = build_neighbor_list( + extended_coord, + extended_atype, + nloc, + rcut, + sel, + # types will be distinguished in the lower interface, + # so it doesn't need to be distinguished here + distinguish_types=False, + ) + return extended_coord, extended_atype, nlist, mapping - def no_pbc() -> tuple[Array, Array, Array]: - return extend_coord_with_ghosts(cc, atype, None, rcut) + def no_pbc() -> tuple[Array, Array, Array, Array]: + extended_coord, extended_atype, mapping = extend_coord_with_ghosts( + cc, atype, None, rcut + ) + nlist = build_neighbor_list( + extended_coord, + extended_atype, + nloc, + rcut, + sel, + # types will be distinguished in the lower interface, + # so it doesn't need to be distinguished here + distinguish_types=False, + ) + return extended_coord, extended_atype, nlist, mapping - has_pbc = _box_has_pbc(bb) - if has_pbc is True: - extended_coord, extended_atype, mapping = with_pbc() - elif has_pbc is False: - extended_coord, extended_atype, mapping = no_pbc() + uses_native_nlist_builder = neighbor_list is None + if neighbor_list is not None: + extended_coord, extended_atype, nlist, mapping = neighbor_list.build( + cc, atype, bb, rcut, sel + ) else: - assert bb is not None - extended_coord_tensor, extended_atype_tensor, mapping_tensor = tf.cond( - tf.shape(to_tf_tensor(bb))[-1] != 0, - lambda: _unwrap_tuple(with_pbc()), - lambda: _unwrap_tuple(no_pbc()), + has_pbc = _box_has_pbc(bb) + if has_pbc is True: + extended_coord, extended_atype, nlist, mapping = with_pbc() + elif has_pbc is False: + extended_coord, extended_atype, nlist, mapping = no_pbc() + else: + assert bb is not None + ( + extended_coord_tensor, + extended_atype_tensor, + nlist_tensor, + mapping_tensor, + ) = tf.cond( + tf.shape(to_tf_tensor(bb))[-1] != 0, + lambda: _unwrap_tuple(with_pbc()), + lambda: _unwrap_tuple(no_pbc()), + ) + extended_coord = to_tensorflow_array(extended_coord_tensor) + extended_atype = to_tensorflow_array(extended_atype_tensor) + nlist = to_tensorflow_array(nlist_tensor) + mapping = to_tensorflow_array(mapping_tensor) + extended_coord = xp.reshape(extended_coord, (nframes, -1, 3)) + if coord_corr is not None: + mapping_idx = xp.tile( + xp.reshape(mapping, (nframes, -1, 1)), + (1, 1, 3), + ) + extended_coord_corr = xp.take_along_axis(coord_corr, mapping_idx, axis=1) + else: + extended_coord_corr = None + lower_kwargs: dict[str, Any] = {"fparam": fp, "aparam": ap} + if pass_lower_kwargs: + if uses_native_nlist_builder: + if not mixed_types: + nlist = nlist_distinguish_types(nlist, extended_atype, sel) + lower_kwargs["nlist_is_formatted"] = True + lower_kwargs.update( + { + "do_atomic_virial": do_atomic_virial, + "do_deriv_c": do_deriv_c, + "charge_spin": cs, + } ) - extended_coord = to_tensorflow_array(extended_coord_tensor) - extended_atype = to_tensorflow_array(extended_atype_tensor) - mapping = to_tensorflow_array(mapping_tensor) - nlist = build_neighbor_list( + if extended_coord_corr is not None: + lower_kwargs["extended_coord_corr"] = extended_coord_corr + model_predict_lower = call_lower( extended_coord, extended_atype, - nloc, - rcut, - sel, - # types will be distinguished in the lower interface, - # so it doesn't need to be distinguished here - distinguish_types=False, + nlist, + mapping, + **lower_kwargs, ) - extended_coord = xp.reshape(extended_coord, (nframes, -1, 3)) - model_predict_lower = wrap_value( - call_lower( - extended_coord, - extended_atype, - nlist, - mapping, - fparam=fp, - aparam=ap, + model_predict = wrap_value( + communicate_extended_output( + unwrap_value(model_predict_lower), + model_output_def, + to_tf_tensor(mapping), + do_atomic_virial=do_atomic_virial, ) ) - model_predict = communicate_extended_output( - model_predict_lower, - model_output_def, - mapping, - do_atomic_virial=do_atomic_virial, - ) return model_predict diff --git a/deepmd/tf2/model/base_model.py b/deepmd/tf2/model/base_model.py index 45d9ed1268..9fa57f33c2 100644 --- a/deepmd/tf2/model/base_model.py +++ b/deepmd/tf2/model/base_model.py @@ -56,6 +56,22 @@ def _collect_model_predict( return model_predict, reduced_output_tensors +def _negative_coordinate_derivative( + tape: tf.GradientTape, + reduced_output_tensor: tf.Tensor, + coord_tensor: tf.Tensor, + output_size: int, +) -> tf.Tensor: + if output_size == 1: + grad = tape.gradient( + reduced_output_tensor, + coord_tensor, + unconnected_gradients=tf.UnconnectedGradients.ZERO, + ) + return -grad[:, tf.newaxis, :, :] + return -tape.batch_jacobian(reduced_output_tensor, coord_tensor) + + def forward_common_atomic( self: "BaseModel", extended_coord: xp.ndarray, @@ -65,6 +81,7 @@ def forward_common_atomic( fparam: xp.ndarray | None = None, aparam: xp.ndarray | None = None, do_atomic_virial: bool = False, + do_deriv_c: bool = True, extended_coord_corr: xp.ndarray | None = None, comm_dict: dict | None = None, charge_spin: xp.ndarray | None = None, @@ -113,7 +130,12 @@ def forward_common_atomic( kk_derv_r, kk_derv_c = get_deriv_name(kk) assert tape is not None reduced_output_tensor = reduced_output_tensors[kk] - ff_tensor = -tape.batch_jacobian(reduced_output_tensor, coord_tensor) + ff_tensor = _negative_coordinate_derivative( + tape, + reduced_output_tensor, + coord_tensor, + vdef.output_size, + ) ff = wrap_tensor(ff_tensor) # extended_force: [nf, nall, *def, 3] @@ -127,6 +149,10 @@ def forward_common_atomic( if vdef.c_differentiable: assert vdef.r_differentiable + if not do_deriv_c: + model_predict[kk_derv_c] = None + model_predict[kk_derv_c + "_redu"] = None + continue # avr: [nf, *def, nall, 3, 3] avr = xp.einsum("f...ai,faj->f...aij", ff, extended_coord) if extended_coord_corr is not None: diff --git a/deepmd/tf2/model/dp_model.py b/deepmd/tf2/model/dp_model.py index 146e399050..78d1886ba0 100644 --- a/deepmd/tf2/model/dp_model.py +++ b/deepmd/tf2/model/dp_model.py @@ -16,6 +16,9 @@ stop_gradient, xp, ) +from deepmd.tf2.make_model import ( + model_call_from_call_lower as tf2_model_call_from_call_lower, +) from deepmd.tf2.model.base_model import ( forward_common_atomic, ) @@ -49,21 +52,37 @@ def call_common( fparam: xp.ndarray | None = None, aparam: xp.ndarray | None = None, do_atomic_virial: bool = False, + do_deriv_c: bool = True, coord_corr_for_virial: xp.ndarray | None = None, charge_spin: xp.ndarray | None = None, neighbor_list: NeighborList | None = None, ) -> dict[str, xp.ndarray]: - return super().call_common( + cc, bb, fp, ap, cs, input_prec = self._input_type_cast( to_tensorflow_array(coord), - to_tensorflow_array(atype), box=to_tensorflow_array(box), fparam=to_tensorflow_array(fparam), aparam=to_tensorflow_array(aparam), + charge_spin=to_tensorflow_array(charge_spin), + ) + model_predict = tf2_model_call_from_call_lower( + call_lower=self.call_common_lower, + rcut=self.get_rcut(), + sel=self.get_sel(), + mixed_types=self.mixed_types(), + model_output_def=self.model_output_def(), + coord=cc, + atype=to_tensorflow_array(atype), + box=bb, + fparam=fp, + aparam=ap, do_atomic_virial=do_atomic_virial, + do_deriv_c=do_deriv_c, coord_corr_for_virial=to_tensorflow_array(coord_corr_for_virial), - charge_spin=to_tensorflow_array(charge_spin), + charge_spin=cs, neighbor_list=neighbor_list, + pass_lower_kwargs=True, ) + return self._output_type_cast(model_predict, input_prec) def call_common_lower( self, @@ -74,10 +93,38 @@ def call_common_lower( fparam: xp.ndarray | None = None, aparam: xp.ndarray | None = None, do_atomic_virial: bool = False, + do_deriv_c: bool = True, extended_coord_corr: xp.ndarray | None = None, comm_dict: dict | None = None, charge_spin: xp.ndarray | None = None, + nlist_is_formatted: bool = False, ) -> dict[str, xp.ndarray]: + if nlist_is_formatted: + del comm_dict # tf2 path has no MPI ghost exchange + extended_coord = to_tensorflow_array(extended_coord) + extended_atype = to_tensorflow_array(extended_atype) + nlist = to_tensorflow_array(nlist) + nframes, _nall = extended_atype.shape[:2] + extended_coord = xp.reshape(extended_coord, (nframes, -1, 3)) + cc_ext, _, fp, ap, cs, input_prec = self._input_type_cast( + extended_coord, + fparam=to_tensorflow_array(fparam), + aparam=to_tensorflow_array(aparam), + charge_spin=to_tensorflow_array(charge_spin), + ) + model_predict = self.forward_common_atomic( + cc_ext, + extended_atype, + nlist, + mapping=to_tensorflow_array(mapping), + fparam=fp, + aparam=ap, + do_atomic_virial=do_atomic_virial, + do_deriv_c=do_deriv_c, + extended_coord_corr=to_tensorflow_array(extended_coord_corr), + charge_spin=cs, + ) + return self._output_type_cast(model_predict, input_prec) return super().call_common_lower( to_tensorflow_array(extended_coord), to_tensorflow_array(extended_atype), @@ -100,6 +147,7 @@ def forward_common_atomic( fparam: xp.ndarray | None = None, aparam: xp.ndarray | None = None, do_atomic_virial: bool = False, + do_deriv_c: bool = True, extended_coord_corr: xp.ndarray | None = None, comm_dict: dict | None = None, charge_spin: xp.ndarray | None = None, @@ -114,6 +162,7 @@ def forward_common_atomic( fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, + do_deriv_c=do_deriv_c, extended_coord_corr=extended_coord_corr, charge_spin=charge_spin, ) diff --git a/deepmd/tf2/train/trainer.py b/deepmd/tf2/train/trainer.py index a1326e83d7..b2b6e0e792 100644 --- a/deepmd/tf2/train/trainer.py +++ b/deepmd/tf2/train/trainer.py @@ -758,6 +758,7 @@ def train_step(self, task: TrainingTask, step: int) -> TrainStepResult: task_key = task.key cur_lr = float(self.lr_schedule.value(step)) input_dict, label_dict, natoms = self.get_data(is_train=True, task_key=task_key) + do_virial = bool(label_dict.pop("_do_virial", True)) more_loss = self._compiled_train_step( task_key, input_dict, @@ -765,6 +766,7 @@ def train_step(self, task: TrainingTask, step: int) -> TrainStepResult: tf.constant(float(natoms), dtype=tf.float64), tf.constant(cur_lr, dtype=tf.float64), tf.constant(step + 1, dtype=tf.int64), + do_virial, ) self._write_tensorboard_step( task_key, @@ -789,6 +791,7 @@ def _compiled_train_step( natoms: Any, cur_lr: Any, next_step: Any, + do_virial: bool, ) -> dict[str, Any]: if task_key not in self._compiled_train_steps: self._compiled_train_steps[task_key] = self._make_compiled_train_step( @@ -800,6 +803,7 @@ def _compiled_train_step( natoms, cur_lr, next_step, + do_virial, ) def _make_compiled_train_step(self, task_key: str) -> Any: @@ -812,10 +816,16 @@ def compiled_train_step( natoms: Any, cur_lr: Any, next_step: Any, + do_virial: bool, ) -> dict[str, Any]: self._assign_learning_rate(cur_lr) with tf.GradientTape() as tape: - model_pred = self._call_model(task_key, input_dict) + model_pred = self._call_model( + task_key, + input_dict, + label_dict=label_dict, + do_virial=do_virial, + ) loss, more_loss = self.losses[task_key]( learning_rate=cur_lr, natoms=natoms, @@ -848,7 +858,15 @@ def evaluate_training( if step_result is not None and step_result.task_key == task.key: return self._more_loss_to_float(step_result.payload["more_loss"]) input_dict, label_dict, natoms = self.get_data(is_train=True, task_key=task.key) - return self._evaluate_batch(task.key, step, input_dict, label_dict, natoms) + do_virial = bool(label_dict.pop("_do_virial", True)) + return self._evaluate_batch( + task.key, + step, + input_dict, + label_dict, + natoms, + do_virial=do_virial, + ) def evaluate_validation( self, @@ -865,12 +883,14 @@ def evaluate_validation( is_train=False, task_key=task.key, ) + do_virial = bool(label_dict.pop("_do_virial", True)) results = self._evaluate_batch( task.key, step, input_dict, label_dict, natoms, + do_virial=do_virial, ) sum_natoms += natoms for key, value in results.items(): @@ -886,6 +906,7 @@ def _evaluate_batch( input_dict: dict[str, Any], label_dict: dict[str, Any], natoms: int, + do_virial: bool = True, ) -> dict[str, float]: cur_lr = float(self.lr_schedule.value(step)) return self._compiled_eval_step( @@ -894,6 +915,7 @@ def _evaluate_batch( label_dict, tf.constant(float(natoms), dtype=tf.float64), tf.constant(cur_lr, dtype=tf.float64), + do_virial, ) def _compiled_eval_step( @@ -903,6 +925,7 @@ def _compiled_eval_step( label_dict: dict[str, Any], natoms: Any, cur_lr: Any, + do_virial: bool, ) -> dict[str, float]: if task_key not in self._compiled_eval_steps: self._compiled_eval_steps[task_key] = self._make_compiled_eval_step( @@ -913,6 +936,7 @@ def _compiled_eval_step( label_dict, natoms, cur_lr, + do_virial, ) return self._more_loss_to_float(more_loss) @@ -923,8 +947,14 @@ def compiled_eval_step( label_dict: dict[str, Any], natoms: Any, cur_lr: Any, + do_virial: bool, ) -> dict[str, Any]: - model_pred = self._call_model(task_key, input_dict) + model_pred = self._call_model( + task_key, + input_dict, + label_dict=label_dict, + do_virial=do_virial, + ) _, more_loss = self.losses[task_key]( learning_rate=cur_lr, natoms=natoms, @@ -1072,19 +1102,45 @@ def get_data( input_dict.pop(opt_key) natoms = int(input_dict["atype"].shape[1]) label_dict["type"] = input_dict["atype"] + do_virial = self._batch_needs_virial(task_key, label_dict) input_tf = { key: self._to_input_tensor(key, value) for key, value in input_dict.items() } label_tf = { key: self._to_label_array(key, value) for key, value in label_dict.items() } + label_tf["_do_virial"] = do_virial return input_tf, label_tf, natoms def _call_model( self, task_key: str, input_dict: dict[str, Any], + *, + label_dict: dict[str, Any] | None = None, + do_virial: bool = True, ) -> dict[str, Any]: + if isinstance(self.losses[task_key], EnergyLoss) and not do_virial: + model_ret = self.models[task_key].call_common( + input_dict["coord"], + input_dict["atype"], + box=input_dict.get("box"), + fparam=input_dict.get("fparam"), + aparam=input_dict.get("aparam"), + charge_spin=input_dict.get("charge_spin"), + do_deriv_c=False, + ) + model_pred = { + "atom_energy": model_ret["energy"], + "energy": model_ret["energy_redu"], + } + if model_ret.get("energy_derv_r") is not None: + model_pred["force"] = model_ret["energy_derv_r"].squeeze(-2) + if label_dict is not None and "virial" in label_dict: + model_pred["virial"] = label_dict["virial"] + if "mask" in model_ret: + model_pred["mask"] = model_ret["mask"] + return model_pred return self.models[task_key].call( input_dict["coord"], input_dict["atype"], @@ -1094,6 +1150,16 @@ def _call_model( charge_spin=input_dict.get("charge_spin"), ) + def _batch_needs_virial( + self, + task_key: str, + label_dict: dict[str, Any], + ) -> bool: + loss = self.losses[task_key] + if not isinstance(loss, EnergyLoss) or not loss.has_v: + return False + return bool(np.asarray(label_dict.get("find_virial", False)).any()) + @staticmethod def _to_input_tensor(key: str, value: Any) -> Any: if value is None: diff --git a/source/tests/tf2/test_training.py b/source/tests/tf2/test_training.py index afda659883..33000bf573 100644 --- a/source/tests/tf2/test_training.py +++ b/source/tests/tf2/test_training.py @@ -13,8 +13,12 @@ import numpy as np import pytest +from deepmd.dpmodel.loss import ( + EnergyLoss, +) from deepmd.dpmodel.output_def import ( FittingOutputDef, + ModelOutputDef, OutputVariableDef, ) from deepmd.dpmodel.train import ( @@ -118,6 +122,21 @@ def atomic_output_def(self) -> FittingOutputDef: ) +class _FakeVirialEnergyModel(_FakeEnergyModel): + def atomic_output_def(self) -> FittingOutputDef: + return FittingOutputDef( + [ + OutputVariableDef( + "energy", + [1], + reducible=True, + r_differentiable=True, + c_differentiable=True, + ) + ] + ) + + def _make_minimal_trainer() -> tuple[Trainer, _LinearModel]: trainer = object.__new__(Trainer) model = _LinearModel() @@ -159,6 +178,216 @@ def test_forward_common_atomic_reuses_taped_atomic_forward() -> None: ) +def test_forward_common_atomic_scalar_output_avoids_batch_jacobian( + monkeypatch: pytest.MonkeyPatch, +) -> None: + model = _FakeEnergyModel() + coord = tf.constant( + [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]], + dtype=tf.float64, + ) + + def fail_batch_jacobian(*args: Any, **kwargs: Any) -> Any: + del args, kwargs + raise AssertionError("scalar output should use tape.gradient") + + monkeypatch.setattr(tf.GradientTape, "batch_jacobian", fail_batch_jacobian) + + result = forward_common_atomic( + model, + wrap_tensor(coord), + tf.constant([[0, 1]], dtype=tf.int32), + tf.constant([[[0], [1]]], dtype=tf.int32), + ) + + np.testing.assert_allclose( + to_tf_tensor(result["energy_derv_r"]).numpy(), + (-2.0 * coord[:, :, tf.newaxis, :]).numpy(), + ) + + +def test_forward_common_atomic_can_skip_virial_derivative() -> None: + model = _FakeVirialEnergyModel() + coord = tf.constant( + [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]], + dtype=tf.float64, + ) + + result = forward_common_atomic( + model, + wrap_tensor(coord), + tf.constant([[0, 1]], dtype=tf.int32), + tf.constant([[[0], [1]]], dtype=tf.int32), + do_deriv_c=False, + ) + + np.testing.assert_allclose( + to_tf_tensor(result["energy_derv_r"]).numpy(), + (-2.0 * coord[:, :, tf.newaxis, :]).numpy(), + ) + assert result["energy_derv_c"] is None + assert result["energy_derv_c_redu"] is None + + +def test_model_call_from_call_lower_uses_tf2_native_communicate( + monkeypatch: pytest.MonkeyPatch, +) -> None: + make_model_module = importlib.import_module("deepmd.tf2.make_model") + captured: dict[str, Any] = {} + + def fake_communicate( + model_ret: dict[str, Any], + model_output_def: ModelOutputDef, + mapping: Any, + do_atomic_virial: bool = False, + ) -> dict[str, Any]: + captured["model_ret_is_tensor"] = all( + value is None or isinstance(value, tf.Tensor) + for value in model_ret.values() + ) + captured["mapping_is_tensor"] = isinstance(mapping, tf.Tensor) + captured["do_atomic_virial"] = do_atomic_virial + captured["model_output_def"] = model_output_def + return { + "energy": model_ret["energy"], + "energy_redu": tf.reduce_sum(model_ret["energy"], axis=1), + } + + monkeypatch.setattr( + make_model_module, + "communicate_extended_output", + fake_communicate, + ) + + def call_lower( + extended_coord: Any, + extended_atype: Any, + nlist: Any, + mapping: Any, + *, + fparam: Any = None, + aparam: Any = None, + **kwargs: Any, + ) -> dict[str, Any]: + del extended_coord, nlist, mapping, fparam, aparam + captured["lower_kwargs"] = kwargs + atype = to_tf_tensor(extended_atype) + assert atype is not None + return { + "energy": tf.ones( + tf.concat([tf.shape(atype), tf.constant([1], dtype=tf.int32)], axis=0), + dtype=tf.float64, + ) + } + + model_output_def = ModelOutputDef( + FittingOutputDef( + [ + OutputVariableDef( + "energy", + [1], + reducible=True, + r_differentiable=True, + ) + ] + ) + ) + + result = make_model_module.model_call_from_call_lower( + call_lower=call_lower, + rcut=1.0, + sel=[1], + mixed_types=False, + model_output_def=model_output_def, + coord=tf.constant([[[0.0, 0.0, 0.0], [0.5, 0.0, 0.0]]], dtype=tf.float64), + atype=tf.constant([[0, 0]], dtype=tf.int32), + box=None, + fparam=None, + aparam=None, + pass_lower_kwargs=True, + ) + + assert captured == { + "model_ret_is_tensor": True, + "mapping_is_tensor": True, + "do_atomic_virial": False, + "model_output_def": model_output_def, + "lower_kwargs": { + "nlist_is_formatted": True, + "do_atomic_virial": False, + "do_deriv_c": True, + "charge_spin": None, + }, + } + assert isinstance(to_tf_tensor(result["energy_redu"]), tf.Tensor) + + +def test_tf2_dp_model_call_common_uses_tf2_helper( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dp_model_module = importlib.import_module("deepmd.tf2.model.dp_model") + captured: dict[str, Any] = {} + + class FakeDPModel: + def call_common(self, *args: Any, **kwargs: Any) -> dict[str, Any]: + del args, kwargs + raise AssertionError("generic dpmodel call_common should not be used") + + def call_common_lower(self, *args: Any, **kwargs: Any) -> dict[str, Any]: + del args, kwargs + return {} + + def _input_type_cast( + self, + coord: Any, + *, + box: Any = None, + fparam: Any = None, + aparam: Any = None, + charge_spin: Any = None, + ) -> tuple[Any, Any, Any, Any, Any, Any]: + return coord, box, fparam, aparam, charge_spin, coord.dtype + + def _output_type_cast( + self, + model_ret: dict[str, Any], + input_prec: Any, + ) -> dict[str, Any]: + captured["input_prec"] = input_prec + return model_ret + + def get_rcut(self) -> float: + return 1.0 + + def get_sel(self) -> list[int]: + return [1] + + def mixed_types(self) -> bool: + return False + + def model_output_def(self) -> str: + return "output_def" + + def fake_helper(**kwargs: Any) -> dict[str, Any]: + captured.update(kwargs) + return {"energy": tf.constant([[[1.0]]], dtype=tf.float64)} + + monkeypatch.setattr(dp_model_module, "tf2_model_call_from_call_lower", fake_helper) + model_class = dp_model_module.make_tf2_dp_model_from_dpmodel(FakeDPModel, object) + model = model_class() + + result = model.call_common( + tf.constant([[[0.0, 0.0, 0.0]]], dtype=tf.float64), + tf.constant([[0]], dtype=tf.int32), + ) + + assert result["energy"].shape == (1, 1, 1) + assert captured["model_output_def"] == "output_def" + assert captured["pass_lower_kwargs"] is True + assert captured["call_lower"] == model.call_common_lower + assert isinstance(to_tf_tensor(captured["coord"]), tf.Tensor) + + def test_compiled_train_step_is_tf_function_and_updates_model() -> None: trainer, model = _make_minimal_trainer() compiled = trainer._make_compiled_train_step(DEFAULT_TASK_KEY) @@ -173,6 +402,7 @@ def test_compiled_train_step_is_tf_function_and_updates_model() -> None: tf.constant(1.0, dtype=tf.float64), tf.constant(0.1, dtype=tf.float64), tf.constant(1, dtype=tf.int64), + True, ) np.testing.assert_allclose(model.weight.numpy(), 1.6) @@ -192,6 +422,7 @@ def test_compiled_eval_step_returns_python_floats_without_l2_terms() -> None: {"target": tf.constant([[1.0]], dtype=tf.float64)}, tf.constant(3.0, dtype=tf.float64), tf.constant(0.1, dtype=tf.float64), + True, ) np.testing.assert_allclose(model.weight.numpy(), 2.0) @@ -214,9 +445,9 @@ def compiled_step(*args: Any) -> dict[str, Any]: trainer._make_compiled_train_step = make_compiled_step - assert trainer._compiled_train_step("a", {}, {}, 1.0, 0.1, 1)["task"] == "a" - assert trainer._compiled_train_step("a", {}, {}, 1.0, 0.1, 2)["task"] == "a" - assert trainer._compiled_train_step("b", {}, {}, 1.0, 0.1, 1)["task"] == "b" + assert trainer._compiled_train_step("a", {}, {}, 1.0, 0.1, 1, True)["task"] == "a" + assert trainer._compiled_train_step("a", {}, {}, 1.0, 0.1, 2, True)["task"] == "a" + assert trainer._compiled_train_step("b", {}, {}, 1.0, 0.1, 1, True)["task"] == "b" assert calls == ["a", "b"] @@ -234,11 +465,13 @@ def compiled_train_step( natoms: Any, cur_lr: Any, next_step: Any, + do_virial: bool, ) -> dict[str, Any]: del task_key, input_dict, label_dict captured["natoms"] = natoms captured["cur_lr"] = cur_lr captured["next_step"] = next_step + captured["do_virial"] = do_virial return {"rmse": tf.constant(1.0, dtype=tf.float64)} trainer._compiled_train_step = compiled_train_step @@ -255,6 +488,35 @@ def compiled_train_step( assert captured["cur_lr"].dtype == tf.float64 assert captured["next_step"].dtype == tf.int64 assert captured["next_step"].numpy() == 5 + assert captured["do_virial"] is True + + +def test_batch_needs_virial_handles_numpy_find_flags() -> None: + trainer = object.__new__(Trainer) + trainer.losses = { + DEFAULT_TASK_KEY: EnergyLoss( + starter_learning_rate=1.0, + start_pref_v=1.0, + limit_pref_v=1.0, + ) + } + + assert ( + Trainer._batch_needs_virial( + trainer, + DEFAULT_TASK_KEY, + {"find_virial": np.asarray([False, True])}, + ) + is True + ) + assert ( + Trainer._batch_needs_virial( + trainer, + DEFAULT_TASK_KEY, + {"find_virial": np.asarray([False])}, + ) + is False + ) def test_tensorboard_step_writes_tensors_without_float_sync( From 6a644aaa4b389cc54032de1ca53984d7fbc56ec8 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 5 Jul 2026 12:38:28 +0800 Subject: [PATCH 05/14] fix(tf2): keep atomic virial disabled in training --- deepmd/tf2/train/trainer.py | 9 ++++-- source/tests/tf2/test_training.py | 53 +++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 3 deletions(-) diff --git a/deepmd/tf2/train/trainer.py b/deepmd/tf2/train/trainer.py index b2b6e0e792..5c4b676df9 100644 --- a/deepmd/tf2/train/trainer.py +++ b/deepmd/tf2/train/trainer.py @@ -1120,7 +1120,7 @@ def _call_model( label_dict: dict[str, Any] | None = None, do_virial: bool = True, ) -> dict[str, Any]: - if isinstance(self.losses[task_key], EnergyLoss) and not do_virial: + if isinstance(self.losses[task_key], EnergyLoss): model_ret = self.models[task_key].call_common( input_dict["coord"], input_dict["atype"], @@ -1128,7 +1128,8 @@ def _call_model( fparam=input_dict.get("fparam"), aparam=input_dict.get("aparam"), charge_spin=input_dict.get("charge_spin"), - do_deriv_c=False, + do_atomic_virial=False, + do_deriv_c=do_virial, ) model_pred = { "atom_energy": model_ret["energy"], @@ -1136,7 +1137,9 @@ def _call_model( } if model_ret.get("energy_derv_r") is not None: model_pred["force"] = model_ret["energy_derv_r"].squeeze(-2) - if label_dict is not None and "virial" in label_dict: + if model_ret.get("energy_derv_c_redu") is not None: + model_pred["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + elif label_dict is not None and "virial" in label_dict: model_pred["virial"] = label_dict["virial"] if "mask" in model_ret: model_pred["mask"] = model_ret["mask"] diff --git a/source/tests/tf2/test_training.py b/source/tests/tf2/test_training.py index 33000bf573..246b509c39 100644 --- a/source/tests/tf2/test_training.py +++ b/source/tests/tf2/test_training.py @@ -388,6 +388,59 @@ def fake_helper(**kwargs: Any) -> dict[str, Any]: assert isinstance(to_tf_tensor(captured["coord"]), tf.Tensor) +def test_training_energy_call_keeps_atomic_virial_disabled() -> None: + trainer = object.__new__(Trainer) + captured: dict[str, Any] = {} + + class SpyEnergyModel: + def call_common( + self, + coord: Any, + atype: Any, + *, + box: Any = None, + fparam: Any = None, + aparam: Any = None, + charge_spin: Any = None, + do_atomic_virial: bool = False, + do_deriv_c: bool = True, + ) -> dict[str, Any]: + del coord, atype, box, fparam, aparam, charge_spin + captured["do_atomic_virial"] = do_atomic_virial + captured["do_deriv_c"] = do_deriv_c + return { + "energy": wrap_tensor(tf.constant([[[1.0]]], dtype=tf.float64)), + "energy_redu": wrap_tensor(tf.constant([[1.0]], dtype=tf.float64)), + "energy_derv_r": wrap_tensor(tf.zeros((1, 1, 1, 3), dtype=tf.float64)), + "energy_derv_c_redu": wrap_tensor( + tf.ones((1, 1, 1, 9), dtype=tf.float64) + ), + } + + trainer.models = {DEFAULT_TASK_KEY: SpyEnergyModel()} + trainer.losses = {DEFAULT_TASK_KEY: EnergyLoss(starter_learning_rate=1.0)} + + result = Trainer._call_model( + trainer, + DEFAULT_TASK_KEY, + { + "coord": tf.constant([[[0.0, 0.0, 0.0]]], dtype=tf.float64), + "atype": tf.constant([[0]], dtype=tf.int32), + }, + label_dict={"virial": tf.zeros((1, 9), dtype=tf.float64)}, + do_virial=True, + ) + + assert captured == { + "do_atomic_virial": False, + "do_deriv_c": True, + } + np.testing.assert_allclose( + to_tf_tensor(result["virial"]).numpy(), np.ones((1, 1, 9)) + ) + assert "atom_virial" not in result + + def test_compiled_train_step_is_tf_function_and_updates_model() -> None: trainer, model = _make_minimal_trainer() compiled = trainer._make_compiled_train_step(DEFAULT_TASK_KEY) From 8defd7c753110c0a580d3079f7786359aaaac10d Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 5 Jul 2026 13:16:06 +0800 Subject: [PATCH 06/14] perf(tf2): jit lower forward with DP_JIT --- deepmd/tf2/model/dp_model.py | 101 +++++++++++++++++++++++------- deepmd/tf2/utils/jit.py | 18 ++++++ deepmd/tf2/utils/serialization.py | 14 ++--- source/tests/tf2/test_training.py | 100 +++++++++++++++++++++++++++++ 4 files changed, 199 insertions(+), 34 deletions(-) create mode 100644 deepmd/tf2/utils/jit.py diff --git a/deepmd/tf2/model/dp_model.py b/deepmd/tf2/model/dp_model.py index 78d1886ba0..a0a44fba35 100644 --- a/deepmd/tf2/model/dp_model.py +++ b/deepmd/tf2/model/dp_model.py @@ -1,4 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + from deepmd.dpmodel.model import ( DPModelCommon, ) @@ -11,9 +15,12 @@ from deepmd.tf2.common import ( tf2_module, to_tensorflow_array, + unwrap_value, + wrap_value, ) from deepmd.tf2.env import ( stop_gradient, + tf, xp, ) from deepmd.tf2.make_model import ( @@ -22,6 +29,9 @@ from deepmd.tf2.model.base_model import ( forward_common_atomic, ) +from deepmd.tf2.utils.jit import ( + default_jit_compile, +) def make_tf2_dp_model_from_dpmodel( @@ -44,6 +54,19 @@ def make_tf2_dp_model_from_dpmodel( @tf2_module class tf2_model(dpmodel_model): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + if default_jit_compile(): + self._tf2_call_common_lower_formatted = tf.function( + self._call_common_lower_formatted, + reduce_retracing=True, + jit_compile=True, + ) + else: + self._tf2_call_common_lower_formatted = ( + self._call_common_lower_formatted + ) + def call_common( self, coord: xp.ndarray, @@ -100,31 +123,21 @@ def call_common_lower( nlist_is_formatted: bool = False, ) -> dict[str, xp.ndarray]: if nlist_is_formatted: - del comm_dict # tf2 path has no MPI ghost exchange - extended_coord = to_tensorflow_array(extended_coord) - extended_atype = to_tensorflow_array(extended_atype) - nlist = to_tensorflow_array(nlist) - nframes, _nall = extended_atype.shape[:2] - extended_coord = xp.reshape(extended_coord, (nframes, -1, 3)) - cc_ext, _, fp, ap, cs, input_prec = self._input_type_cast( - extended_coord, - fparam=to_tensorflow_array(fparam), - aparam=to_tensorflow_array(aparam), - charge_spin=to_tensorflow_array(charge_spin), - ) - model_predict = self.forward_common_atomic( - cc_ext, - extended_atype, - nlist, - mapping=to_tensorflow_array(mapping), - fparam=fp, - aparam=ap, - do_atomic_virial=do_atomic_virial, - do_deriv_c=do_deriv_c, - extended_coord_corr=to_tensorflow_array(extended_coord_corr), - charge_spin=cs, + return wrap_value( + self._tf2_call_common_lower_formatted( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + do_deriv_c=do_deriv_c, + extended_coord_corr=extended_coord_corr, + comm_dict=comm_dict, + charge_spin=charge_spin, + ) ) - return self._output_type_cast(model_predict, input_prec) return super().call_common_lower( to_tensorflow_array(extended_coord), to_tensorflow_array(extended_atype), @@ -138,6 +151,46 @@ def call_common_lower( charge_spin=to_tensorflow_array(charge_spin), ) + def _call_common_lower_formatted( + self, + extended_coord: xp.ndarray, + extended_atype: xp.ndarray, + nlist: xp.ndarray, + mapping: xp.ndarray | None = None, + fparam: xp.ndarray | None = None, + aparam: xp.ndarray | None = None, + do_atomic_virial: bool = False, + do_deriv_c: bool = True, + extended_coord_corr: xp.ndarray | None = None, + comm_dict: dict | None = None, + charge_spin: xp.ndarray | None = None, + ) -> dict[str, tf.Tensor]: + del comm_dict # tf2 path has no MPI ghost exchange + extended_coord = to_tensorflow_array(extended_coord) + extended_atype = to_tensorflow_array(extended_atype) + nlist = to_tensorflow_array(nlist) + nframes, _nall = extended_atype.shape[:2] + extended_coord = xp.reshape(extended_coord, (nframes, -1, 3)) + cc_ext, _, fp, ap, cs, input_prec = self._input_type_cast( + extended_coord, + fparam=to_tensorflow_array(fparam), + aparam=to_tensorflow_array(aparam), + charge_spin=to_tensorflow_array(charge_spin), + ) + model_predict = self.forward_common_atomic( + cc_ext, + extended_atype, + nlist, + mapping=to_tensorflow_array(mapping), + fparam=fp, + aparam=ap, + do_atomic_virial=do_atomic_virial, + do_deriv_c=do_deriv_c, + extended_coord_corr=to_tensorflow_array(extended_coord_corr), + charge_spin=cs, + ) + return unwrap_value(self._output_type_cast(model_predict, input_prec)) + def forward_common_atomic( self, extended_coord: xp.ndarray, diff --git a/deepmd/tf2/utils/jit.py b/deepmd/tf2/utils/jit.py new file mode 100644 index 0000000000..50aaca1b9f --- /dev/null +++ b/deepmd/tf2/utils/jit.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""TensorFlow 2 JIT configuration helpers.""" + +from __future__ import ( + annotations, +) + +import os + + +def env_flag(name: str) -> bool: + """Return whether an environment flag is enabled.""" + return os.environ.get(name, "").lower() in {"1", "true", "yes", "on"} + + +def default_jit_compile() -> bool: + """Return the default tf.function XLA setting for TF2 code paths.""" + return env_flag("DP_JIT") diff --git a/deepmd/tf2/utils/serialization.py b/deepmd/tf2/utils/serialization.py index 95fc2ad2dd..2f4b01e173 100644 --- a/deepmd/tf2/utils/serialization.py +++ b/deepmd/tf2/utils/serialization.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import json -import os from collections.abc import ( Callable, Mapping, @@ -40,19 +39,14 @@ from deepmd.tf2.utils._dpmodel import ( format_nlist, ) +from deepmd.tf2.utils.jit import ( + default_jit_compile, +) from deepmd.tf2.utils.multi_task import ( apply_shared_links, ) -def _env_flag(name: str) -> bool: - return os.environ.get(name, "").lower() in {"1", "true", "yes", "on"} - - -def _default_jit_compile() -> bool: - return _env_flag("DP_JIT") - - class _ExportConstantArray: """Array-like export constant that traces as ``tf.constant``.""" @@ -279,7 +273,7 @@ def deserialize_to_savedmodel( ) -> None: """Deserialize the dictionary to a TensorFlow SavedModel directory.""" if jit_compile is None: - jit_compile = _default_jit_compile() + jit_compile = default_jit_compile() # Import model registrations before deserializing the dpmodel payload. import deepmd.tf2.model.model # noqa: F401 diff --git a/source/tests/tf2/test_training.py b/source/tests/tf2/test_training.py index 246b509c39..7370e821bd 100644 --- a/source/tests/tf2/test_training.py +++ b/source/tests/tf2/test_training.py @@ -42,6 +42,9 @@ from deepmd.tf2.train.trainer import ( Trainer, ) +from deepmd.tf2.utils.jit import ( + default_jit_compile, +) pytestmark = pytest.mark.filterwarnings( "ignore:.*__init__ missing .*:DeprecationWarning:gast\\.astn" @@ -504,6 +507,103 @@ def compiled_step(*args: Any) -> dict[str, Any]: assert calls == ["a", "b"] +def test_default_jit_compile_reads_dp_jit( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.delenv("DP_JIT", raising=False) + assert default_jit_compile() is False + + monkeypatch.setenv("DP_JIT", "1") + assert default_jit_compile() is True + + +def test_compiled_steps_do_not_jit_whole_train_step( + monkeypatch: pytest.MonkeyPatch, +) -> None: + trainer, _ = _make_minimal_trainer() + trainer_module = importlib.import_module("deepmd.tf2.train.trainer") + captured: list[dict[str, Any]] = [] + + def fake_tf_function(*args: Any, **kwargs: Any) -> Any: + assert not args + captured.append(kwargs) + + def decorate(fn: Any) -> Any: + return fn + + return decorate + + monkeypatch.setenv("DP_JIT", "1") + monkeypatch.setattr(trainer_module.tf, "function", fake_tf_function) + + trainer._make_compiled_train_step(DEFAULT_TASK_KEY) + trainer._make_compiled_eval_step(DEFAULT_TASK_KEY) + + assert captured == [ + { + "reduce_retracing": True, + }, + { + "reduce_retracing": True, + }, + ] + + +def test_tf2_formatted_lower_forwards_dp_jit_to_tf_function( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dp_model_module = importlib.import_module("deepmd.tf2.model.dp_model") + captured: list[dict[str, Any]] = [] + + def fake_tf_function(fn: Any = None, *args: Any, **kwargs: Any) -> Any: + assert not args + captured.append(kwargs) + + def decorate(inner: Any) -> Any: + return inner + + return decorate(fn) if fn is not None else decorate + + class FakeDPModel: + pass + + monkeypatch.setenv("DP_JIT", "1") + monkeypatch.setattr(dp_model_module.tf, "function", fake_tf_function) + + model_class = dp_model_module.make_tf2_dp_model_from_dpmodel(FakeDPModel, object) + model_class() + + assert captured == [ + { + "reduce_retracing": True, + "jit_compile": True, + } + ] + + +def test_tf2_formatted_lower_does_not_wrap_without_dp_jit( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dp_model_module = importlib.import_module("deepmd.tf2.model.dp_model") + captured: list[dict[str, Any]] = [] + + def fake_tf_function(fn: Any = None, *args: Any, **kwargs: Any) -> Any: + del fn, args + captured.append(kwargs) + raise AssertionError("formatted lower should not be wrapped when DP_JIT is off") + + class FakeDPModel: + pass + + monkeypatch.delenv("DP_JIT", raising=False) + monkeypatch.setattr(dp_model_module.tf, "function", fake_tf_function) + + model_class = dp_model_module.make_tf2_dp_model_from_dpmodel(FakeDPModel, object) + model_class() + + assert captured == [] + + def test_train_step_passes_float_natoms_to_compiled_step() -> None: trainer = object.__new__(Trainer) trainer.lr_schedule = SimpleNamespace(value=lambda step: 0.25) From 84657a440b176a323d3cb5ca67e1a6ee45d5a522 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 5 Jul 2026 14:13:38 +0800 Subject: [PATCH 07/14] feat(tf2): support training compile option --- deepmd/jax/utils/type_embed.py | 28 ++++++++++++++- deepmd/tf2/model/dp_model.py | 6 +++- deepmd/tf2/train/trainer.py | 17 +++++++++ deepmd/utils/argcheck.py | 10 +++--- source/tests/tf2/test_training.py | 57 +++++++++++++++++++++++++++++++ 5 files changed, 112 insertions(+), 6 deletions(-) diff --git a/deepmd/jax/utils/type_embed.py b/deepmd/jax/utils/type_embed.py index eead05f978..f45d1aaab8 100644 --- a/deepmd/jax/utils/type_embed.py +++ b/deepmd/jax/utils/type_embed.py @@ -1,5 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import array_api_compat + import deepmd.jax.utils.network as _jax_network # noqa: F401 +from deepmd.dpmodel.array_api import ( + Array, +) from deepmd.dpmodel.utils.type_embed import TypeEmbedNet as TypeEmbedNetDP from deepmd.jax.common import ( flax_module, @@ -8,4 +13,25 @@ @flax_module class TypeEmbedNet(TypeEmbedNetDP): - pass + def call(self) -> Array: + """Compute type embeddings without querying tracer devices.""" + sample_array = self.embedding_net[0]["w"] + xp = array_api_compat.array_namespace(sample_array) + if not self.use_econf_tebd: + embed = self.embedding_net( + xp.eye( + self.ntypes, + dtype=sample_array.dtype, + device=None, + ) + ) + else: + embed = self.embedding_net(self.econf_tebd) + if self.padding: + embed_pad = xp.zeros( + (1, embed.shape[-1]), + dtype=embed.dtype, + device=None, + ) + embed = xp.concat([embed, embed_pad], axis=0) + return embed diff --git a/deepmd/tf2/model/dp_model.py b/deepmd/tf2/model/dp_model.py index a0a44fba35..87a939a890 100644 --- a/deepmd/tf2/model/dp_model.py +++ b/deepmd/tf2/model/dp_model.py @@ -56,7 +56,11 @@ def make_tf2_dp_model_from_dpmodel( class tf2_model(dpmodel_model): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - if default_jit_compile(): + self.set_enable_compile(default_jit_compile()) + + def set_enable_compile(self, enable_compile: bool) -> None: + """Enable or disable XLA compilation for the formatted lower path.""" + if enable_compile: self._tf2_call_common_lower_formatted = tf.function( self._call_common_lower_formatted, reduce_retracing=True, diff --git a/deepmd/tf2/train/trainer.py b/deepmd/tf2/train/trainer.py index 5c4b676df9..683d89495c 100644 --- a/deepmd/tf2/train/trainer.py +++ b/deepmd/tf2/train/trainer.py @@ -274,6 +274,7 @@ def __init__( training_params.get("tensorboard_log_dir", "log") ) self.tensorboard_freq = int(training_params.get("tensorboard_freq", 1)) + self.enable_compile = bool(training_params.get("enable_compile", False)) self.change_bias_after_training = bool( training_params.get("change_bias_after_training", False) ) @@ -283,6 +284,7 @@ def __init__( model_key: get_model(deepcopy(self.model_params_by_task[model_key])) for model_key in self.model_keys } + self._configure_model_compile() self.set_min_nbor_dist(min_nbor_dist) self.model = self.models if self.multi_task else self.models[DEFAULT_TASK_KEY] @@ -439,6 +441,21 @@ def _validate_unsupported_config(self, config: Mapping[str, Any]) -> None: "TF2 training does not support model.modifier yet." ) + def _configure_model_compile(self) -> None: + """Apply training.enable_compile to TF2 models that support lower XLA.""" + if not self.enable_compile: + return + log.info("Enabling TF2 lower-forward XLA compilation.") + for model_key, model in self.models.items(): + set_enable_compile = getattr(model, "set_enable_compile", None) + if not callable(set_enable_compile): + log.warning( + "Model %s does not support training.enable_compile; ignoring.", + model_key, + ) + continue + set_enable_compile(True) + def _create_full_validator(self) -> Any | None: if not self._is_validation_requested("full_validation"): return None diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 0dc733cc44..0a7e854487 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -5441,10 +5441,12 @@ def training_args( bool, optional=True, default=False, - doc=doc_only_pt_expt_supported - + "Enable torch.compile to accelerate training. " - "Uses make_fx to decompose autograd into primitive ops, " - "then compiles with torch.compile/Inductor for kernel fusion. " + doc="(Supported Backend: PyTorch Experimental, TensorFlow2) " + "Enable backend compiler acceleration during training. " + "PyTorch Experimental uses make_fx to decompose autograd into " + "primitive ops, then compiles with torch.compile/Inductor for " + "kernel fusion. TensorFlow2 enables XLA jit_compile for the " + "formatted lower-forward path. " "The first training step will be slower due to one-time compilation.", ), ] diff --git a/source/tests/tf2/test_training.py b/source/tests/tf2/test_training.py index 7370e821bd..87daab4ca3 100644 --- a/source/tests/tf2/test_training.py +++ b/source/tests/tf2/test_training.py @@ -581,6 +581,41 @@ class FakeDPModel: ] +def test_tf2_formatted_lower_forwards_enable_compile_to_tf_function( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dp_model_module = importlib.import_module("deepmd.tf2.model.dp_model") + captured: list[dict[str, Any]] = [] + + def fake_tf_function(fn: Any = None, *args: Any, **kwargs: Any) -> Any: + assert not args + captured.append(kwargs) + + def decorate(inner: Any) -> Any: + return inner + + return decorate(fn) if fn is not None else decorate + + class FakeDPModel: + pass + + monkeypatch.delenv("DP_JIT", raising=False) + monkeypatch.setattr(dp_model_module.tf, "function", fake_tf_function) + + model_class = dp_model_module.make_tf2_dp_model_from_dpmodel(FakeDPModel, object) + model = model_class() + assert captured == [] + + model.set_enable_compile(True) + + assert captured == [ + { + "reduce_retracing": True, + "jit_compile": True, + } + ] + + def test_tf2_formatted_lower_does_not_wrap_without_dp_jit( monkeypatch: pytest.MonkeyPatch, ) -> None: @@ -604,6 +639,28 @@ class FakeDPModel: assert captured == [] +def test_trainer_applies_enable_compile_to_models() -> None: + trainer = object.__new__(Trainer) + calls: list[tuple[str, bool]] = [] + + class FakeModel: + def __init__(self, key: str) -> None: + self.key = key + + def set_enable_compile(self, enable_compile: bool) -> None: + calls.append((self.key, enable_compile)) + + trainer.enable_compile = True + trainer.models = { + "a": FakeModel("a"), + "b": FakeModel("b"), + } + + Trainer._configure_model_compile(trainer) + + assert calls == [("a", True), ("b", True)] + + def test_train_step_passes_float_natoms_to_compiled_step() -> None: trainer = object.__new__(Trainer) trainer.lr_schedule = SimpleNamespace(value=lambda step: 0.25) From 1f1cf09aa9edd0973ea93b1b7f48d7e4595fa7e3 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 5 Jul 2026 14:30:38 +0800 Subject: [PATCH 08/14] fix(tf2): address training review comments --- deepmd/jax/utils/type_embed.py | 28 +---- deepmd/tf2/entrypoints/freeze.py | 5 +- deepmd/tf2/make_model.py | 1 + deepmd/tf2/model/dp_model.py | 1 + deepmd/tf2/train/trainer.py | 40 ++++-- deepmd/tf2/utils/multi_task.py | 16 +-- deepmd/tf2/utils/serialization.py | 13 +- source/tests/tf2/test_training.py | 202 +++++++++++++++++++++++++++++- 8 files changed, 245 insertions(+), 61 deletions(-) diff --git a/deepmd/jax/utils/type_embed.py b/deepmd/jax/utils/type_embed.py index f45d1aaab8..eead05f978 100644 --- a/deepmd/jax/utils/type_embed.py +++ b/deepmd/jax/utils/type_embed.py @@ -1,10 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import array_api_compat - import deepmd.jax.utils.network as _jax_network # noqa: F401 -from deepmd.dpmodel.array_api import ( - Array, -) from deepmd.dpmodel.utils.type_embed import TypeEmbedNet as TypeEmbedNetDP from deepmd.jax.common import ( flax_module, @@ -13,25 +8,4 @@ @flax_module class TypeEmbedNet(TypeEmbedNetDP): - def call(self) -> Array: - """Compute type embeddings without querying tracer devices.""" - sample_array = self.embedding_net[0]["w"] - xp = array_api_compat.array_namespace(sample_array) - if not self.use_econf_tebd: - embed = self.embedding_net( - xp.eye( - self.ntypes, - dtype=sample_array.dtype, - device=None, - ) - ) - else: - embed = self.embedding_net(self.econf_tebd) - if self.padding: - embed_pad = xp.zeros( - (1, embed.shape[-1]), - dtype=embed.dtype, - device=None, - ) - embed = xp.concat([embed, embed_pad], axis=0) - return embed + pass diff --git a/deepmd/tf2/entrypoints/freeze.py b/deepmd/tf2/entrypoints/freeze.py index ebc216408f..af65523c87 100644 --- a/deepmd/tf2/entrypoints/freeze.py +++ b/deepmd/tf2/entrypoints/freeze.py @@ -5,9 +5,6 @@ annotations, ) -from copy import ( - deepcopy, -) from typing import ( Any, ) @@ -70,7 +67,7 @@ def select_model_branch( f"{list(model_def_script['model_dict'])}." ) resolved_head = model_alias_dict[head] - selected = deepcopy(data) + selected = data.copy() selected["model"] = data["model"]["model_dict"][resolved_head] selected["model_def_script"] = model_def_script["model_dict"][resolved_head] min_nbor_dist = data.get("min_nbor_dist") diff --git a/deepmd/tf2/make_model.py b/deepmd/tf2/make_model.py index 2976807508..6ff2091505 100644 --- a/deepmd/tf2/make_model.py +++ b/deepmd/tf2/make_model.py @@ -177,6 +177,7 @@ def no_pbc() -> tuple[Array, Array, Array, Array]: mapping = to_tensorflow_array(mapping_tensor) extended_coord = xp.reshape(extended_coord, (nframes, -1, 3)) if coord_corr is not None: + coord_corr = xp.reshape(coord_corr, (nframes, nloc, 3)) mapping_idx = xp.tile( xp.reshape(mapping, (nframes, -1, 1)), (1, 1, 3), diff --git a/deepmd/tf2/model/dp_model.py b/deepmd/tf2/model/dp_model.py index 87a939a890..df8b45b129 100644 --- a/deepmd/tf2/model/dp_model.py +++ b/deepmd/tf2/model/dp_model.py @@ -150,6 +150,7 @@ def call_common_lower( fparam=to_tensorflow_array(fparam), aparam=to_tensorflow_array(aparam), do_atomic_virial=do_atomic_virial, + do_deriv_c=do_deriv_c, extended_coord_corr=to_tensorflow_array(extended_coord_corr), comm_dict=comm_dict, charge_spin=to_tensorflow_array(charge_spin), diff --git a/deepmd/tf2/train/trainer.py b/deepmd/tf2/train/trainer.py index 683d89495c..1c5fae165c 100644 --- a/deepmd/tf2/train/trainer.py +++ b/deepmd/tf2/train/trainer.py @@ -395,14 +395,23 @@ def sample( checkpoint_name=Path(self.save_ckpt).name, ) + restart_restore: tuple[str, Any] | None = None if init_model is not None: self._restore_model(init_model) self.step.assign(0) elif restart_model is not None: - self._restore_checkpoint(restart_model) - self.start_step = int(self.step.numpy()) + restart_restore = self._restore_checkpoint(restart_model) self._build_optimizer_slots() + if restart_restore is not None: + resolved, restore_status = restart_restore + restore_status.assert_existing_objects_matched() + self.start_step = int(self.step.numpy()) + log.info( + "Restarted TF2 training from %s at step %d", + resolved, + self.start_step, + ) self._compiled_train_steps: dict[str, Any] = {} self._compiled_eval_steps: dict[str, Any] = {} self.training_tasks = self._make_training_tasks() @@ -552,7 +561,7 @@ def _resolve_checkpoint_path(self, checkpoint_path: str) -> str: latest = tf.train.latest_checkpoint(str(candidate)) if latest is not None: return latest - if path.exists() or Path(f"{path}.index").exists(): + if path.is_file() or Path(f"{path}.index").is_file(): return str(path) raise FileNotFoundError( f"Cannot find TF2 checkpoint {checkpoint_path!r}. Expected a " @@ -562,15 +571,14 @@ def _resolve_checkpoint_path(self, checkpoint_path: str) -> str: def _restore_model(self, checkpoint_path: str) -> None: resolved = self._resolve_checkpoint_path(checkpoint_path) model_checkpoint = tf.train.Checkpoint(model=self.model_container) - model_checkpoint.restore(resolved).expect_partial() + restore_status = model_checkpoint.restore(resolved).expect_partial() + restore_status.assert_existing_objects_matched() log.info("Initialized TF2 model variables from %s", resolved) - def _restore_checkpoint(self, checkpoint_path: str) -> None: + def _restore_checkpoint(self, checkpoint_path: str) -> tuple[str, Any]: resolved = self._resolve_checkpoint_path(checkpoint_path) - self.checkpoint.restore(resolved).expect_partial() - log.info( - "Restarted TF2 training from %s at step %d", resolved, self.step.numpy() - ) + restore_status = self.checkpoint.restore(resolved).expect_partial() + return resolved, restore_status @staticmethod def _model_params_by_task( @@ -735,11 +743,13 @@ def _make_training_tasks(self) -> TrainingTaskCollection: probabilities=self.model_prob, ) - def run(self) -> None: + def run(self, tasks: TrainingTaskCollection | None = None) -> None: """Run TF2 training through the backend-independent trainer loop.""" + if tasks is None: + tasks = self.training_tasks log.info("Start to train %d steps.", self.num_steps) wall_start = time.time() - super().run(self.training_tasks) + super().run(tasks) if self.change_bias_after_training: self._change_bias_after_training() if self.rank_context.is_chief: @@ -1019,11 +1029,15 @@ def _save_full_validation_checkpoint( self._write_checkpoint_directory(save_path, step=step) def _write_checkpoint_directory(self, directory: Path, *, step: int) -> None: - self.step.assign(step) if directory.exists(): shutil.rmtree(directory) + checkpoint = tf.train.Checkpoint( + step=tf.Variable(step, dtype=tf.int64, trainable=False, name="step"), + optimizer=self.optimizer, + model=self.model_container, + ) manager = tf.train.CheckpointManager( - self.checkpoint, + checkpoint, directory=str(directory), max_to_keep=1, checkpoint_name=directory.stem, diff --git a/deepmd/tf2/utils/multi_task.py b/deepmd/tf2/utils/multi_task.py index a2f1b3111f..de794c70fb 100644 --- a/deepmd/tf2/utils/multi_task.py +++ b/deepmd/tf2/utils/multi_task.py @@ -332,14 +332,7 @@ def _share_fitting( _share_tf2_state_attrs( link_class, base_class, - excluded={ - "bias_atom_e", - "case_embd", - "fparam_avg", - "fparam_inv_std", - "aparam_avg", - "aparam_inv_std", - }, + shared_attr_names={"nets"}, ) @@ -397,11 +390,12 @@ def _share_tf2_state_attrs( link_class: Any, base_class: Any, *, - excluded: set[str], + shared_attr_names: set[str], ) -> None: - for name, value in list(vars(link_class).items()): - if name in excluded or name.startswith("_"): + for name in shared_attr_names: + if not hasattr(link_class, name) or not hasattr(base_class, name): continue + value = getattr(link_class, name) if _is_shareable_tf2_state(value): setattr(link_class, name, getattr(base_class, name)) diff --git a/deepmd/tf2/utils/serialization.py b/deepmd/tf2/utils/serialization.py index 2f4b01e173..9b17af4462 100644 --- a/deepmd/tf2/utils/serialization.py +++ b/deepmd/tf2/utils/serialization.py @@ -612,7 +612,7 @@ def _resolve_checkpoint_path(path: Path) -> tuple[str, Path]: latest = tf.train.latest_checkpoint(str(candidate)) if latest is not None: return latest, candidate - if path.exists() or Path(f"{path}.index").exists(): + if path.is_file() or Path(f"{path}.index").is_file(): return str(path), path.parent raise FileNotFoundError( f"Cannot find TF2 checkpoint {str(path)!r}. Expected a CheckpointManager " @@ -630,10 +630,19 @@ def _restore_models_from_checkpoint( apply_shared_links(models, state.get("shared_links"), resume=True) container = _TaskModelContainer(models) checkpoint = tf.train.Checkpoint(model=container) - checkpoint.restore(checkpoint_path).expect_partial() + _materialize_module_variables(container) + restore_status = checkpoint.restore(checkpoint_path).expect_partial() + _materialize_module_variables(container) + restore_status.assert_existing_objects_matched() return models +def _materialize_module_variables(module: tf.Module) -> None: + """Touch tracked variables before validating checkpoint restore status.""" + tuple(module.variables) + tuple(module.trainable_variables) + + def _build_models(model_def_script: dict[str, Any]) -> dict[str, BaseModel]: if "model_dict" in model_def_script: return { diff --git a/source/tests/tf2/test_training.py b/source/tests/tf2/test_training.py index 87daab4ca3..1978776e17 100644 --- a/source/tests/tf2/test_training.py +++ b/source/tests/tf2/test_training.py @@ -46,9 +46,12 @@ default_jit_compile, ) -pytestmark = pytest.mark.filterwarnings( - "ignore:.*__init__ missing .*:DeprecationWarning:gast\\.astn" -) +pytestmark = [ + pytest.mark.filterwarnings( + "ignore:.*__init__ missing .*:DeprecationWarning:gast\\.astn" + ), + pytest.mark.timeout(60), +] class _LinearModel(tf.Module): @@ -325,6 +328,83 @@ def call_lower( assert isinstance(to_tf_tensor(result["energy_redu"]), tf.Tensor) +def test_model_call_from_call_lower_reshapes_coord_corr_for_mapping( + monkeypatch: pytest.MonkeyPatch, +) -> None: + make_model_module = importlib.import_module("deepmd.tf2.make_model") + captured: dict[str, Any] = {} + + def fake_communicate( + model_ret: dict[str, Any], + model_output_def: ModelOutputDef, + mapping: Any, + do_atomic_virial: bool = False, + ) -> dict[str, Any]: + del model_output_def, mapping, do_atomic_virial + return model_ret + + monkeypatch.setattr( + make_model_module, + "communicate_extended_output", + fake_communicate, + ) + + def call_lower( + extended_coord: Any, + extended_atype: Any, + nlist: Any, + mapping: Any, + **kwargs: Any, + ) -> dict[str, Any]: + del extended_coord, nlist, mapping + captured["extended_coord_corr"] = kwargs["extended_coord_corr"] + atype = to_tf_tensor(extended_atype) + assert atype is not None + return { + "energy": tf.ones( + tf.concat([tf.shape(atype), tf.constant([1], dtype=tf.int32)], axis=0), + dtype=tf.float64, + ) + } + + model_output_def = ModelOutputDef( + FittingOutputDef( + [ + OutputVariableDef( + "energy", + [1], + reducible=True, + r_differentiable=True, + ) + ] + ) + ) + + make_model_module.model_call_from_call_lower( + call_lower=call_lower, + rcut=1.0, + sel=[1], + mixed_types=False, + model_output_def=model_output_def, + coord=tf.constant([[[0.0, 0.0, 0.0], [0.5, 0.0, 0.0]]], dtype=tf.float64), + atype=tf.constant([[0, 0]], dtype=tf.int32), + box=None, + fparam=None, + aparam=None, + do_deriv_c=True, + coord_corr_for_virial=tf.constant( + [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]], + dtype=tf.float64, + ), + pass_lower_kwargs=True, + ) + + np.testing.assert_allclose( + to_tf_tensor(captured["extended_coord_corr"]).numpy(), + [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]], + ) + + def test_tf2_dp_model_call_common_uses_tf2_helper( monkeypatch: pytest.MonkeyPatch, ) -> None: @@ -391,6 +471,30 @@ def fake_helper(**kwargs: Any) -> dict[str, Any]: assert isinstance(to_tf_tensor(captured["coord"]), tf.Tensor) +def test_tf2_dp_model_call_common_lower_forwards_do_deriv_c() -> None: + dp_model_module = importlib.import_module("deepmd.tf2.model.dp_model") + captured: dict[str, Any] = {} + + class FakeDPModel: + def call_common_lower(self, *args: Any, **kwargs: Any) -> dict[str, Any]: + del args + captured.update(kwargs) + return {} + + model_class = dp_model_module.make_tf2_dp_model_from_dpmodel(FakeDPModel, object) + model = model_class() + + model.call_common_lower( + tf.zeros((1, 1, 3), dtype=tf.float64), + tf.zeros((1, 1), dtype=tf.int32), + tf.zeros((1, 1, 1), dtype=tf.int32), + do_deriv_c=False, + nlist_is_formatted=False, + ) + + assert captured["do_deriv_c"] is False + + def test_training_energy_call_keeps_atomic_virial_disabled() -> None: trainer = object.__new__(Trainer) captured: dict[str, Any] = {} @@ -661,6 +765,96 @@ def set_enable_compile(self, enable_compile: bool) -> None: assert calls == [("a", True), ("b", True)] +def test_write_checkpoint_directory_does_not_mutate_training_step( + tmp_path: Any, + monkeypatch: pytest.MonkeyPatch, +) -> None: + trainer = object.__new__(Trainer) + trainer.step = tf.Variable(7, dtype=tf.int64, trainable=False) + trainer.optimizer = tf.keras.optimizers.Adam() + trainer.model_container = tf.Module() + + def fake_write_training_state(directory: Any, *, step: int) -> None: + del directory, step + + monkeypatch.setattr(trainer, "_write_training_state", fake_write_training_state) + + Trainer._write_checkpoint_directory(trainer, tmp_path / "best", step=3) + + assert int(trainer.step.numpy()) == 7 + + +def test_serialization_checkpoint_directory_without_latest_is_not_prefix( + tmp_path: Any, + monkeypatch: pytest.MonkeyPatch, +) -> None: + serialization = importlib.import_module("deepmd.tf2.utils.serialization") + checkpoint_dir = tmp_path / "model.tf2" + checkpoint_dir.mkdir() + monkeypatch.setattr(serialization.tf.train, "latest_checkpoint", lambda path: None) + + with pytest.raises(FileNotFoundError): + serialization._resolve_checkpoint_path(checkpoint_dir) + + +def test_restore_models_from_checkpoint_validates_restore_status( + monkeypatch: pytest.MonkeyPatch, +) -> None: + serialization = importlib.import_module("deepmd.tf2.utils.serialization") + calls: list[str] = [] + + class FakeStatus: + def expect_partial(self) -> "FakeStatus": + calls.append("expect_partial") + return self + + def assert_existing_objects_matched(self) -> None: + calls.append("assert_existing_objects_matched") + + class FakeCheckpoint: + def __init__(self, **kwargs: Any) -> None: + del kwargs + + def restore(self, checkpoint_path: str) -> FakeStatus: + calls.append(checkpoint_path) + return FakeStatus() + + monkeypatch.setattr( + serialization, + "_build_models", + lambda model_def_script: {DEFAULT_TASK_KEY: tf.Module()}, + ) + monkeypatch.setattr(serialization, "_set_min_nbor_dist", lambda models, value: None) + monkeypatch.setattr( + serialization, + "apply_shared_links", + lambda models, shared_links, resume: None, + ) + monkeypatch.setattr(serialization.tf.train, "Checkpoint", FakeCheckpoint) + + serialization._restore_models_from_checkpoint( + "ckpt-1", + {"type_map": ["O"]}, + {"min_nbor_dist": None, "shared_links": None}, + ) + + assert calls == ["ckpt-1", "expect_partial", "assert_existing_objects_matched"] + + +def test_share_tf2_state_attrs_uses_allowlist_only() -> None: + multi_task = importlib.import_module("deepmd.tf2.utils.multi_task") + base_extra = tf.Variable(2.0) + link_extra = tf.Variable(1.0) + base = SimpleNamespace(nets=tf.Module(), extra=base_extra) + link = SimpleNamespace(nets=tf.Module(), extra=link_extra) + + multi_task._share_tf2_state_attrs(link, base, shared_attr_names={"nets"}) + + assert link.nets is base.nets + assert link.extra is link_extra + assert link.extra is not base_extra + + def test_train_step_passes_float_natoms_to_compiled_step() -> None: trainer = object.__new__(Trainer) trainer.lr_schedule = SimpleNamespace(value=lambda step: 0.25) @@ -766,7 +960,7 @@ class FakeData: type_map: ClassVar[list[str]] = ["O", "H"] def print_summary(self, *args: Any) -> None: - del args + pass def fake_get_data( params: dict[str, Any], From f23d383ef7920db40e017d9b25f6a8bd009144e9 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 5 Jul 2026 15:26:19 +0800 Subject: [PATCH 09/14] perf(tf2): compile prepared energy train step --- deepmd/tf2/make_model.py | 113 +++++++++--- deepmd/tf2/train/trainer.py | 281 +++++++++++++++++++++++++++--- source/tests/tf2/test_training.py | 16 ++ 3 files changed, 362 insertions(+), 48 deletions(-) diff --git a/deepmd/tf2/make_model.py b/deepmd/tf2/make_model.py index 6ff2091505..458e0b5a06 100644 --- a/deepmd/tf2/make_model.py +++ b/deepmd/tf2/make_model.py @@ -101,6 +101,85 @@ def model_call_from_call_lower( The keys are defined by the `ModelOutputDef`. """ + ( + extended_coord, + extended_atype, + nlist, + mapping, + fp, + ap, + cs, + extended_coord_corr, + nlist_is_formatted, + ) = prepare_lower_inputs( + rcut=rcut, + sel=sel, + mixed_types=mixed_types, + coord=coord, + atype=atype, + box=box, + fparam=fparam, + aparam=aparam, + coord_corr_for_virial=coord_corr_for_virial, + charge_spin=charge_spin, + neighbor_list=neighbor_list, + ) + lower_kwargs: dict[str, Any] = {"fparam": fp, "aparam": ap} + if pass_lower_kwargs: + if nlist_is_formatted: + lower_kwargs["nlist_is_formatted"] = True + lower_kwargs.update( + { + "do_atomic_virial": do_atomic_virial, + "do_deriv_c": do_deriv_c, + "charge_spin": cs, + } + ) + if extended_coord_corr is not None: + lower_kwargs["extended_coord_corr"] = extended_coord_corr + model_predict_lower = call_lower( + extended_coord, + extended_atype, + nlist, + mapping, + **lower_kwargs, + ) + model_predict = wrap_value( + communicate_extended_output( + unwrap_value(model_predict_lower), + model_output_def, + to_tf_tensor(mapping), + do_atomic_virial=do_atomic_virial, + ) + ) + return model_predict + + +def prepare_lower_inputs( + *, + rcut: float, + sel: list[int], + mixed_types: bool, + coord: Array, + atype: Array, + box: Array | None, + fparam: Array | None, + aparam: Array | None, + coord_corr_for_virial: Array | None = None, + charge_spin: Array | None = None, + neighbor_list: NeighborList | None = None, +) -> tuple[ + Array, + Array, + Array, + Array, + Array | None, + Array | None, + Array | None, + Array | None, + bool, +]: + """Build lower-interface tensors outside the train-step compiler boundary.""" cc = to_tensorflow_array(coord) atype = to_tensorflow_array(atype) bb = to_tensorflow_array(box) @@ -185,34 +264,16 @@ def no_pbc() -> tuple[Array, Array, Array, Array]: extended_coord_corr = xp.take_along_axis(coord_corr, mapping_idx, axis=1) else: extended_coord_corr = None - lower_kwargs: dict[str, Any] = {"fparam": fp, "aparam": ap} - if pass_lower_kwargs: - if uses_native_nlist_builder: - if not mixed_types: - nlist = nlist_distinguish_types(nlist, extended_atype, sel) - lower_kwargs["nlist_is_formatted"] = True - lower_kwargs.update( - { - "do_atomic_virial": do_atomic_virial, - "do_deriv_c": do_deriv_c, - "charge_spin": cs, - } - ) - if extended_coord_corr is not None: - lower_kwargs["extended_coord_corr"] = extended_coord_corr - model_predict_lower = call_lower( + if uses_native_nlist_builder and not mixed_types: + nlist = nlist_distinguish_types(nlist, extended_atype, sel) + return ( extended_coord, extended_atype, nlist, mapping, - **lower_kwargs, + fp, + ap, + cs, + extended_coord_corr, + uses_native_nlist_builder, ) - model_predict = wrap_value( - communicate_extended_output( - unwrap_value(model_predict_lower), - model_output_def, - to_tf_tensor(mapping), - do_atomic_virial=do_atomic_virial, - ) - ) - return model_predict diff --git a/deepmd/tf2/train/trainer.py b/deepmd/tf2/train/trainer.py index 1c5fae165c..2ec59b30a0 100644 --- a/deepmd/tf2/train/trainer.py +++ b/deepmd/tf2/train/trainer.py @@ -55,13 +55,20 @@ to_tensorflow_array, to_tf_tensor, unwrap_value, + wrap_value, ) from deepmd.tf2.env import ( tf, ) +from deepmd.tf2.make_model import ( + prepare_lower_inputs, +) from deepmd.tf2.model.model import ( get_model, ) +from deepmd.tf2.transform_output import ( + communicate_extended_output, +) from deepmd.tf2.utils.multi_task import ( apply_shared_links, sanitize_shared_links, @@ -90,6 +97,7 @@ log = logging.getLogger(__name__) TF2_TRAINING_STATE_FILE = "training_state.json" +TF2_FULL_STEP_XLA_DESCRIPTOR_TYPES = {"se_e2_a"} def get_loss( @@ -413,6 +421,8 @@ def sample( self.start_step, ) self._compiled_train_steps: dict[str, Any] = {} + self._compiled_prepare_steps: dict[str, Any] = {} + self._compiled_prepared_train_steps: dict[str, Any] = {} self._compiled_eval_steps: dict[str, Any] = {} self.training_tasks = self._make_training_tasks() self.summary_writer: Any | None = None @@ -786,15 +796,27 @@ def train_step(self, task: TrainingTask, step: int) -> TrainStepResult: cur_lr = float(self.lr_schedule.value(step)) input_dict, label_dict, natoms = self.get_data(is_train=True, task_key=task_key) do_virial = bool(label_dict.pop("_do_virial", True)) - more_loss = self._compiled_train_step( - task_key, - input_dict, - label_dict, - tf.constant(float(natoms), dtype=tf.float64), - tf.constant(cur_lr, dtype=tf.float64), - tf.constant(step + 1, dtype=tf.int64), - do_virial, - ) + if self._use_prepared_energy_step(task_key): + prepared = self._prepare_energy_batch(task_key, input_dict) + more_loss = self._compiled_prepared_energy_train_step( + task_key, + label_dict, + prepared, + tf.constant(float(natoms), dtype=tf.float64), + tf.constant(cur_lr, dtype=tf.float64), + tf.constant(step + 1, dtype=tf.int64), + do_virial, + ) + else: + more_loss = self._compiled_train_step( + task_key, + input_dict, + label_dict, + tf.constant(float(natoms), dtype=tf.float64), + tf.constant(cur_lr, dtype=tf.float64), + tf.constant(step + 1, dtype=tf.int64), + do_virial, + ) self._write_tensorboard_step( task_key, display_step=step + 1, @@ -810,6 +832,151 @@ def train_step(self, task: TrainingTask, step: int) -> TrainStepResult: }, ) + def _use_prepared_energy_step(self, task_key: str) -> bool: + if not bool(getattr(self, "enable_compile", False)): + return False + if not isinstance(self.losses[task_key], EnergyLoss): + return False + return self._descriptor_type(task_key) in TF2_FULL_STEP_XLA_DESCRIPTOR_TYPES + + def _descriptor_type(self, task_key: str) -> str | None: + descriptor = self.model_params_by_task[task_key].get("descriptor") + if not isinstance(descriptor, Mapping): + return None + return str(descriptor.get("type", "se_e2_a")) + + def _prepare_energy_batch( + self, + task_key: str, + input_dict: dict[str, Any], + ) -> tuple[Any, Any, Any, Any, Any, Any, Any, Any]: + if task_key not in self._compiled_prepare_steps: + self._compiled_prepare_steps[task_key] = ( + self._make_compiled_prepare_energy_batch(task_key) + ) + prepared = self._compiled_prepare_steps[task_key]( + input_dict["coord"], + input_dict["atype"], + input_dict.get("box"), + input_dict.get("fparam"), + input_dict.get("aparam"), + input_dict.get("charge_spin"), + ) + return prepared[:-1] + + def _make_compiled_prepare_energy_batch(self, task_key: str) -> Any: + model = self.models[task_key] + + @tf.function(reduce_retracing=True) + def compiled_prepare_energy_batch( + coord: Any, + atype: Any, + box: Any, + fparam: Any, + aparam: Any, + charge_spin: Any, + ) -> tuple[Any, Any, Any, Any, Any, Any, Any, Any, bool]: + cc, bb, fp, ap, cs, _input_prec = model._input_type_cast( + to_tensorflow_array(coord), + box=to_tensorflow_array(box), + fparam=to_tensorflow_array(fparam), + aparam=to_tensorflow_array(aparam), + charge_spin=to_tensorflow_array(charge_spin), + ) + return prepare_lower_inputs( + rcut=model.get_rcut(), + sel=model.get_sel(), + mixed_types=model.mixed_types(), + coord=cc, + atype=to_tensorflow_array(atype), + box=bb, + fparam=fp, + aparam=ap, + charge_spin=cs, + ) + + return compiled_prepare_energy_batch + + def _compiled_prepared_energy_train_step( + self, + task_key: str, + label_dict: dict[str, Any], + prepared: tuple[Any, Any, Any, Any, Any, Any, Any, Any], + natoms: Any, + cur_lr: Any, + next_step: Any, + do_virial: bool, + ) -> dict[str, Any]: + if task_key not in self._compiled_prepared_train_steps: + self._compiled_prepared_train_steps[task_key] = ( + self._make_compiled_prepared_energy_train_step(task_key) + ) + return self._compiled_prepared_train_steps[task_key]( + label_dict, + *prepared, + natoms, + cur_lr, + next_step, + do_virial, + ) + + def _make_compiled_prepared_energy_train_step(self, task_key: str) -> Any: + variables = _unique_variables(self.models[task_key].trainable_variables) + + @tf.function(reduce_retracing=True, jit_compile=True) + def compiled_prepared_energy_train_step( + label_dict: dict[str, Any], + extended_coord: Any, + extended_atype: Any, + nlist: Any, + mapping: Any, + fparam: Any, + aparam: Any, + charge_spin: Any, + extended_coord_corr: Any, + natoms: Any, + cur_lr: Any, + next_step: Any, + do_virial: bool, + ) -> dict[str, Any]: + self._assign_learning_rate(cur_lr) + with tf.GradientTape() as tape: + model_pred = self._call_prepared_energy_model( + task_key, + extended_coord, + extended_atype, + nlist, + mapping, + fparam, + aparam, + charge_spin, + extended_coord_corr, + label_dict=label_dict, + do_virial=do_virial, + ) + loss, more_loss = self.losses[task_key]( + learning_rate=cur_lr, + natoms=natoms, + model_dict=model_pred, + label_dict=label_dict, + ) + loss_tensor = to_tf_tensor(loss) + gradients = tape.gradient(loss_tensor, variables) + gradients_and_variables = [ + (grad, var) + for grad, var in zip(gradients, variables, strict=True) + if grad is not None + ] + if self.gradient_max_norm > 0.0 and gradients_and_variables: + grads, vars_ = zip(*gradients_and_variables, strict=True) + grads, _ = tf.clip_by_global_norm(grads, self.gradient_max_norm) + gradients_and_variables = list(zip(grads, vars_, strict=True)) + self.optimizer.apply_gradients(gradients_and_variables) + self.step.assign(next_step) + return unwrap_value(more_loss) + + return compiled_prepared_energy_train_step + def _compiled_train_step( self, task_key: str, @@ -1162,19 +1329,10 @@ def _call_model( do_atomic_virial=False, do_deriv_c=do_virial, ) - model_pred = { - "atom_energy": model_ret["energy"], - "energy": model_ret["energy_redu"], - } - if model_ret.get("energy_derv_r") is not None: - model_pred["force"] = model_ret["energy_derv_r"].squeeze(-2) - if model_ret.get("energy_derv_c_redu") is not None: - model_pred["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) - elif label_dict is not None and "virial" in label_dict: - model_pred["virial"] = label_dict["virial"] - if "mask" in model_ret: - model_pred["mask"] = model_ret["mask"] - return model_pred + return self._energy_model_ret_to_loss_dict( + model_ret, + label_dict=label_dict, + ) return self.models[task_key].call( input_dict["coord"], input_dict["atype"], @@ -1184,6 +1342,85 @@ def _call_model( charge_spin=input_dict.get("charge_spin"), ) + def _call_prepared_energy_model( + self, + task_key: str, + extended_coord: Any, + extended_atype: Any, + nlist: Any, + mapping: Any, + fparam: Any, + aparam: Any, + charge_spin: Any, + extended_coord_corr: Any, + *, + label_dict: dict[str, Any] | None = None, + do_virial: bool = True, + ) -> dict[str, Any]: + model = self.models[task_key] + call_lower_formatted = getattr(model, "_call_common_lower_formatted", None) + if callable(call_lower_formatted): + model_ret_lower = wrap_value( + call_lower_formatted( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=False, + do_deriv_c=do_virial, + extended_coord_corr=extended_coord_corr, + charge_spin=charge_spin, + ) + ) + else: + model_ret_lower = model.call_common_lower( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=False, + do_deriv_c=do_virial, + extended_coord_corr=extended_coord_corr, + charge_spin=charge_spin, + nlist_is_formatted=True, + ) + model_ret = wrap_value( + communicate_extended_output( + unwrap_value(model_ret_lower), + model.model_output_def(), + to_tf_tensor(mapping), + do_atomic_virial=False, + ) + ) + return self._energy_model_ret_to_loss_dict( + model_ret, + label_dict=label_dict, + ) + + @staticmethod + def _energy_model_ret_to_loss_dict( + model_ret: dict[str, Any], + *, + label_dict: dict[str, Any] | None = None, + ) -> dict[str, Any]: + model_pred = { + "atom_energy": model_ret["energy"], + "energy": model_ret["energy_redu"], + } + if model_ret.get("energy_derv_r") is not None: + model_pred["force"] = model_ret["energy_derv_r"].squeeze(-2) + if model_ret.get("energy_derv_c_redu") is not None: + model_pred["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + elif label_dict is not None and "virial" in label_dict: + model_pred["virial"] = label_dict["virial"] + if "mask" in model_ret: + model_pred["mask"] = model_ret["mask"] + return model_pred + def _batch_needs_virial( self, task_key: str, diff --git a/source/tests/tf2/test_training.py b/source/tests/tf2/test_training.py index 1978776e17..076002c9a6 100644 --- a/source/tests/tf2/test_training.py +++ b/source/tests/tf2/test_training.py @@ -765,6 +765,22 @@ def set_enable_compile(self, enable_compile: bool) -> None: assert calls == [("a", True), ("b", True)] +def test_prepared_energy_step_only_uses_validated_descriptors() -> None: + trainer = object.__new__(Trainer) + trainer.enable_compile = True + trainer.losses = { + "se": EnergyLoss(starter_learning_rate=1.0), + "dpa3": EnergyLoss(starter_learning_rate=1.0), + } + trainer.model_params_by_task = { + "se": {"descriptor": {"type": "se_e2_a"}}, + "dpa3": {"descriptor": {"type": "dpa3"}}, + } + + assert Trainer._use_prepared_energy_step(trainer, "se") is True + assert Trainer._use_prepared_energy_step(trainer, "dpa3") is False + + def test_write_checkpoint_directory_does_not_mutate_training_step( tmp_path: Any, monkeypatch: pytest.MonkeyPatch, From 51679c288e703d73402977cfe672ca4b59c55f0b Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 5 Jul 2026 17:37:12 +0800 Subject: [PATCH 10/14] fix(tf2): enable compiled step for all outputs --- .github/workflows/test_python.yml | 6 +- deepmd/tf2/train/trainer.py | 138 ++++++++++++++++++++---------- dpa_adapt/cli.py | 1 - source/tests/tf2/test_training.py | 80 ++++++++++++++--- 4 files changed, 167 insertions(+), 58 deletions(-) diff --git a/.github/workflows/test_python.yml b/.github/workflows/test_python.yml index d690cb054c..a1ca8d64af 100644 --- a/.github/workflows/test_python.yml +++ b/.github/workflows/test_python.yml @@ -60,7 +60,7 @@ jobs: restore-keys: | test2-durations-combined-${{ matrix.python }}-${{ github.sha }} test2-durations-combined-${{ matrix.python }} - - run: pytest --cov=deepmd source/tests --splits 12 --group ${{ matrix.group }} --store-durations --clean-durations --durations-path=.test_durations --splitting-algorithm least_duration + - run: pytest --cov=deepmd source/tests --ignore=source/tests/tf2 --splits 12 --group ${{ matrix.group }} --store-durations --clean-durations --durations-path=.test_durations --splitting-algorithm least_duration env: NUM_WORKERS: 0 DP_CI_IMPORT_PADDLE_BEFORE_TF: 1 @@ -81,6 +81,10 @@ jobs: return "$status" } + run_pytest_allow_no_tests --cov=deepmd --cov-append \ + source/tests/tf2 \ + --splits 12 \ + --group ${{ matrix.group }} run_pytest_allow_no_tests --cov=deepmd --cov-append \ source/tests/consistent/io/test_io.py \ source/jax2tf_tests \ diff --git a/deepmd/tf2/train/trainer.py b/deepmd/tf2/train/trainer.py index 2ec59b30a0..3d28ee3c93 100644 --- a/deepmd/tf2/train/trainer.py +++ b/deepmd/tf2/train/trainer.py @@ -97,7 +97,6 @@ log = logging.getLogger(__name__) TF2_TRAINING_STATE_FILE = "training_state.json" -TF2_FULL_STEP_XLA_DESCRIPTOR_TYPES = {"se_e2_a"} def get_loss( @@ -796,9 +795,9 @@ def train_step(self, task: TrainingTask, step: int) -> TrainStepResult: cur_lr = float(self.lr_schedule.value(step)) input_dict, label_dict, natoms = self.get_data(is_train=True, task_key=task_key) do_virial = bool(label_dict.pop("_do_virial", True)) - if self._use_prepared_energy_step(task_key): - prepared = self._prepare_energy_batch(task_key, input_dict) - more_loss = self._compiled_prepared_energy_train_step( + if self._use_prepared_step(task_key): + prepared = self._prepare_lower_batch(task_key, input_dict) + more_loss = self._compiled_prepared_train_step( task_key, label_dict, prepared, @@ -832,27 +831,17 @@ def train_step(self, task: TrainingTask, step: int) -> TrainStepResult: }, ) - def _use_prepared_energy_step(self, task_key: str) -> bool: - if not bool(getattr(self, "enable_compile", False)): - return False - if not isinstance(self.losses[task_key], EnergyLoss): - return False - return self._descriptor_type(task_key) in TF2_FULL_STEP_XLA_DESCRIPTOR_TYPES - - def _descriptor_type(self, task_key: str) -> str | None: - descriptor = self.model_params_by_task[task_key].get("descriptor") - if not isinstance(descriptor, Mapping): - return None - return str(descriptor.get("type", "se_e2_a")) + def _use_prepared_step(self, task_key: str) -> bool: + return bool(getattr(self, "enable_compile", False)) - def _prepare_energy_batch( + def _prepare_lower_batch( self, task_key: str, input_dict: dict[str, Any], ) -> tuple[Any, Any, Any, Any, Any, Any, Any, Any]: if task_key not in self._compiled_prepare_steps: self._compiled_prepare_steps[task_key] = ( - self._make_compiled_prepare_energy_batch(task_key) + self._make_compiled_prepare_lower_batch(task_key) ) prepared = self._compiled_prepare_steps[task_key]( input_dict["coord"], @@ -864,11 +853,11 @@ def _prepare_energy_batch( ) return prepared[:-1] - def _make_compiled_prepare_energy_batch(self, task_key: str) -> Any: + def _make_compiled_prepare_lower_batch(self, task_key: str) -> Any: model = self.models[task_key] @tf.function(reduce_retracing=True) - def compiled_prepare_energy_batch( + def compiled_prepare_lower_batch( coord: Any, atype: Any, box: Any, @@ -895,9 +884,9 @@ def compiled_prepare_energy_batch( charge_spin=cs, ) - return compiled_prepare_energy_batch + return compiled_prepare_lower_batch - def _compiled_prepared_energy_train_step( + def _compiled_prepared_train_step( self, task_key: str, label_dict: dict[str, Any], @@ -909,7 +898,7 @@ def _compiled_prepared_energy_train_step( ) -> dict[str, Any]: if task_key not in self._compiled_prepared_train_steps: self._compiled_prepared_train_steps[task_key] = ( - self._make_compiled_prepared_energy_train_step(task_key) + self._make_compiled_prepared_train_step(task_key) ) return self._compiled_prepared_train_steps[task_key]( label_dict, @@ -920,11 +909,11 @@ def _compiled_prepared_energy_train_step( do_virial, ) - def _make_compiled_prepared_energy_train_step(self, task_key: str) -> Any: + def _make_compiled_prepared_train_step(self, task_key: str) -> Any: variables = _unique_variables(self.models[task_key].trainable_variables) @tf.function(reduce_retracing=True, jit_compile=True) - def compiled_prepared_energy_train_step( + def compiled_prepared_train_step( label_dict: dict[str, Any], extended_coord: Any, extended_atype: Any, @@ -941,7 +930,7 @@ def compiled_prepared_energy_train_step( ) -> dict[str, Any]: self._assign_learning_rate(cur_lr) with tf.GradientTape() as tape: - model_pred = self._call_prepared_energy_model( + model_pred = self._call_prepared_model( task_key, extended_coord, extended_atype, @@ -975,7 +964,7 @@ def compiled_prepared_energy_train_step( self.step.assign(next_step) return unwrap_value(more_loss) - return compiled_prepared_energy_train_step + return compiled_prepared_train_step def _compiled_train_step( self, @@ -1318,8 +1307,10 @@ def _call_model( label_dict: dict[str, Any] | None = None, do_virial: bool = True, ) -> dict[str, Any]: - if isinstance(self.losses[task_key], EnergyLoss): - model_ret = self.models[task_key].call_common( + model = self.models[task_key] + call_common = getattr(model, "call_common", None) + if callable(call_common): + model_ret = call_common( input_dict["coord"], input_dict["atype"], box=input_dict.get("box"), @@ -1329,11 +1320,12 @@ def _call_model( do_atomic_virial=False, do_deriv_c=do_virial, ) - return self._energy_model_ret_to_loss_dict( + return self._translate_model_ret_to_loss_dict( + task_key, model_ret, label_dict=label_dict, ) - return self.models[task_key].call( + return model.call( input_dict["coord"], input_dict["atype"], box=input_dict.get("box"), @@ -1342,7 +1334,7 @@ def _call_model( charge_spin=input_dict.get("charge_spin"), ) - def _call_prepared_energy_model( + def _call_prepared_model( self, task_key: str, extended_coord: Any, @@ -1396,31 +1388,85 @@ def _call_prepared_energy_model( do_atomic_virial=False, ) ) - return self._energy_model_ret_to_loss_dict( + return self._translate_model_ret_to_loss_dict( + task_key, model_ret, label_dict=label_dict, ) - @staticmethod - def _energy_model_ret_to_loss_dict( + def _translate_model_ret_to_loss_dict( + self, + task_key: str, model_ret: dict[str, Any], *, label_dict: dict[str, Any] | None = None, ) -> dict[str, Any]: - model_pred = { - "atom_energy": model_ret["energy"], - "energy": model_ret["energy_redu"], - } - if model_ret.get("energy_derv_r") is not None: - model_pred["force"] = model_ret["energy_derv_r"].squeeze(-2) - if model_ret.get("energy_derv_c_redu") is not None: - model_pred["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) - elif label_dict is not None and "virial" in label_dict: + translated_output_def = getattr( + self.models[task_key], + "translated_output_def", + None, + ) + if not callable(translated_output_def): + return model_ret + output_defs = translated_output_def() + model_pred = {} + for output_key, output_def in output_defs.items(): + source_key = output_def.name + if source_key not in model_ret or model_ret[source_key] is None: + continue + model_pred[output_key] = self._match_output_rank( + model_ret[source_key], + output_def, + ) + if ( + label_dict is not None + and "virial" in label_dict + and "virial" in output_defs + and "virial" not in model_pred + ): model_pred["virial"] = label_dict["virial"] - if "mask" in model_ret: - model_pred["mask"] = model_ret["mask"] return model_pred + @classmethod + def _match_output_rank(cls, value: Any, output_def: Any) -> Any: + expected_rank = len(output_def.shape) + (2 if output_def.atomic else 1) + axis = -(len(output_def.shape) + 1) + while True: + rank = cls._shape_rank(value) + if rank is None or rank <= expected_rank: + return value + if cls._shape_dim(value, axis) != 1: + return value + squeeze = getattr(value, "squeeze", None) + if not callable(squeeze): + value = tf.squeeze(value, axis=axis) + else: + value = squeeze(axis) + + @staticmethod + def _shape_rank(value: Any) -> int | None: + shape = getattr(value, "shape", None) + if shape is None: + return None + rank = getattr(shape, "rank", None) + if rank is not None: + return int(rank) + try: + return len(shape) + except TypeError: + return None + + @staticmethod + def _shape_dim(value: Any, axis: int) -> int | None: + shape = getattr(value, "shape", None) + if shape is None: + return None + try: + dim = shape[axis] + except (IndexError, TypeError): + return None + return getattr(dim, "value", dim) + def _batch_needs_virial( self, task_key: str, diff --git a/dpa_adapt/cli.py b/dpa_adapt/cli.py index b5a7b1f104..10d098715f 100644 --- a/dpa_adapt/cli.py +++ b/dpa_adapt/cli.py @@ -274,7 +274,6 @@ def _cmd_evaluate(args: argparse.Namespace) -> int: def _cmd_data_convert(args: argparse.Namespace) -> int: - type_map = _maybe_split_list(args.type_map) from dpa_adapt import ( diff --git a/source/tests/tf2/test_training.py b/source/tests/tf2/test_training.py index 076002c9a6..08f21441ca 100644 --- a/source/tests/tf2/test_training.py +++ b/source/tests/tf2/test_training.py @@ -2,6 +2,7 @@ """Unit tests for TensorFlow 2 training internals.""" import importlib +import os from types import ( SimpleNamespace, ) @@ -13,6 +14,12 @@ import numpy as np import pytest +if os.environ.get("DP_TEST_TF2_ONLY") != "1": + pytest.skip( + "TF2 tests require DP_TEST_TF2_ONLY=1", + allow_module_level=True, + ) + from deepmd.dpmodel.loss import ( EnergyLoss, ) @@ -524,6 +531,30 @@ def call_common( ), } + def translated_output_def(self) -> dict[str, Any]: + return { + "atom_energy": SimpleNamespace( + name="energy", + shape=[1], + atomic=True, + ), + "energy": SimpleNamespace( + name="energy_redu", + shape=[1], + atomic=False, + ), + "force": SimpleNamespace( + name="energy_derv_r", + shape=[3], + atomic=True, + ), + "virial": SimpleNamespace( + name="energy_derv_c_redu", + shape=[9], + atomic=False, + ), + } + trainer.models = {DEFAULT_TASK_KEY: SpyEnergyModel()} trainer.losses = {DEFAULT_TASK_KEY: EnergyLoss(starter_learning_rate=1.0)} @@ -542,9 +573,7 @@ def call_common( "do_atomic_virial": False, "do_deriv_c": True, } - np.testing.assert_allclose( - to_tf_tensor(result["virial"]).numpy(), np.ones((1, 1, 9)) - ) + np.testing.assert_allclose(to_tf_tensor(result["virial"]).numpy(), np.ones((1, 9))) assert "atom_virial" not in result @@ -765,20 +794,51 @@ def set_enable_compile(self, enable_compile: bool) -> None: assert calls == [("a", True), ("b", True)] -def test_prepared_energy_step_only_uses_validated_descriptors() -> None: +def test_prepared_step_only_depends_on_enable_compile() -> None: trainer = object.__new__(Trainer) trainer.enable_compile = True trainer.losses = { "se": EnergyLoss(starter_learning_rate=1.0), - "dpa3": EnergyLoss(starter_learning_rate=1.0), + "tensor": object(), } - trainer.model_params_by_task = { - "se": {"descriptor": {"type": "se_e2_a"}}, - "dpa3": {"descriptor": {"type": "dpa3"}}, + + assert Trainer._use_prepared_step(trainer, "se") is True + assert Trainer._use_prepared_step(trainer, "tensor") is True + + trainer.enable_compile = False + + assert Trainer._use_prepared_step(trainer, "se") is False + + +def test_model_ret_translation_uses_translated_output_def() -> None: + trainer = object.__new__(Trainer) + trainer.losses = {"tensor": object()} + trainer.models = { + "tensor": SimpleNamespace( + translated_output_def=lambda: { + "dipole": SimpleNamespace(name="dipole", shape=[3], atomic=True), + "global_dipole": SimpleNamespace( + name="dipole_redu", + shape=[3], + atomic=False, + ), + } + ) } - assert Trainer._use_prepared_energy_step(trainer, "se") is True - assert Trainer._use_prepared_energy_step(trainer, "dpa3") is False + model_pred = Trainer._translate_model_ret_to_loss_dict( + trainer, + "tensor", + { + "dipole": "local", + "dipole_redu": "global", + }, + ) + + assert model_pred == { + "dipole": "local", + "global_dipole": "global", + } def test_write_checkpoint_directory_does_not_mutate_training_step( From 2bc8085798d49c0420d3c60ee14c408eef6a74e2 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 5 Jul 2026 17:54:16 +0800 Subject: [PATCH 11/14] fix(tf2): avoid lower derivative kwarg mismatch --- deepmd/tf2/model/dp_model.py | 1 - deepmd/tf2/train/trainer.py | 13 ++++++-- source/tests/tf2/test_training.py | 54 +++++++++++++++++++++++++++---- 3 files changed, 59 insertions(+), 9 deletions(-) diff --git a/deepmd/tf2/model/dp_model.py b/deepmd/tf2/model/dp_model.py index df8b45b129..87a939a890 100644 --- a/deepmd/tf2/model/dp_model.py +++ b/deepmd/tf2/model/dp_model.py @@ -150,7 +150,6 @@ def call_common_lower( fparam=to_tensorflow_array(fparam), aparam=to_tensorflow_array(aparam), do_atomic_virial=do_atomic_virial, - do_deriv_c=do_deriv_c, extended_coord_corr=to_tensorflow_array(extended_coord_corr), comm_dict=comm_dict, charge_spin=to_tensorflow_array(charge_spin), diff --git a/deepmd/tf2/train/trainer.py b/deepmd/tf2/train/trainer.py index 3d28ee3c93..ada601b077 100644 --- a/deepmd/tf2/train/trainer.py +++ b/deepmd/tf2/train/trainer.py @@ -832,7 +832,12 @@ def train_step(self, task: TrainingTask, step: int) -> TrainStepResult: ) def _use_prepared_step(self, task_key: str) -> bool: - return bool(getattr(self, "enable_compile", False)) + if not bool(getattr(self, "enable_compile", False)): + return False + model = self.models[task_key] + return callable(getattr(model, "call_common_lower", None)) or callable( + getattr(model, "_call_common_lower_formatted", None) + ) def _prepare_lower_batch( self, @@ -1324,6 +1329,7 @@ def _call_model( task_key, model_ret, label_dict=label_dict, + do_virial=do_virial, ) return model.call( input_dict["coord"], @@ -1392,6 +1398,7 @@ def _call_prepared_model( task_key, model_ret, label_dict=label_dict, + do_virial=do_virial, ) def _translate_model_ret_to_loss_dict( @@ -1400,6 +1407,7 @@ def _translate_model_ret_to_loss_dict( model_ret: dict[str, Any], *, label_dict: dict[str, Any] | None = None, + do_virial: bool = True, ) -> dict[str, Any]: translated_output_def = getattr( self.models[task_key], @@ -1419,7 +1427,8 @@ def _translate_model_ret_to_loss_dict( output_def, ) if ( - label_dict is not None + not do_virial + and label_dict is not None and "virial" in label_dict and "virial" in output_defs and "virial" not in model_pred diff --git a/source/tests/tf2/test_training.py b/source/tests/tf2/test_training.py index 08f21441ca..9965f31e64 100644 --- a/source/tests/tf2/test_training.py +++ b/source/tests/tf2/test_training.py @@ -478,7 +478,7 @@ def fake_helper(**kwargs: Any) -> dict[str, Any]: assert isinstance(to_tf_tensor(captured["coord"]), tf.Tensor) -def test_tf2_dp_model_call_common_lower_forwards_do_deriv_c() -> None: +def test_tf2_dp_model_call_common_lower_does_not_forward_do_deriv_c() -> None: dp_model_module = importlib.import_module("deepmd.tf2.model.dp_model") captured: dict[str, Any] = {} @@ -499,7 +499,7 @@ def call_common_lower(self, *args: Any, **kwargs: Any) -> dict[str, Any]: nlist_is_formatted=False, ) - assert captured["do_deriv_c"] is False + assert "do_deriv_c" not in captured def test_training_energy_call_keeps_atomic_virial_disabled() -> None: @@ -794,16 +794,23 @@ def set_enable_compile(self, enable_compile: bool) -> None: assert calls == [("a", True), ("b", True)] -def test_prepared_step_only_depends_on_enable_compile() -> None: +def test_prepared_step_uses_enable_compile_and_model_capability() -> None: trainer = object.__new__(Trainer) trainer.enable_compile = True trainer.losses = { "se": EnergyLoss(starter_learning_rate=1.0), "tensor": object(), + "plain": object(), + } + trainer.models = { + "se": SimpleNamespace(call_common_lower=lambda: None), + "tensor": SimpleNamespace(_call_common_lower_formatted=lambda: None), + "plain": SimpleNamespace(), } assert Trainer._use_prepared_step(trainer, "se") is True assert Trainer._use_prepared_step(trainer, "tensor") is True + assert Trainer._use_prepared_step(trainer, "plain") is False trainer.enable_compile = False @@ -841,6 +848,41 @@ def test_model_ret_translation_uses_translated_output_def() -> None: } +def test_model_ret_translation_only_uses_label_virial_when_not_requested() -> None: + trainer = object.__new__(Trainer) + trainer.models = { + "energy": SimpleNamespace( + translated_output_def=lambda: { + "energy": SimpleNamespace(name="energy_redu", shape=[1], atomic=False), + "virial": SimpleNamespace( + name="energy_derv_c_redu", + shape=[9], + atomic=False, + ), + } + ) + } + label_dict = {"virial": tf.ones((1, 9), dtype=tf.float64)} + + skipped_virial = Trainer._translate_model_ret_to_loss_dict( + trainer, + "energy", + {"energy_redu": tf.ones((1, 1), dtype=tf.float64)}, + label_dict=label_dict, + do_virial=False, + ) + requested_virial = Trainer._translate_model_ret_to_loss_dict( + trainer, + "energy", + {"energy_redu": tf.ones((1, 1), dtype=tf.float64)}, + label_dict=label_dict, + do_virial=True, + ) + + assert skipped_virial["virial"] is label_dict["virial"] + assert "virial" not in requested_virial + + def test_write_checkpoint_directory_does_not_mutate_training_step( tmp_path: Any, monkeypatch: pytest.MonkeyPatch, @@ -851,7 +893,7 @@ def test_write_checkpoint_directory_does_not_mutate_training_step( trainer.model_container = tf.Module() def fake_write_training_state(directory: Any, *, step: int) -> None: - del directory, step + pass monkeypatch.setattr(trainer, "_write_training_state", fake_write_training_state) @@ -888,8 +930,8 @@ def assert_existing_objects_matched(self) -> None: calls.append("assert_existing_objects_matched") class FakeCheckpoint: - def __init__(self, **kwargs: Any) -> None: - del kwargs + def __init__(self, **_kwargs: Any) -> None: + pass def restore(self, checkpoint_path: str) -> FakeStatus: calls.append(checkpoint_path) From d721c65cbd8912cd65500d601b47dd184675f258 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 5 Jul 2026 18:05:34 +0800 Subject: [PATCH 12/14] fix(tf2): reject empty checkpoint directories --- deepmd/tf2/utils/serialization.py | 5 +++++ source/tests/tf2/test_training.py | 15 +++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/deepmd/tf2/utils/serialization.py b/deepmd/tf2/utils/serialization.py index 9b17af4462..5de81a2bdf 100644 --- a/deepmd/tf2/utils/serialization.py +++ b/deepmd/tf2/utils/serialization.py @@ -612,6 +612,11 @@ def _resolve_checkpoint_path(path: Path) -> tuple[str, Path]: latest = tf.train.latest_checkpoint(str(candidate)) if latest is not None: return latest, candidate + for candidate in candidates: + if candidate.is_dir(): + raise FileNotFoundError( + f"Cannot find a latest TF2 checkpoint in directory {str(candidate)!r}." + ) if path.is_file() or Path(f"{path}.index").is_file(): return str(path), path.parent raise FileNotFoundError( diff --git a/source/tests/tf2/test_training.py b/source/tests/tf2/test_training.py index 9965f31e64..abe7ff6981 100644 --- a/source/tests/tf2/test_training.py +++ b/source/tests/tf2/test_training.py @@ -915,6 +915,21 @@ def test_serialization_checkpoint_directory_without_latest_is_not_prefix( serialization._resolve_checkpoint_path(checkpoint_dir) +def test_serialization_empty_checkpoint_candidate_directory_is_not_prefix( + tmp_path: Any, + monkeypatch: pytest.MonkeyPatch, +) -> None: + serialization = importlib.import_module("deepmd.tf2.utils.serialization") + checkpoint_prefix = tmp_path / "model" + checkpoint_dir = tmp_path / "model.tf2" + checkpoint_dir.mkdir() + checkpoint_prefix.touch() + monkeypatch.setattr(serialization.tf.train, "latest_checkpoint", lambda path: None) + + with pytest.raises(FileNotFoundError): + serialization._resolve_checkpoint_path(checkpoint_prefix) + + def test_restore_models_from_checkpoint_validates_restore_status( monkeypatch: pytest.MonkeyPatch, ) -> None: From 28a27d35feb569f817c2356e6d7d5fe0f0ab13ef Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 5 Jul 2026 19:24:29 +0800 Subject: [PATCH 13/14] fix(ci): preload paddle before tf in pytest --- source/tests/__init__.py | 6 ------ source/tests/conftest.py | 8 ++++++++ 2 files changed, 8 insertions(+), 6 deletions(-) create mode 100644 source/tests/conftest.py diff --git a/source/tests/__init__.py b/source/tests/__init__.py index 16149c2cd0..6ceb116d85 100644 --- a/source/tests/__init__.py +++ b/source/tests/__init__.py @@ -1,7 +1 @@ # SPDX-License-Identifier: LGPL-3.0-or-later - -import os - -if os.environ.get("DP_CI_IMPORT_PADDLE_BEFORE_TF", "0") == "1": - import paddle # noqa: F401 - import tensorflow # noqa: F401 diff --git a/source/tests/conftest.py b/source/tests/conftest.py new file mode 100644 index 0000000000..eabaa25cd4 --- /dev/null +++ b/source/tests/conftest.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +import os + +if os.environ.get("DP_CI_IMPORT_PADDLE_BEFORE_TF", "0") == "1": + # Paddle must be loaded before TensorFlow in the CI test process. + import paddle # noqa: F401 + import tensorflow # noqa: F401 From 86f73f925907dc1eb854d54dc50bfecb05735dac Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 5 Jul 2026 20:53:37 +0800 Subject: [PATCH 14/14] test(tf2): avoid tensorboard dependency in summary test --- source/tests/tf2/test_training.py | 41 ++++++++++++++++++++++++++++--- 1 file changed, 38 insertions(+), 3 deletions(-) diff --git a/source/tests/tf2/test_training.py b/source/tests/tf2/test_training.py index abe7ff6981..9d7941ade1 100644 --- a/source/tests/tf2/test_training.py +++ b/source/tests/tf2/test_training.py @@ -3,6 +3,9 @@ import importlib import os +from contextlib import ( + nullcontext, +) from types import ( SimpleNamespace, ) @@ -1057,12 +1060,36 @@ def test_batch_needs_virial_handles_numpy_find_flags() -> None: def test_tensorboard_step_writes_tensors_without_float_sync( - tmp_path: Any, + monkeypatch: pytest.MonkeyPatch, ) -> None: + class FakeSummaryWriter: + def __init__(self) -> None: + self.flushes = 0 + + def as_default(self) -> Any: + return nullcontext() + + def flush(self) -> None: + self.flushes += 1 + + scalar_calls = [] + + def fake_summary_scalar( + name: str, + data: Any, + *, + step: Any = None, + description: Any = None, + ) -> bool: + del description + scalar_calls.append((name, data, step)) + return True + trainer = object.__new__(Trainer) - trainer.summary_writer = tf.summary.create_file_writer(str(tmp_path)) + trainer.summary_writer = FakeSummaryWriter() trainer.tensorboard_freq = 1 trainer.multi_task = False + monkeypatch.setattr(tf.summary, "scalar", fake_summary_scalar) def fail_if_float_sync_is_used(more_loss: dict[str, Any]) -> dict[str, float]: del more_loss @@ -1080,7 +1107,15 @@ def fail_if_float_sync_is_used(more_loss: dict[str, Any]) -> dict[str, float]: "l2_regularization": tf.constant(2.0, dtype=tf.float64), }, ) - trainer.summary_writer.close() + + assert [call[0] for call in scalar_calls] == ["learning_rate", "train/rmse"] + assert scalar_calls[0][1].dtype == tf.float64 + assert scalar_calls[0][1].numpy() == 0.1 + assert scalar_calls[0][2] == 1 + assert scalar_calls[1][1].dtype == tf.float64 + assert scalar_calls[1][1].numpy() == 1.0 + assert scalar_calls[1][2] == 1 + assert trainer.summary_writer.flushes == 1 def test_train_entrypoint_builds_data_without_descriptor_rcut(