diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 46ad8a6cd0..73e1bf5eb3 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -54,6 +54,7 @@ from deepmd.pt.train import ( training, ) +from deepmd.pt.train.trainer import Trainer as NewTrainer from deepmd.pt.train.wrapper import ( ModelWrapper, ) @@ -106,6 +107,7 @@ def get_trainer( init_frz_model: str | None = None, shared_links: dict[str, Any] | None = None, finetune_links: dict[str, Any] | None = None, + use_legacy: bool = False, ) -> training.Trainer: multi_task = "model_dict" in config.get("model", {}) @@ -200,19 +202,34 @@ def prepare_trainer_input_single( seed=data_seed, ) - trainer = training.Trainer( - config, - train_data, - stat_file_path=stat_file_path, - validation_data=validation_data, - init_model=init_model, - restart_model=restart_model, - finetune_model=finetune_model, - force_load=force_load, - shared_links=shared_links, - finetune_links=finetune_links, - init_frz_model=init_frz_model, - ) + if use_legacy: + trainer = training.Trainer( + config, + train_data, + stat_file_path=stat_file_path, + validation_data=validation_data, + init_model=init_model, + restart_model=restart_model, + finetune_model=finetune_model, + force_load=force_load, + shared_links=shared_links, + finetune_links=finetune_links, + init_frz_model=init_frz_model, + ) + else: + trainer = NewTrainer( + config, + train_data, + stat_file_path=stat_file_path, + validation_data=validation_data, + init_model=init_model, + restart_model=restart_model, + finetune_model=finetune_model, + force_load=force_load, + shared_links=shared_links, + finetune_links=finetune_links, + init_frz_model=init_frz_model, + ) return trainer diff --git a/deepmd/pt/train/__init__.py b/deepmd/pt/train/__init__.py old mode 100644 new mode 100755 index 6ceb116d85..0afd8101fe --- a/deepmd/pt/train/__init__.py +++ b/deepmd/pt/train/__init__.py @@ -1 +1,102 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +"""PyTorch training module with modular, extensible design. + +This module provides a clean, component-based training system: + +- TrainingConfig: Configuration management with validation +- DataManager: Data loading and batch iteration +- OptimizerFactory: Strategy pattern for optimizer creation +- CheckpointManager: Model persistence and recovery +- TrainingLoop: Specialized training step implementations +- HookManager: Extensible callback system +- TrainingLogger: Formatted output and file I/O +- Trainer: Main orchestrator coordinating all components + +Example: + >>> from deepmd.pt.train import Trainer, TrainingConfig + >>> + >>> # Create trainer + >>> trainer = Trainer( + ... config=config_dict, + ... training_data=train_dataset, + ... validation_data=valid_dataset, + ... ) + >>> + >>> # Run training + >>> trainer.run() + +Future extensions for multi-backend support: +- AbstractTrainingLoop can be extended for JAX/NumPy +- OptimizerFactory can support backend-specific optimizers +- DataManager can use backend-specific data loading +""" + +from deepmd.pt.train.checkpoint_manager import ( + CheckpointManager, +) +from deepmd.pt.train.config import ( + CheckpointConfig, + DisplayConfig, + LearningRateConfig, + OptimizerConfig, + TrainingConfig, +) +from deepmd.pt.train.data_manager import ( + DataManager, +) +from deepmd.pt.train.hooks import ( + HookManager, + HookPriority, + TensorBoardHook, + TimingHook, + TrainingHook, +) +from deepmd.pt.train.logger import ( + LossAccumulator, + TrainingLogger, +) +from deepmd.pt.train.optimizer_factory import ( + OptimizerFactory, +) +from deepmd.pt.train.trainer import ( + Trainer, +) + +# Keep old Trainer available for backward compatibility during transition +from deepmd.pt.train.training import Trainer as LegacyTrainer +from deepmd.pt.train.training_loop import ( + AdamTrainingLoop, + BaseTrainingLoop, + LKFEnergyTrainingLoop, + TrainingLoopFactory, +) +from deepmd.pt.train.wrapper import ( + ModelWrapper, +) + +__all__ = [ + # New modular components + "AdamTrainingLoop", + "BaseTrainingLoop", + "CheckpointConfig", + "CheckpointManager", + "DataManager", + "DisplayConfig", + "HookManager", + "HookPriority", + "LKFEnergyTrainingLoop", + "LearningRateConfig", + # Legacy support + "LegacyTrainer", + "LossAccumulator", + "ModelWrapper", + "OptimizerConfig", + "OptimizerFactory", + "TensorBoardHook", + "TimingHook", + "Trainer", + "TrainingConfig", + "TrainingHook", + "TrainingLogger", + "TrainingLoopFactory", +] diff --git a/deepmd/pt/train/checkpoint_manager.py b/deepmd/pt/train/checkpoint_manager.py new file mode 100755 index 0000000000..030eda2106 --- /dev/null +++ b/deepmd/pt/train/checkpoint_manager.py @@ -0,0 +1,407 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Checkpoint management for model saving, loading, and recovery. + +This module provides a clean interface for managing model checkpoints, +including saving, loading, automatic cleanup, and fine-tuning support. +""" + +from __future__ import ( + annotations, +) + +import logging +from copy import ( + deepcopy, +) +from pathlib import ( + Path, +) +from typing import ( + TYPE_CHECKING, + Any, +) + +import torch +import torch.distributed as dist + +from deepmd.common import ( + symlink_prefix_files, +) +from deepmd.pt.utils.env import ( + DEVICE, +) + +if TYPE_CHECKING: + from deepmd.pt.train.config import ( + CheckpointConfig, + ) + +log = logging.getLogger(__name__) + + +class CheckpointManager: + """Manages model checkpoints throughout training. + + This class handles saving checkpoints, loading for resume/finetune, + automatic cleanup of old checkpoints, and symlink management. + + Attributes + ---------- + config : CheckpointConfig + Configuration for checkpoint behavior. + rank : int + Distributed training rank. + latest_model : Path | None + Path to the most recent checkpoint. + """ + + def __init__( + self, + config: CheckpointConfig, + rank: int = 0, + ) -> None: + """Initialize checkpoint manager. + + Parameters + ---------- + config : CheckpointConfig + Configuration for checkpoint behavior. + rank : int + Distributed training rank (only rank 0 saves). + """ + self.config = config + self.rank = rank + self.latest_model: Path | None = None + self._saved_checkpoints: list[Path] = [] + + def save( + self, + step: int, + wrapper: torch.nn.Module, + optimizer: torch.optim.Optimizer | Any, + lr: float = 0.0, + ) -> Path | None: + """Save a checkpoint. + + Parameters + ---------- + step : int + Current training step. + wrapper : torch.nn.Module + Model wrapper (possibly wrapped in DDP). + optimizer : torch.optim.Optimizer | Any + Optimizer instance. + lr : float + Current learning rate. + + Returns + ------- + Path | None + Path to saved checkpoint, or None if not saved. + """ + if self.rank != 0: + return None + + # Get unwrapped module if using DDP + module = wrapper + if dist.is_available() and dist.is_initialized(): + if hasattr(wrapper, "module"): + module = wrapper.module + + # Update training info + if hasattr(module, "train_infos"): + module.train_infos["lr"] = float(lr) + module.train_infos["step"] = step + + # Prepare checkpoint path + save_path = Path(self.config.save_ckpt + f"-{step + 1}.pt") + + # Prepare optimizer state + optim_state = deepcopy(optimizer.state_dict()) + if "param_groups" in optim_state: + for group in optim_state["param_groups"]: + if "lr" in group: + group["lr"] = float(group["lr"]) + + # Save checkpoint + checkpoint = { + "model": module.state_dict(), + "optimizer": optim_state, + "step": step, + "lr": float(lr), + } + + torch.save(checkpoint, save_path) + self.latest_model = save_path + self._saved_checkpoints.append(save_path) + + # Update symlinks + symlink_prefix_files(save_path.stem, self.config.save_ckpt) + + # Write checkpoint file + with open("checkpoint", "w") as f: + f.write(str(save_path)) + + # Cleanup old checkpoints + self._cleanup_old_checkpoints() + + log.info(f"Saved checkpoint to {save_path}") + return save_path + + def _cleanup_old_checkpoints(self) -> None: + """Remove old checkpoints keeping only max_ckpt_keep most recent.""" + if len(self._saved_checkpoints) <= self.config.max_ckpt_keep: + return + + # Sort by modification time + checkpoint_files = [ + f + for f in Path(".").glob(f"{self.config.save_ckpt}*.pt") + if not f.is_symlink() and f.name.startswith(self.config.save_ckpt) + ] + checkpoint_files.sort(key=lambda x: x.stat().st_mtime) + + # Remove oldest + while len(checkpoint_files) > self.config.max_ckpt_keep: + old_file = checkpoint_files.pop(0) + try: + old_file.unlink() + log.debug(f"Removed old checkpoint: {old_file}") + except OSError as e: + log.warning(f"Failed to remove old checkpoint {old_file}: {e}") + + def load( + self, + checkpoint_path: str | Path, + wrapper: torch.nn.Module | None = None, + optimizer: torch.optim.Optimizer | Any = None, + strict: bool = True, + ) -> dict[str, Any]: + """Load a checkpoint. + + Parameters + ---------- + checkpoint_path : str | Path + Path to checkpoint file. + wrapper : torch.nn.Module | None + Model wrapper to load state into. + optimizer : torch.optim.Optimizer | Any | None + Optimizer to load state into. + strict : bool + Whether to strictly enforce state dict matching. + + Returns + ------- + dict[str, Any] + Loaded checkpoint dictionary. + """ + checkpoint_path = Path(checkpoint_path) + if not checkpoint_path.exists(): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + log.info(f"Loading checkpoint from {checkpoint_path}") + checkpoint = torch.load( + checkpoint_path, + map_location=DEVICE, + weights_only=True, + ) + + # Load model state + if wrapper is not None: + module = wrapper + if dist.is_available() and dist.is_initialized(): + if hasattr(wrapper, "module"): + module = wrapper.module + + state_dict = checkpoint.get("model", checkpoint) + if "model" in checkpoint: + state_dict = checkpoint["model"] + + module.load_state_dict(state_dict, strict=strict) + log.info("Model state loaded successfully") + + # Load optimizer state + if optimizer is not None and "optimizer" in checkpoint: + try: + optimizer.load_state_dict(checkpoint["optimizer"]) + log.info("Optimizer state loaded successfully") + except Exception as e: + log.warning(f"Failed to load optimizer state: {e}") + + return checkpoint + + def load_for_finetune( + self, + checkpoint_path: str | Path, + wrapper: torch.nn.Module, + force_load: bool = False, + ) -> dict[str, Any]: + """Load checkpoint for fine-tuning with optional key mapping. + + Parameters + ---------- + checkpoint_path : str | Path + Path to pretrained checkpoint. + wrapper : torch.nn.Module + Model wrapper to load into. + force_load : bool + If True, initialize missing keys from current model. + + Returns + ------- + dict[str, Any] + Loaded checkpoint info. + """ + checkpoint_path = Path(checkpoint_path) + log.info(f"Loading pretrained model from {checkpoint_path}") + + checkpoint = torch.load( + checkpoint_path, + map_location=DEVICE, + weights_only=True, + ) + + module = wrapper + if dist.is_available() and dist.is_initialized(): + if hasattr(wrapper, "module"): + module = wrapper.module + + state_dict = checkpoint.get("model", checkpoint) + if "model" in checkpoint: + state_dict = checkpoint["model"] + + target_state_dict = module.state_dict() + input_keys = set(state_dict.keys()) + target_keys = set(target_state_dict.keys()) + + missing_keys = target_keys - input_keys + unexpected_keys = input_keys - target_keys + + if missing_keys and force_load: + log.warning( + f"Force load: initializing {len(missing_keys)} missing keys from model" + ) + for key in missing_keys: + state_dict[key] = target_state_dict[key].clone().detach() + + # Load with strict=False to handle architecture differences + module.load_state_dict(state_dict, strict=False) + + if missing_keys and not force_load: + log.warning(f"Missing keys in checkpoint: {missing_keys}") + if unexpected_keys: + log.warning(f"Unexpected keys in checkpoint: {unexpected_keys}") + + return { + "state_dict": state_dict, + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + } + + def get_start_step(self, checkpoint_path: str | Path | None) -> int: + """Get the starting step from a checkpoint. + + Parameters + ---------- + checkpoint_path : str | Path | None + Path to checkpoint, or None. + + Returns + ------- + int + Step to resume from, or 0 for fresh start. + """ + if checkpoint_path is None: + return 0 + + try: + checkpoint = torch.load( + checkpoint_path, + map_location="cpu", + weights_only=True, + ) + step = checkpoint.get("step", 0) + log.info(f"Resuming from step {step}") + return step + except Exception as e: + log.warning(f"Failed to get step from checkpoint: {e}") + return 0 + + def get_latest_checkpoint(self) -> Path | None: + """Get the path to the latest checkpoint. + + Returns + ------- + Path | None + Path to latest checkpoint, or None. + """ + checkpoint_file = Path("checkpoint") + if checkpoint_file.exists(): + try: + latest = Path(checkpoint_file.read_text().strip()) + if latest.exists(): + return latest + except Exception: + pass + + # Fallback: find newest checkpoint file + checkpoints = list(Path(".").glob(f"{self.config.save_ckpt}*.pt")) + if checkpoints: + return max(checkpoints, key=lambda p: p.stat().st_mtime) + + return None + + def save_final( + self, + step: int, + wrapper: torch.nn.Module, + lr: float = 0.0, + ) -> Path | None: + """Save final checkpoint at end of training. + + Parameters + ---------- + step : int + Final step number. + wrapper : torch.nn.Module + Model wrapper. + lr : float + Final learning rate. + + Returns + ------- + Path | None + Path to saved checkpoint. + """ + if self.rank != 0: + return None + + # Get unwrapped module + module = wrapper + if dist.is_available() and dist.is_initialized(): + if hasattr(wrapper, "module"): + module = wrapper.module + + # Update training info + if hasattr(module, "train_infos"): + module.train_infos["lr"] = float(lr) + module.train_infos["step"] = step + + save_path = Path(self.config.save_ckpt + f"-{step}.pt") + + checkpoint = { + "model": module.state_dict(), + "step": step, + "lr": float(lr), + } + + torch.save(checkpoint, save_path) + self.latest_model = save_path + + symlink_prefix_files(save_path.stem, self.config.save_ckpt) + + with open("checkpoint", "w") as f: + f.write(str(save_path)) + + log.info(f"Saved final checkpoint to {save_path}") + return save_path diff --git a/deepmd/pt/train/config.py b/deepmd/pt/train/config.py new file mode 100755 index 0000000000..85c72d1ac8 --- /dev/null +++ b/deepmd/pt/train/config.py @@ -0,0 +1,263 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Training configuration management with validation and defaults. + +This module defines dataclasses for training configuration. It works +in conjunction with deepmd/utils/argcheck.py which validates the +input configuration against a schema defined using dargs. + +Configuration flow: +1. User provides input JSON/YAML +2. argcheck.py validates against schema (deepmd/utils/argcheck.py) +3. normalize() normalizes the configuration +4. This module converts the normalized dict to typed dataclasses + +Default values here should match those in argcheck.py's Argument definitions +to ensure consistency between validation and runtime behavior. +""" + +from __future__ import ( + annotations, +) + +import logging +from dataclasses import ( + dataclass, + field, +) +from typing import ( + Any, +) + +log = logging.getLogger(__name__) + + +@dataclass +class OptimizerConfig: + """Optimizer configuration with type-specific parameters.""" + + opt_type: str = "Adam" + weight_decay: float = 0.001 + momentum: float = 0.95 + adam_beta1: float = 0.9 + adam_beta2: float = 0.95 + lr_adjust: float = 10.0 + lr_adjust_coeff: float = 0.2 + muon_2d_only: bool = True + min_2d_dim: int = 1 + kf_blocksize: int = 5120 + kf_start_pref_e: float = 1.0 + kf_limit_pref_e: float = 1.0 + kf_start_pref_f: float = 1.0 + kf_limit_pref_f: float = 1.0 + + @classmethod + def from_dict(cls, params: dict[str, Any]) -> OptimizerConfig: + """Create OptimizerConfig from dictionary.""" + return cls( + opt_type=params.get("opt_type", "Adam"), + weight_decay=params.get("weight_decay", 0.001), + momentum=params.get("momentum", 0.95), + adam_beta1=params.get("adam_beta1", 0.9), + adam_beta2=params.get("adam_beta2", 0.95), + lr_adjust=params.get("lr_adjust", 10.0), + lr_adjust_coeff=params.get("lr_adjust_coeff", 0.2), + muon_2d_only=params.get("muon_2d_only", True), + min_2d_dim=params.get("min_2d_dim", 1), + kf_blocksize=params.get("kf_blocksize", 5120), + kf_start_pref_e=params.get("kf_start_pref_e", 1.0), + kf_limit_pref_e=params.get("kf_limit_pref_e", 1.0), + kf_start_pref_f=params.get("kf_start_pref_f", 1.0), + kf_limit_pref_f=params.get("kf_limit_pref_f", 1.0), + ) + + +@dataclass +class LearningRateConfig: + """Learning rate schedule configuration.""" + + start_lr: float = 1e-3 + stop_lr: float = 1e-8 + decay_steps: int = 100000 + decay_rate: float = 0.95 + stop_steps: int = 0 + + @classmethod + def from_dict(cls, params: dict[str, Any]) -> LearningRateConfig: + """Create LearningRateConfig from dictionary.""" + return cls( + start_lr=params.get("start_lr", 1e-3), + stop_lr=params.get("stop_lr", 1e-8), + decay_steps=params.get("decay_steps", 100000), + decay_rate=params.get("decay_rate", 0.95), + ) + + +@dataclass +class DisplayConfig: + """Training display and logging configuration. + + Default values match those in argcheck.py training_args(). + """ + + disp_file: str = "lcurve.out" # argcheck default: "lcurve.out" + disp_freq: int = 1000 # argcheck default: 1000 + disp_avg: bool = False # argcheck default: False (PyTorch only) + disp_training: bool = True # argcheck default: True + time_training: bool = True # argcheck default: True + tensorboard: bool = False # argcheck default: False + tensorboard_log_dir: str = "log" # argcheck default: "log" + tensorboard_freq: int = 1 # argcheck default: 1 + enable_profiler: bool = False # argcheck default: False + profiling: bool = False # argcheck default: False + profiling_file: str = "timeline.json" # argcheck default: "timeline.json" + + @classmethod + def from_dict(cls, params: dict[str, Any]) -> DisplayConfig: + """Create DisplayConfig from dictionary.""" + return cls( + disp_file=params.get("disp_file", "lcurve.out"), + disp_freq=params.get("disp_freq", 1000), + disp_avg=params.get("disp_avg", False), + disp_training=params.get("disp_training", True), + time_training=params.get("time_training", True), + tensorboard=params.get("tensorboard", False), + tensorboard_log_dir=params.get("tensorboard_log_dir", "log"), + tensorboard_freq=params.get("tensorboard_freq", 1), + enable_profiler=params.get("enable_profiler", False), + profiling=params.get("profiling", False), + profiling_file=params.get("profiling_file", "timeline.json"), + ) + + +@dataclass +class CheckpointConfig: + """Model checkpoint configuration. + + Default values match those in argcheck.py training_args(). + """ + + save_ckpt: str = "model.ckpt" # argcheck default: "model.ckpt" + save_freq: int = 1000 # argcheck default: 1000 + max_ckpt_keep: int = 5 # argcheck default: 5 + change_bias_after_training: bool = False # argcheck default: False + + @classmethod + def from_dict(cls, params: dict[str, Any]) -> CheckpointConfig: + """Create CheckpointConfig from dictionary.""" + return cls( + save_ckpt=params.get("save_ckpt", "model.ckpt"), + save_freq=params.get("save_freq", 1000), + max_ckpt_keep=params.get("max_ckpt_keep", 5), + change_bias_after_training=params.get("change_bias_after_training", False), + ) + + +@dataclass +class TrainingConfig: + """Complete training configuration container.""" + + num_steps: int = 0 + warmup_steps: int = 0 + warmup_start_factor: float = 0.0 + gradient_max_norm: float = 0.0 + optimizer: OptimizerConfig = field(default_factory=OptimizerConfig) + learning_rate: LearningRateConfig = field(default_factory=LearningRateConfig) + display: DisplayConfig = field(default_factory=DisplayConfig) + checkpoint: CheckpointConfig = field(default_factory=CheckpointConfig) + is_multitask: bool = False + optimizer_dict: dict[str, OptimizerConfig] | None = None + learning_rate_dict: dict[str, LearningRateConfig] | None = None + + @classmethod + def from_dict( + cls, + config: dict[str, Any], + model_keys: list[str] | None = None, + ) -> TrainingConfig: + """Create TrainingConfig from a configuration dictionary.""" + training_params = config.get("training", {}) + + num_steps = training_params.get("numb_steps", 0) + if num_steps <= 0: + raise ValueError(f"numb_steps must be positive, got {num_steps}") + + warmup_steps = training_params.get("warmup_steps", None) + warmup_ratio = training_params.get("warmup_ratio", None) + if warmup_steps is not None: + computed_warmup_steps = warmup_steps + elif warmup_ratio is not None: + if not 0 <= warmup_ratio < 1: + raise ValueError(f"warmup_ratio must be in [0, 1), got {warmup_ratio}") + computed_warmup_steps = int(warmup_ratio * num_steps) + if computed_warmup_steps == 0 and warmup_ratio > 0: + log.warning( + f"warmup_ratio {warmup_ratio} results in 0 warmup steps. " + "Consider using a larger ratio or specify warmup_steps directly." + ) + else: + computed_warmup_steps = 0 + + assert num_steps - computed_warmup_steps > 0 or computed_warmup_steps == 0, ( + "Warm up steps must be less than total training steps!" + ) + + is_multitask = model_keys is not None and len(model_keys) > 1 + + if is_multitask and training_params.get("optim_dict") is not None: + optim_dict = { + key: OptimizerConfig.from_dict(training_params["optim_dict"][key]) + for key in model_keys + if key in training_params["optim_dict"] + } + missing_keys = [key for key in model_keys if key not in optim_dict] + if missing_keys: + raise ValueError(f"Missing optimizer config for keys: {missing_keys}") + optimizer = optim_dict[model_keys[0]] + else: + optim_dict = None + optimizer = OptimizerConfig.from_dict(training_params) + + lr_params = config.get("learning_rate", {}) + if is_multitask and config.get("learning_rate_dict") is not None: + lr_dict = { + key: LearningRateConfig.from_dict(config["learning_rate_dict"][key]) + for key in model_keys + if key in config["learning_rate_dict"] + } + learning_rate = lr_dict.get( + model_keys[0], LearningRateConfig.from_dict(lr_params) + ) + else: + lr_dict = None + learning_rate = LearningRateConfig.from_dict(lr_params) + + learning_rate.stop_steps = num_steps - computed_warmup_steps + if lr_dict: + for lr_config in lr_dict.values(): + lr_config.stop_steps = num_steps - computed_warmup_steps + + return cls( + num_steps=num_steps, + warmup_steps=computed_warmup_steps, + warmup_start_factor=training_params.get("warmup_start_factor", 0.0), + gradient_max_norm=training_params.get("gradient_max_norm", 0.0), + optimizer=optimizer, + learning_rate=learning_rate, + display=DisplayConfig.from_dict(training_params), + checkpoint=CheckpointConfig.from_dict(training_params), + is_multitask=is_multitask, + optimizer_dict=optim_dict, + learning_rate_dict=lr_dict, + ) + + def get_optimizer_config(self, task_key: str = "Default") -> OptimizerConfig: + """Get optimizer config for a specific task.""" + if self.is_multitask and self.optimizer_dict is not None: + return self.optimizer_dict.get(task_key, self.optimizer) + return self.optimizer + + def get_lr_config(self, task_key: str = "Default") -> LearningRateConfig: + """Get learning rate config for a specific task.""" + if self.is_multitask and self.learning_rate_dict is not None: + return self.learning_rate_dict.get(task_key, self.learning_rate) + return self.learning_rate diff --git a/deepmd/pt/train/data_loader.py b/deepmd/pt/train/data_loader.py new file mode 100755 index 0000000000..460da7f844 --- /dev/null +++ b/deepmd/pt/train/data_loader.py @@ -0,0 +1,490 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Abstract data loader interface with DpLoaderSet compatibility. + +This module provides an abstract interface for data loading that: +1. Is compatible with existing DpLoaderSet +2. Allows future high-performance implementations without DpLoaderSet dependency +3. Provides a clean, backend-agnostic API + +Future implementations can: +- Replace DpLoaderSet with custom Dataset classes +- Implement prefetching and async data loading +- Use memory-mapped files for large datasets +- Implement custom batching strategies +""" + +from __future__ import ( + annotations, +) + +import logging +from abc import ( + ABC, + abstractmethod, +) +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Protocol, + runtime_checkable, +) + +import torch +from torch.utils.data import ( + DataLoader, +) + +from deepmd.pt.utils.env import ( + DEVICE, +) + +if TYPE_CHECKING: + from collections.abc import ( + Iterator, + ) + + from deepmd.pt.utils.dataloader import ( + DpLoaderSet, + ) + +log = logging.getLogger(__name__) + + +@runtime_checkable +class DataLoaderInterface(Protocol): + """Protocol defining the minimal data loader interface. + + Any data loader implementation (DpLoaderSet or future alternatives) + must satisfy this protocol to work with the training system. + + This allows gradual migration from DpLoaderSet to new implementations. + """ + + def __iter__(self) -> Iterator[dict[str, Any]]: + """Return iterator over batches.""" + return self + + def __next__(self) -> dict[str, Any]: + """Get next batch.""" + ... + + def add_data_requirement(self, requirement: Any) -> None: + """Add data requirements for labels.""" + ... + + def preload_and_modify_all_data_torch(self) -> None: + """Preload and apply modifiers to data.""" + ... + + def print_summary(self, name: str, weights: Any = None) -> None: + """Print dataset summary.""" + ... + + @property + def systems(self) -> list[Any]: + """Get list of systems/datasets.""" + ... + + +class BatchProcessor: + """Processes batches: device transfer and input/label splitting. + + This class centralizes batch processing logic, making it reusable + across different data loader implementations. + """ + + def __init__( + self, + device: torch.device = DEVICE, + input_keys: list[str] | None = None, + ) -> None: + """Initialize batch processor. + + Parameters + ---------- + device : torch.device + Target device for tensors. + input_keys : list[str] | None + Keys that are considered model inputs. + """ + self.device = device + self.input_keys = input_keys or [ + "coord", + "atype", + "spin", + "box", + "fparam", + "aparam", + ] + + def process( + self, batch_data: dict[str, Any] + ) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: + """Process a batch: transfer to device and split inputs/labels. + + Parameters + ---------- + batch_data : dict[str, Any] + Raw batch data. + + Returns + ------- + tuple[dict[str, Any], dict[str, Any], dict[str, Any]] + (input_dict, label_dict, log_dict) + """ + # Transfer to device + processed = self._to_device(batch_data) + + # Split into inputs and labels + input_dict, label_dict = self._split_inputs_labels(processed) + + # Create log dict + log_dict = self._create_log_dict(processed) + + return input_dict, label_dict, log_dict + + def _to_device(self, batch_data: dict[str, Any]) -> dict[str, Any]: + """Transfer batch data to target device.""" + result = {} + for key, value in batch_data.items(): + if key in ("sid", "fid", "box") or "find_" in key: + result[key] = value + elif isinstance(value, list): + result[key] = [ + item.to(self.device, non_blocking=True) + if isinstance(item, torch.Tensor) + else item + for item in value + ] + elif isinstance(value, torch.Tensor): + result[key] = value.to(self.device, non_blocking=True) + else: + result[key] = value + return result + + def _split_inputs_labels( + self, batch_data: dict[str, Any] + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Split batch into input and label dictionaries.""" + input_dict: dict[str, Any] = {} + label_dict: dict[str, Any] = {} + + for key, value in batch_data.items(): + if key in self.input_keys: + # Special handling for fparam with find_fparam + if key == "fparam" and batch_data.get("find_fparam", 0.0) == 0.0: + continue + input_dict[key] = value + elif key not in ("sid", "fid"): + label_dict[key] = value + + return input_dict, label_dict + + def _create_log_dict(self, batch_data: dict[str, Any]) -> dict[str, Any]: + """Create log dictionary from batch data.""" + log_dict: dict[str, Any] = {} + if "fid" in batch_data: + log_dict["fid"] = batch_data["fid"] + if "sid" in batch_data: + log_dict["sid"] = batch_data["sid"] + return log_dict + + +class AbstractDataLoader(ABC): + """Abstract base class for data loaders. + + This class defines the interface that all data loaders must implement. + It provides a common API for training code while allowing different + underlying implementations. + + Implementations: + - DpLoaderSetAdapter: Wraps existing DpLoaderSet + - Future: High-performance data loader without DpLoaderSet + """ + + def __init__( + self, + device: torch.device = DEVICE, + ) -> None: + """Initialize abstract data loader. + + Parameters + ---------- + device : torch.device + Target device for data. + """ + self.device = device + self._batch_processor = BatchProcessor(device) + + @abstractmethod + def __iter__( + self, + ) -> Iterator[tuple[dict[str, Any], dict[str, Any], dict[str, Any]]]: + """Return iterator yielding processed batches. + + Yields + ------ + tuple[dict[str, Any], dict[str, Any], dict[str, Any]] + (input_dict, label_dict, log_dict) + """ + pass + + @abstractmethod + def __len__(self) -> int: + """Return number of batches per epoch.""" + pass + + @abstractmethod + def add_data_requirement(self, requirement: Any) -> None: + """Add data requirement for labels. + + Parameters + ---------- + requirement : Any + Data requirement specification. + """ + pass + + @abstractmethod + def preload_data(self) -> None: + """Preload data into memory.""" + pass + + @abstractmethod + def print_summary(self, name: str) -> None: + """Print data summary. + + Parameters + ---------- + name : str + Name to display in summary. + """ + pass + + +class DpLoaderSetAdapter(AbstractDataLoader): + """Adapter making DpLoaderSet compatible with AbstractDataLoader. + + This adapter wraps the existing DpLoaderSet implementation, + allowing it to be used with the new training system without + modifying the original class. + + This is the transition solution - future implementations can + replace this with high-performance alternatives. + """ + + def __init__( + self, + dp_loader_set: DpLoaderSet, + device: torch.device = DEVICE, + ) -> None: + """Initialize adapter. + + Parameters + ---------- + dp_loader_set : DpLoaderSet + Existing DpLoaderSet instance. + device : torch.device + Target device. + """ + super().__init__(device) + self._dp_loader = dp_loader_set + self._iterator: Iterator[dict[str, Any]] | None = None + + def __iter__( + self, + ) -> Iterator[tuple[dict[str, Any], dict[str, Any], dict[str, Any]]]: + """Return iterator over processed batches.""" + self._iterator = self._create_iterator() + return self + + def __next__(self) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: + """Get next processed batch.""" + if self._iterator is None: + self._iterator = self._create_iterator() + + batch = next(self._iterator) + return self._batch_processor.process(batch) + + def _create_iterator(self) -> Iterator[dict[str, Any]]: + """Create underlying iterator with automatic restart. + + Uses DpLoaderSet's __getitem__ method which handles: + - System sampling according to weights + - Iterator reset on exhaustion + - CPU device context for data loading + """ + # Import sampler utilities + from deepmd.pt.utils.dataloader import ( + get_sampler_from_params, + ) + + # Get or create sampler for weighted system sampling + if hasattr(self._dp_loader, "sampler") and self._dp_loader.sampler is not None: + sampler = self._dp_loader.sampler + else: + # Create default sampler with prob_sys_size + with torch.device("cpu"): + sampler = get_sampler_from_params(self._dp_loader, "prob_sys_size") + # Store sampler on dp_loader for consistency + self._dp_loader.sampler = sampler + + # Create DataLoader that wraps DpLoaderSet (not its internal dataloaders) + # This ensures __getitem__ is called with sampled indices + import torch.distributed as dist + + from deepmd.pt.utils.env import ( + NUM_WORKERS, + ) + + with torch.device("cpu"): + dataloader = DataLoader( + self._dp_loader, + sampler=sampler, + batch_size=None, + num_workers=NUM_WORKERS + if dist.is_available() and dist.is_initialized() + else 0, + drop_last=False, + collate_fn=lambda batch: batch, # prevent extra conversion + pin_memory=False, # Batch processor handles device transfer + ) + + def cycle_iterator() -> Iterator[dict[str, Any]]: + """Infinite iterator that cycles through the dataloader.""" + while True: + with torch.device("cpu"): + it = iter(dataloader) + yield from it + + return cycle_iterator() + + def __len__(self) -> int: + """Return total number of batches.""" + return self._dp_loader.total_batch + + def add_data_requirement(self, requirement: Any) -> None: + """Add data requirement to underlying DpLoaderSet.""" + self._dp_loader.add_data_requirement(requirement) + + def preload_data(self) -> None: + """Preload data via DpLoaderSet.""" + self._dp_loader.preload_and_modify_all_data_torch() + + def print_summary(self, name: str) -> None: + """Print summary via DpLoaderSet.""" + from deepmd.pt.utils.utils import ( + to_numpy_array, + ) + + weights = None + if hasattr(self._dp_loader, "sampler_list") and self._dp_loader.sampler_list: + # Get weights from first sampler as representative + if hasattr(self._dp_loader.sampler_list[0], "weights"): + weights = to_numpy_array(self._dp_loader.sampler_list[0].weights) + + # Handle case where sampler doesn't have weights (e.g., DistributedSampler) + if weights is None and hasattr(self._dp_loader, "systems"): + # Default: uniform weights + import numpy as np + + weights = np.ones(len(self._dp_loader.systems), dtype=np.float32) + + self._dp_loader.print_summary(name, weights) + + @property + def dp_loader_set(self) -> DpLoaderSet: + """Access underlying DpLoaderSet (for backward compatibility).""" + return self._dp_loader + + +class DataLoaderFactory: + """Factory for creating data loaders. + + This factory centralizes data loader creation and allows + easy switching between implementations. + """ + + # Registry of available implementations + _implementations: ClassVar[dict[str, type[AbstractDataLoader]]] = { + "dploaderset": DpLoaderSetAdapter, + } + + @classmethod + def register(cls, name: str, implementation: type[AbstractDataLoader]) -> None: + """Register a new data loader implementation. + + Parameters + ---------- + name : str + Identifier for the implementation. + implementation : type[AbstractDataLoader] + Data loader class. + """ + cls._implementations[name] = implementation + log.info(f"Registered data loader implementation: {name}") + + @classmethod + def create( + cls, + data_source: Any, + implementation: str = "dploaderset", + device: torch.device = DEVICE, + **kwargs: Any, + ) -> AbstractDataLoader: + """Create a data loader instance. + + Parameters + ---------- + data_source : Any + Source data (DpLoaderSet, paths, etc.). + implementation : str + Which implementation to use. + device : torch.device + Target device. + **kwargs : Any + Additional arguments for the implementation. + + Returns + ------- + AbstractDataLoader + Configured data loader. + + Raises + ------ + ValueError + If implementation is not registered. + """ + if implementation not in cls._implementations: + raise ValueError( + f"Unknown data loader implementation: {implementation}. " + f"Available: {list(cls._implementations.keys())}" + ) + + impl_class = cls._implementations[implementation] + return impl_class(data_source, device=device, **kwargs) + + @classmethod + def get_available_implementations(cls) -> list[str]: + """Get list of registered implementations.""" + return list(cls._implementations.keys()) + + +# Convenience functions +def create_data_loader( + data_source: Any, + implementation: str = "dploaderset", + device: torch.device = DEVICE, + **kwargs: Any, +) -> AbstractDataLoader: + """Create a data loader (convenience function).""" + return DataLoaderFactory.create(data_source, implementation, device, **kwargs) + + +def adapt_dploader_set( + dp_loader_set: DpLoaderSet, + device: torch.device = DEVICE, +) -> DpLoaderSetAdapter: + """Adapt existing DpLoaderSet (convenience function).""" + return DpLoaderSetAdapter(dp_loader_set, device) diff --git a/deepmd/pt/train/data_manager.py b/deepmd/pt/train/data_manager.py new file mode 100755 index 0000000000..5b787dee61 --- /dev/null +++ b/deepmd/pt/train/data_manager.py @@ -0,0 +1,311 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Data management for training and validation. + +This module handles data loading, batch iteration, and provides +a unified interface for both single-task and multi-task scenarios. + +It now uses the abstract DataLoader interface, allowing future +high-performance implementations to replace DpLoaderSet. +""" + +from __future__ import ( + annotations, +) + +import logging +from typing import ( + TYPE_CHECKING, + Any, +) + +from deepmd.pt.train.data_loader import ( + AbstractDataLoader, + DpLoaderSetAdapter, +) +from deepmd.pt.utils.dataloader import ( + DpLoaderSet, +) +from deepmd.pt.utils.env import ( + DEVICE, +) + +if TYPE_CHECKING: + import torch + +log = logging.getLogger(__name__) + + +class DataManager: + """Manages training and validation data. + + This class handles DataLoader creation, data iteration, and provides + a unified interface for both single-task and multi-task scenarios. + + Attributes + ---------- + is_multitask : bool + Whether managing data for multiple tasks. + """ + + def __init__( + self, + training_data: DpLoaderSet + | dict[str, DpLoaderSet] + | AbstractDataLoader + | dict[str, AbstractDataLoader], + validation_data: DpLoaderSet + | dict[str, DpLoaderSet] + | AbstractDataLoader + | dict[str, AbstractDataLoader] + | None = None, + training_params: dict[str, Any] | None = None, + device: torch.device = DEVICE, + data_loader_impl: str = "dploaderset", + ) -> None: + """Initialize data manager. + + Parameters + ---------- + training_data : DpLoaderSet | dict[str, DpLoaderSet] | AbstractDataLoader | dict[str, AbstractDataLoader] + Training dataset(s). Can be DpLoaderSet (legacy) or AbstractDataLoader. + validation_data : DpLoaderSet | dict[str, DpLoaderSet] | AbstractDataLoader | dict[str, AbstractDataLoader] | None + Validation dataset(s). + training_params : dict[str, Any] | None + Training configuration parameters (kept for API compatibility). + device : torch.device + Device to transfer data to. + data_loader_impl : str + Data loader implementation to use (for future extensions). + """ + self.device = device + self._data_loader_impl = data_loader_impl + + # Determine if multi-task + self.is_multitask = isinstance(training_data, dict) + + # Convert inputs to AbstractDataLoader if needed + if self.is_multitask: + self.training_loaders: dict[str, AbstractDataLoader] = ( + self._ensure_data_loaders(training_data) + ) + self.validation_loaders: dict[str, AbstractDataLoader | None] = ( + self._ensure_data_loaders(validation_data) + if validation_data + else dict.fromkeys(training_data) + ) + self.model_keys = list(self.training_loaders.keys()) + else: + self.training_loaders = self._ensure_data_loader(training_data) + self.validation_loaders = ( + self._ensure_data_loader(validation_data) if validation_data else None + ) + + log.info(f"DataManager initialized with {data_loader_impl} implementation") + + def _ensure_data_loader( + self, data: DpLoaderSet | AbstractDataLoader + ) -> AbstractDataLoader: + """Ensure data is wrapped as AbstractDataLoader. + + Parameters + ---------- + data : DpLoaderSet | AbstractDataLoader + Input data. + + Returns + ------- + AbstractDataLoader + Wrapped or original data loader. + """ + if isinstance(data, AbstractDataLoader): + return data + elif isinstance(data, DpLoaderSet): + return DpLoaderSetAdapter(data, self.device) + else: + raise TypeError(f"Unsupported data type: {type(data)}") + + def _ensure_data_loaders( + self, data: dict[str, DpLoaderSet | AbstractDataLoader] + ) -> dict[str, AbstractDataLoader]: + """Ensure all values in dict are AbstractDataLoader. + + Parameters + ---------- + data : dict[str, DpLoaderSet | AbstractDataLoader] + Input data dict. + + Returns + ------- + dict[str, AbstractDataLoader] + Dict with wrapped data loaders. + """ + return {key: self._ensure_data_loader(value) for key, value in data.items()} + + def get_train_batch( + self, task_key: str | None = None + ) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: + """Get next training batch. + + Parameters + ---------- + task_key : str | None + Task key for multi-task training. + + Returns + ------- + tuple[dict[str, Any], dict[str, Any], dict[str, Any]] + (input_dict, label_dict, log_dict) + """ + if self.is_multitask: + assert task_key is not None, "task_key required for multi-task" + return next(self.training_loaders[task_key]) + return next(self.training_loaders) + + def get_valid_batch( + self, task_key: str | None = None + ) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: + """Get next validation batch. + + Parameters + ---------- + task_key : str | None + Task key for multi-task training. + + Returns + ------- + tuple[dict[str, Any], dict[str, Any], dict[str, Any]] + (input_dict, label_dict, log_dict) + """ + loader = self._get_valid_loader(task_key) + if loader is None: + return {}, {}, {} + return next(loader) + + def _get_valid_loader( + self, task_key: str | None = None + ) -> AbstractDataLoader | None: + """Get validation loader for task.""" + if self.is_multitask: + assert task_key is not None + return self.validation_loaders.get(task_key) + return self.validation_loaders + + def get_valid_numb_batch(self, task_key: str | None = None) -> int: + """Get number of validation batches. + + For now, returns a default value. Future implementations + can derive this from the underlying data loader. + + Parameters + ---------- + task_key : str | None + Task key for multi-task training. + + Returns + ------- + int + Number of validation batches. + """ + loader = self._get_valid_loader(task_key) + if loader is None: + return 1 + # Try to get length, default to 1 if not available + try: + return len(loader) + except (TypeError, AttributeError): + return 1 + + def print_summary(self, rank: int = 0) -> None: + """Print dataset summaries. + + Parameters + ---------- + rank : int + Current process rank (only rank 0 prints). + """ + if rank != 0: + return + + if self.is_multitask: + for key in self.model_keys: + self.training_loaders[key].print_summary(f"training in {key}") + if self.validation_loaders.get(key): + self.validation_loaders[key].print_summary(f"validation in {key}") + else: + self.training_loaders.print_summary("training") + if self.validation_loaders: + self.validation_loaders.print_summary("validation") + + def add_data_requirements( + self, + requirements: Any, + task_key: str | None = None, + ) -> None: + """Add data requirements. + + Parameters + ---------- + requirements : Any + Data requirements to add. + task_key : str | None + Task key for multi-task training. + """ + if self.is_multitask: + assert task_key is not None + self.training_loaders[task_key].add_data_requirement(requirements) + if self.validation_loaders.get(task_key): + self.validation_loaders[task_key].add_data_requirement(requirements) + else: + self.training_loaders.add_data_requirement(requirements) + if self.validation_loaders: + self.validation_loaders.add_data_requirement(requirements) + + def preload_data(self, task_key: str | None = None) -> None: + """Preload data into memory. + + Parameters + ---------- + task_key : str | None + Task key for multi-task training. + """ + if self.is_multitask: + assert task_key is not None + self.training_loaders[task_key].preload_data() + if self.validation_loaders.get(task_key): + self.validation_loaders[task_key].preload_data() + else: + self.training_loaders.preload_data() + if self.validation_loaders: + self.validation_loaders.preload_data() + + @staticmethod + def create_from_dploader_set( + training_data: DpLoaderSet | dict[str, DpLoaderSet], + validation_data: DpLoaderSet | dict[str, DpLoaderSet] | None = None, + device: torch.device = DEVICE, + ) -> DataManager: + """Factory method to create from DpLoaderSet(s). + + This is the primary method for creating DataManager from + existing DpLoaderSet instances. + + Parameters + ---------- + training_data : DpLoaderSet | dict[str, DpLoaderSet] + Training dataset(s). + validation_data : DpLoaderSet | dict[str, DpLoaderSet] | None + Validation dataset(s). + device : torch.device + Device to transfer data to. + + Returns + ------- + DataManager + Configured data manager. + """ + return DataManager( + training_data=training_data, + validation_data=validation_data, + device=device, + data_loader_impl="dploaderset", + ) diff --git a/deepmd/pt/train/hooks.py b/deepmd/pt/train/hooks.py new file mode 100755 index 0000000000..96691f7de5 --- /dev/null +++ b/deepmd/pt/train/hooks.py @@ -0,0 +1,527 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Hook system for extensible training callbacks. + +This module provides a hook system that allows users to inject custom +logic at various points during the training process without modifying +the core training code. + +Example usage: + >>> class MyHook(TrainingHook): + ... def on_step_end(self, step, logs): + ... if step % 100 == 0: + ... print(f"Step {step}: loss = {logs.get('loss', 'N/A')}") + >>> trainer.register_hook(MyHook()) +""" + +from __future__ import ( + annotations, +) + +import logging +from abc import ( + ABC, +) +from enum import ( + IntEnum, +) +from typing import ( + TYPE_CHECKING, + Any, +) + +if TYPE_CHECKING: + from collections.abc import ( + Mapping, + ) + +log = logging.getLogger(__name__) + + +class HookPriority(IntEnum): + """Priority levels for hook execution order. + + Lower values execute first. Use these to control the order + in which hooks are called when multiple hooks are registered. + """ + + HIGHEST = 0 # System-critical hooks (e.g., checkpointing) + HIGH = 10 # Important monitoring hooks + NORMAL = 20 # Default priority for user hooks + LOW = 30 # Logging and non-critical hooks + LOWEST = 40 # Debug and development hooks + + +class TrainingHook(ABC): + """Base class for training hooks. + + Subclass this to implement custom callbacks at various points + during training. All methods are optional - only override the + ones you need. + + Attributes + ---------- + priority : HookPriority + Execution priority of this hook. Lower values execute first. + """ + + priority: HookPriority = HookPriority.NORMAL + + def on_train_begin(self, logs: Mapping[str, Any] | None = None) -> None: + """Called at the beginning of training. + + Parameters + ---------- + logs : Mapping[str, Any] | None + Dictionary of initial values (e.g., start_step). + """ + pass + + def on_train_end(self, logs: Mapping[str, Any] | None = None) -> None: + """Called at the end of training. + + Parameters + ---------- + logs : Mapping[str, Any] | None + Dictionary of final training metrics. + """ + pass + + def on_epoch_begin(self, epoch: int, logs: Mapping[str, Any] | None = None) -> None: + """Called at the beginning of each epoch. + + Parameters + ---------- + epoch : int + Current epoch number. + logs : Mapping[str, Any] | None + Dictionary of metrics from previous epoch. + """ + pass + + def on_epoch_end(self, epoch: int, logs: Mapping[str, Any] | None = None) -> None: + """Called at the end of each epoch. + + Parameters + ---------- + epoch : int + Current epoch number. + logs : Mapping[str, Any] | None + Dictionary of metrics for this epoch. + """ + pass + + def on_step_begin(self, step: int, logs: Mapping[str, Any] | None = None) -> None: + """Called at the beginning of each training step. + + Parameters + ---------- + step : int + Current step number. + logs : Mapping[str, Any] | None + Dictionary of current training state. + """ + pass + + def on_step_end(self, step: int, logs: Mapping[str, Any] | None = None) -> None: + """Called at the end of each training step. + + Parameters + ---------- + step : int + Current step number. + logs : Mapping[str, Any] | None + Dictionary of metrics for this step (loss, lr, etc.). + """ + pass + + def on_validation_begin( + self, step: int, logs: Mapping[str, Any] | None = None + ) -> None: + """Called at the beginning of validation. + + Parameters + ---------- + step : int + Current step number. + logs : Mapping[str, Any] | None + Dictionary of current training state. + """ + pass + + def on_validation_end( + self, step: int, logs: Mapping[str, Any] | None = None + ) -> None: + """Called at the end of validation. + + Parameters + ---------- + step : int + Current step number. + logs : Mapping[str, Any] | None + Dictionary of validation metrics. + """ + pass + + def on_save_checkpoint( + self, step: int, checkpoint_path: str, logs: Mapping[str, Any] | None = None + ) -> None: + """Called when a checkpoint is saved. + + Parameters + ---------- + step : int + Current step number. + checkpoint_path : str + Path where checkpoint was saved. + logs : Mapping[str, Any] | None + Dictionary of current training state. + """ + pass + + +class HookManager: + """Manages a collection of training hooks. + + This class handles registration and execution of hooks, ensuring + they are called in priority order and handling any errors gracefully. + + Attributes + ---------- + hooks : list[TrainingHook] + List of registered hooks sorted by priority. + """ + + def __init__(self) -> None: + """Initialize an empty hook manager.""" + self.hooks: list[TrainingHook] = [] + + def register(self, hook: TrainingHook) -> None: + """Register a new hook. + + The hook is inserted in priority order (lower priority values first). + + Parameters + ---------- + hook : TrainingHook + The hook instance to register. + """ + # Insert in priority order + idx = len(self.hooks) + for i, existing_hook in enumerate(self.hooks): + if hook.priority < existing_hook.priority: + idx = i + break + self.hooks.insert(idx, hook) + log.debug( + f"Registered hook {hook.__class__.__name__} at priority {hook.priority}" + ) + + def unregister(self, hook: TrainingHook) -> None: + """Unregister a previously registered hook. + + Parameters + ---------- + hook : TrainingHook + The hook instance to unregister. + + Raises + ------ + ValueError + If the hook is not found in the registered hooks. + """ + if hook in self.hooks: + self.hooks.remove(hook) + log.debug(f"Unregistered hook {hook.__class__.__name__}") + else: + raise ValueError(f"Hook {hook.__class__.__name__} not found") + + def _call_hooks(self, method_name: str, *args: Any, **kwargs: Any) -> None: + """Internal method to call a hook method on all registered hooks. + + Parameters + ---------- + method_name : str + Name of the hook method to call. + *args : Any + Positional arguments to pass to the hook method. + **kwargs : Any + Keyword arguments to pass to the hook method. + """ + for hook in self.hooks: + try: + method = getattr(hook, method_name) + method(*args, **kwargs) + except Exception as e: + log.warning( + f"Hook {hook.__class__.__name__}.{method_name} failed: {e}", + exc_info=True, + ) + + def on_train_begin(self, logs: Mapping[str, Any] | None = None) -> None: + """Trigger on_train_begin on all hooks.""" + self._call_hooks("on_train_begin", logs) + + def on_train_end(self, logs: Mapping[str, Any] | None = None) -> None: + """Trigger on_train_end on all hooks.""" + self._call_hooks("on_train_end", logs) + + def on_epoch_begin(self, epoch: int, logs: Mapping[str, Any] | None = None) -> None: + """Trigger on_epoch_begin on all hooks.""" + self._call_hooks("on_epoch_begin", epoch, logs) + + def on_epoch_end(self, epoch: int, logs: Mapping[str, Any] | None = None) -> None: + """Trigger on_epoch_end on all hooks.""" + self._call_hooks("on_epoch_end", epoch, logs) + + def on_step_begin(self, step: int, logs: Mapping[str, Any] | None = None) -> None: + """Trigger on_step_begin on all hooks.""" + self._call_hooks("on_step_begin", step, logs) + + def on_step_end(self, step: int, logs: Mapping[str, Any] | None = None) -> None: + """Trigger on_step_end on all hooks.""" + self._call_hooks("on_step_end", step, logs) + + def on_validation_begin( + self, step: int, logs: Mapping[str, Any] | None = None + ) -> None: + """Trigger on_validation_begin on all hooks.""" + self._call_hooks("on_validation_begin", step, logs) + + def on_validation_end( + self, step: int, logs: Mapping[str, Any] | None = None + ) -> None: + """Trigger on_validation_end on all hooks.""" + self._call_hooks("on_validation_end", step, logs) + + def on_save_checkpoint( + self, step: int, checkpoint_path: str, logs: Mapping[str, Any] | None = None + ) -> None: + """Trigger on_save_checkpoint on all hooks.""" + self._call_hooks("on_save_checkpoint", step, checkpoint_path, logs) + + +class TensorBoardHook(TrainingHook): + """Hook for logging metrics to TensorBoard. + + This hook automatically logs training metrics to TensorBoard + at specified intervals. + + Attributes + ---------- + log_dir : str + Directory for TensorBoard logs. + log_freq : int + Frequency of logging (every N steps). + """ + + def __init__(self, log_dir: str = "logs", log_freq: int = 1) -> None: + """Initialize TensorBoard hook. + + Parameters + ---------- + log_dir : str + Directory for TensorBoard logs. + log_freq : int + Frequency of logging (every N steps). + """ + self.log_dir = log_dir + self.log_freq = log_freq + self.writer = None + self._initialized = False + + def on_train_begin(self, logs: Mapping[str, Any] | None = None) -> None: + """Initialize TensorBoard writer.""" + try: + from torch.utils.tensorboard import ( + SummaryWriter, + ) + + self.writer = SummaryWriter(log_dir=self.log_dir) + self._initialized = True + log.info(f"TensorBoard logging enabled at {self.log_dir}") + except ImportError: + log.warning( + "TensorBoard not available. Install with: pip install tensorboard" + ) + + def on_step_end(self, step: int, logs: Mapping[str, Any] | None = None) -> None: + """Log metrics to TensorBoard.""" + if not self._initialized or self.writer is None: + return + + if logs is None: + return + + display_step = step + 1 + if display_step % self.log_freq != 0 and display_step != 1: + return + + # Log common metrics + if "loss" in logs: + self.writer.add_scalar("train/loss", logs["loss"], display_step) + if "lr" in logs: + self.writer.add_scalar("train/learning_rate", logs["lr"], display_step) + + # Log task-specific metrics + for key, value in logs.items(): + if key not in ["loss", "lr"] and isinstance(value, (int, float)): + self.writer.add_scalar(f"train/{key}", value, display_step) + + def on_validation_end( + self, step: int, logs: Mapping[str, Any] | None = None + ) -> None: + """Log validation metrics to TensorBoard.""" + if not self._initialized or self.writer is None or logs is None: + return + + display_step = step + 1 + for key, value in logs.items(): + if isinstance(value, (int, float)): + self.writer.add_scalar(f"val/{key}", value, display_step) + + def on_train_end(self, logs: Mapping[str, Any] | None = None) -> None: + """Close TensorBoard writer.""" + if self.writer is not None: + self.writer.close() + self._initialized = False + + +class TimingHook(TrainingHook): + """Hook for tracking and logging training timing statistics. + + Tracks time per step, ETA, and average training speed. + """ + + priority = HookPriority.LOW # Run after other hooks + + def __init__(self) -> None: + """Initialize timing hook.""" + self.step_times: list[float] = [] + self.start_time: float | None = None + self.last_step_time: float | None = None + + def on_train_begin(self, logs: Mapping[str, Any] | None = None) -> None: + """Reset timing statistics.""" + import time + + self.step_times = [] + self.start_time = time.time() + self.last_step_time = self.start_time + + def on_step_end(self, step: int, logs: Mapping[str, Any] | None = None) -> None: + """Record step timing.""" + import time + + if self.last_step_time is not None: + step_time = time.time() - self.last_step_time + self.step_times.append(step_time) + # Keep only last 100 measurements + if len(self.step_times) > 100: + self.step_times.pop(0) + self.last_step_time = time.time() + + def get_average_step_time(self) -> float: + """Get average step time over last measurements. + + Returns + ------- + float + Average step time in seconds, or 0.0 if no measurements. + """ + if not self.step_times: + return 0.0 + return sum(self.step_times) / len(self.step_times) + + def get_eta_seconds(self, current_step: int, total_steps: int) -> int: + """Estimate time to completion. + + Parameters + ---------- + current_step : int + Current training step. + total_steps : int + Total number of training steps. + + Returns + ------- + int + Estimated seconds remaining. + """ + avg_time = self.get_average_step_time() + remaining_steps = total_steps - current_step + return int(avg_time * remaining_steps) + + +class EarlyStoppingHook(TrainingHook): + """Hook for early stopping based on validation metrics. + + Stops training when a monitored metric has stopped improving. + + Attributes + ---------- + monitor : str + Metric name to monitor (e.g., "val_loss"). + patience : int + Number of steps with no improvement after which training stops. + mode : str + One of {"min", "max"}. In "min" mode, training stops when metric + stops decreasing; in "max" mode, stops when stops increasing. + """ + + priority = HookPriority.HIGH + + def __init__( + self, + monitor: str = "val_loss", + patience: int = 10, + mode: str = "min", + ) -> None: + """Initialize early stopping hook. + + Parameters + ---------- + monitor : str + Metric name to monitor. + patience : int + Number of steps with no improvement before stopping. + mode : str + "min" or "max" - whether to minimize or maximize the metric. + """ + self.monitor = monitor + self.patience = patience + self.mode = mode + self.best_value: float | None = None + self.counter: int = 0 + self.should_stop: bool = False + + if mode == "min": + self.is_better = lambda current, best: current < best + self.best_value = float("inf") + elif mode == "max": + self.is_better = lambda current, best: current > best + self.best_value = float("-inf") + else: + raise ValueError(f"mode must be 'min' or 'max', got {mode}") + + def on_validation_end( + self, step: int, logs: Mapping[str, Any] | None = None + ) -> None: + """Check if metric has improved.""" + if logs is None or self.monitor not in logs: + return + + current_value = logs[self.monitor] + if not isinstance(current_value, (int, float)): + return + + if self.is_better(current_value, self.best_value): + self.best_value = current_value + self.counter = 0 + else: + self.counter += 1 + if self.counter >= self.patience: + self.should_stop = True + log.info( + f"Early stopping triggered at step {step}. " + f"{self.monitor} didn't improve for {self.patience} evaluations." + ) diff --git a/deepmd/pt/train/logger.py b/deepmd/pt/train/logger.py new file mode 100755 index 0000000000..0716a6ee89 --- /dev/null +++ b/deepmd/pt/train/logger.py @@ -0,0 +1,430 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Training logging utilities for output formatting and file I/O. + +This module provides clean interfaces for logging training progress, +managing log files, and formatting output messages. +""" + +from __future__ import ( + annotations, +) + +import logging +from pathlib import ( + Path, +) +from typing import ( + Any, + Self, +) + +from deepmd.loggers.training import ( + format_training_message, + format_training_message_per_task, +) + +log = logging.getLogger(__name__) + + +class TrainingLogger: + """Handles training log output to console and file. + + This class manages the training log file, formats output messages, + and handles both single-task and multi-task logging scenarios. + + Attributes + ---------- + log_file : Path | None + Path to the log file. + should_print_header : bool + Whether header needs to be printed. + is_multitask : bool + Whether logging for multi-task training. + model_keys : list[str] | None + Model keys for multi-task logging. + """ + + def __init__( + self, + log_file: str, + is_multitask: bool = False, + model_keys: list[str] | None = None, + rank: int = 0, + restart: bool = False, + ) -> None: + """Initialize training logger. + + Parameters + ---------- + log_file : str + Path to log file. + is_multitask : bool + Whether this is multi-task training. + model_keys : list[str] | None + Model keys for multi-task. + rank : int + Process rank (only rank 0 writes to file). + restart : bool + Whether this is a restart (append mode). + """ + self.is_multitask = is_multitask + self.model_keys = model_keys or [] + self.rank = rank + self.should_print_header = True + + # Open file only on rank 0 + if rank == 0: + self.log_path = Path(log_file) + mode = "a" if restart else "w" + self._file_handle = open(self.log_path, mode=mode, buffering=1) + else: + self.log_path = None + self._file_handle = None + + def log_step( + self, + step: int, + train_results: dict[str, Any], + valid_results: dict[str, Any] | dict[str, dict[str, Any]] | None, + lr: float, + wall_time: float | None = None, + eta: int | None = None, + task_key: str | None = None, + ) -> None: + """Log a training step. + + Parameters + ---------- + step : int + Current step number. + train_results : dict[str, Any] + Training metrics. + valid_results : dict[str, Any] | dict[str, dict[str, Any]] | None + Validation metrics. + lr : float + Current learning rate. + wall_time : float | None + Wall time for step (optional). + eta : int | None + Estimated time to completion (optional). + task_key : str | None + Current task key for multi-task. + """ + if self.rank != 0: + return + + # Print header if needed + if self.should_print_header and self._file_handle: + self._print_header(train_results, valid_results) + self.should_print_header = False + + # Log to console + self._log_console( + step, train_results, valid_results, lr, wall_time, eta, task_key + ) + + # Log to file + if self._file_handle: + self._print_to_file(step, lr, train_results, valid_results) + + def _log_console( + self, + step: int, + train_results: dict[str, Any], + valid_results: dict[str, Any] | dict[str, dict[str, Any]] | None, + lr: float, + wall_time: float | None, + eta: int | None, + task_key: str | None, + ) -> None: + """Log to console.""" + if self.is_multitask: + # Log all tasks + for key in self.model_keys: + train_res = ( + train_results.get(key, {}) + if isinstance(train_results, dict) + else {} + ) + valid_res = ( + valid_results.get(key, {}) + if isinstance(valid_results, dict) + else {} + ) + + if train_res: + log.info( + format_training_message_per_task( + batch=step, + task_name=f"{key}_trn", + rmse=train_res, + learning_rate=lr if key == task_key else None, + ) + ) + + if valid_res: + log.info( + format_training_message_per_task( + batch=step, + task_name=f"{key}_val", + rmse=valid_res, + learning_rate=None, + ) + ) + else: + if train_results: + log.info( + format_training_message_per_task( + batch=step, + task_name="trn", + rmse=train_results, + learning_rate=lr, + ) + ) + + if valid_results and isinstance(valid_results, dict): + log.info( + format_training_message_per_task( + batch=step, + task_name="val", + rmse=valid_results, + learning_rate=None, + ) + ) + + # Log timing + if wall_time is not None and eta is not None: + log.info( + format_training_message( + batch=step, + wall_time=wall_time, + eta=eta, + ) + ) + + def _print_header( + self, + train_results: dict[str, Any], + valid_results: dict[str, Any] | dict[str, dict[str, Any]] | None, + ) -> None: + """Print header to log file.""" + if not self._file_handle: + return + + header = "# {:5s}".format("step") + + if self.is_multitask: + for key in self.model_keys: + train_keys = ( + sorted(train_results.get(key, {}).keys()) + if isinstance(train_results, dict) + else [] + ) + if valid_results and key in (valid_results or {}): + for k in train_keys: + header += f" {k + f'_val_{key}':11s} {k + f'_trn_{key}':11s}" + else: + for k in train_keys: + header += f" {k + f'_trn_{key}':11s}" + else: + train_keys = sorted(train_results.keys()) + if valid_results: + for k in train_keys: + header += f" {k + '_val':11s} {k + '_trn':11s}" + else: + for k in train_keys: + header += f" {k + '_trn':11s}" + + header += " {:8s}\n".format("lr") + header += "# If there is no available reference data, rmse_*_{val,trn} will print nan\n" + + self._file_handle.write(header) + self._file_handle.flush() + + def _print_to_file( + self, + step: int, + lr: float, + train_results: dict[str, Any], + valid_results: dict[str, Any] | dict[str, dict[str, Any]] | None, + ) -> None: + """Print formatted line to log file.""" + if not self._file_handle: + return + + line = f"{step:7d}" + + if self.is_multitask: + for key in self.model_keys: + train_res = ( + train_results.get(key, {}) + if isinstance(train_results, dict) + else {} + ) + valid_res = ( + valid_results.get(key, {}) + if isinstance(valid_results, dict) + else {} + ) + + if valid_res: + for k in sorted(train_res.keys()): + line += f" {valid_res.get(k, 0.0):11.2e} {train_res.get(k, 0.0):11.2e}" + else: + for k in sorted(train_res.keys()): + line += f" {train_res.get(k, 0.0):11.2e}" + else: + train_keys = sorted(train_results.keys()) + if valid_results and isinstance(valid_results, dict): + for k in train_keys: + line += f" {valid_results.get(k, 0.0):11.2e} {train_results.get(k, 0.0):11.2e}" + else: + for k in train_keys: + line += f" {train_results.get(k, 0.0):11.2e}" + + line += f" {lr:8.1e}\n" + self._file_handle.write(line) + self._file_handle.flush() + + def log_summary( + self, total_time: float, timed_steps: int, excluded_steps: int + ) -> None: + """Log training summary. + + Parameters + ---------- + total_time : float + Total training time. + timed_steps : int + Number of timed steps. + excluded_steps : int + Number of excluded steps. + """ + if timed_steps > 0: + avg_time = total_time / timed_steps + msg = f"Average training time: {avg_time:.4f} s/batch" + if excluded_steps > 0: + msg += f" ({excluded_steps} batches excluded)" + log.info(msg) + + def close(self) -> None: + """Close log file.""" + if self._file_handle: + self._file_handle.close() + self._file_handle = None + + def __enter__(self) -> Self: + """Context manager entry.""" + return self + + def __exit__(self, *args: object) -> None: + """Context manager exit.""" + self.close() + + +class LossAccumulator: + """Accumulates loss values over multiple steps for averaging. + + This class handles loss accumulation for both single-task and + multi-task training scenarios. + """ + + def __init__( + self, is_multitask: bool = False, model_keys: list[str] | None = None + ) -> None: + """Initialize loss accumulator. + + Parameters + ---------- + is_multitask : bool + Whether this is multi-task training. + model_keys : list[str] | None + Model keys for multi-task. + """ + self.is_multitask = is_multitask + self.model_keys = model_keys or [] + self.reset() + + def reset(self) -> None: + """Reset accumulated losses.""" + if self.is_multitask: + self.accumulated: dict[str, dict[str, float]] = { + key: {} for key in self.model_keys + } + self.step_counts: dict[str, int] = dict.fromkeys(self.model_keys, 0) + else: + self.accumulated: dict[str, float] = {} + self.step_count = 0 + + def update( + self, + more_loss: dict[str, Any], + task_key: str = "Default", + ) -> None: + """Update accumulated losses. + + Parameters + ---------- + more_loss : dict[str, Any] + Loss dictionary from step. + task_key : str + Task key for multi-task. + """ + if self.is_multitask: + self.step_counts[task_key] = self.step_counts.get(task_key, 0) + 1 + if task_key not in self.accumulated: + self.accumulated[task_key] = {} + + for key, value in more_loss.items(): + if "l2_" in key: + continue + if not isinstance(value, (int, float)): + continue + if key not in self.accumulated[task_key]: + self.accumulated[task_key][key] = 0.0 + self.accumulated[task_key][key] += float(value) + else: + self.step_count += 1 + for key, value in more_loss.items(): + if "l2_" in key: + continue + if not isinstance(value, (int, float)): + continue + if key not in self.accumulated: + self.accumulated[key] = 0.0 + self.accumulated[key] += float(value) + + def get_averaged(self, task_key: str = "Default") -> dict[str, float]: + """Get averaged losses. + + Parameters + ---------- + task_key : str + Task key for multi-task. + + Returns + ------- + dict[str, float] + Averaged loss values. + """ + if self.is_multitask: + if task_key not in self.accumulated: + return {} + count = self.step_counts.get(task_key, 1) + return {k: v / count for k, v in self.accumulated[task_key].items()} + else: + if self.step_count == 0: + return {} + return {k: v / self.step_count for k, v in self.accumulated.items()} + + def get_all_averaged(self) -> dict[str, dict[str, float]] | dict[str, float]: + """Get all averaged losses. + + Returns + ------- + dict[str, dict[str, float]] | dict[str, float] + All averaged losses. + """ + if self.is_multitask: + return {key: self.get_averaged(key) for key in self.model_keys} + return self.get_averaged() diff --git a/deepmd/pt/train/optimizer_factory.py b/deepmd/pt/train/optimizer_factory.py new file mode 100755 index 0000000000..87c8a14b4d --- /dev/null +++ b/deepmd/pt/train/optimizer_factory.py @@ -0,0 +1,525 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Optimizer and learning rate scheduler factory. + +This module provides a factory pattern for creating optimizers and +learning rate schedulers, making it easy to add new optimizer types +and customize their behavior. +""" + +from __future__ import ( + annotations, +) + +import logging +from abc import ( + ABC, + abstractmethod, +) +from typing import ( + TYPE_CHECKING, + Any, +) + +import torch + +from deepmd.pt.optimizer import ( + AdaMuonOptimizer, + HybridMuonOptimizer, + LKFOptimizer, +) +from deepmd.pt.utils.env import ( + DEVICE, +) + +if TYPE_CHECKING: + from torch.optim import ( + Optimizer, + ) + from torch.optim.lr_scheduler import ( + LRScheduler, + ) + + from deepmd.pt.train.config import ( + OptimizerConfig, + ) + +log = logging.getLogger(__name__) + + +class OptimizerStrategy(ABC): + """Abstract base class for optimizer creation strategies. + + This class defines the interface for creating optimizers and their + associated learning rate schedulers. Subclasses implement specific + optimizer types. + """ + + @abstractmethod + def create_optimizer( + self, + parameters: Any, + config: OptimizerConfig, + lr_config: Any, + ) -> Optimizer | Any: + """Create an optimizer instance. + + Parameters + ---------- + parameters : Any + Model parameters to optimize. + config : OptimizerConfig + Optimizer configuration. + lr_config : Any + Learning rate configuration. + + Returns + ------- + Optimizer | Any + The created optimizer instance. + """ + pass + + @abstractmethod + def create_scheduler( + self, + optimizer: Optimizer | Any, + warmup_steps: int, + warmup_start_factor: float, + lr_schedule: Any, + start_step: int = 0, + ) -> LRScheduler | None: + """Create a learning rate scheduler. + + Parameters + ---------- + optimizer : Optimizer | Any + The optimizer to schedule. + warmup_steps : int + Number of warmup steps. + warmup_start_factor : float + Initial LR factor during warmup. + lr_schedule : Any + Learning rate schedule object. + start_step : int + Starting step for scheduler. + + Returns + ------- + LRScheduler | None + The created scheduler, or None if not applicable. + """ + pass + + @abstractmethod + def supports_scheduler(self) -> bool: + """Whether this optimizer supports LR scheduling. + + Returns + ------- + bool + True if scheduler is supported, False otherwise. + """ + pass + + +class AdamStrategy(OptimizerStrategy): + """Strategy for creating Adam optimizer.""" + + def create_optimizer( + self, + parameters: Any, + config: OptimizerConfig, + lr_config: Any, + ) -> Optimizer: + """Create Adam optimizer.""" + return torch.optim.Adam( + parameters, + lr=lr_config.start_lr, + fused=False if DEVICE.type == "cpu" else True, + ) + + def create_scheduler( + self, + optimizer: Optimizer, + warmup_steps: int, + warmup_start_factor: float, + lr_schedule: Any, + start_step: int = 0, + ) -> LRScheduler: + """Create LambdaLR scheduler with warmup.""" + + def warmup_linear(step: int) -> float: + """Compute LR multiplier with warmup.""" + current_step = step + start_step + if current_step < warmup_steps: + return warmup_start_factor + (1.0 - warmup_start_factor) * ( + current_step / warmup_steps + ) + else: + return ( + lr_schedule.value(current_step - warmup_steps) + / lr_schedule.start_lr + ) + + return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_linear) + + def supports_scheduler(self) -> bool: + return True + + +class AdamWStrategy(OptimizerStrategy): + """Strategy for creating AdamW optimizer.""" + + def create_optimizer( + self, + parameters: Any, + config: OptimizerConfig, + lr_config: Any, + ) -> Optimizer: + """Create AdamW optimizer.""" + return torch.optim.AdamW( + parameters, + lr=lr_config.start_lr, + weight_decay=config.weight_decay, + fused=False if DEVICE.type == "cpu" else True, + ) + + def create_scheduler( + self, + optimizer: Optimizer, + warmup_steps: int, + warmup_start_factor: float, + lr_schedule: Any, + start_step: int = 0, + ) -> LRScheduler: + """Create LambdaLR scheduler with warmup.""" + + def warmup_linear(step: int) -> float: + """Compute LR multiplier with warmup.""" + current_step = step + start_step + if current_step < warmup_steps: + return warmup_start_factor + (1.0 - warmup_start_factor) * ( + current_step / warmup_steps + ) + else: + return ( + lr_schedule.value(current_step - warmup_steps) + / lr_schedule.start_lr + ) + + return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_linear) + + def supports_scheduler(self) -> bool: + return True + + +class LKFStrategy(OptimizerStrategy): + """Strategy for creating LKF (Levenberg-Kalman Filter) optimizer.""" + + def create_optimizer( + self, + parameters: Any, + config: OptimizerConfig, + lr_config: Any, + ) -> LKFOptimizer: + """Create LKF optimizer.""" + return LKFOptimizer( + parameters, + 0.98, # Kalman lambda + 0.99870, # Kalman nu + config.kf_blocksize, + ) + + def create_scheduler( + self, + optimizer: Optimizer | Any, + warmup_steps: int, + warmup_start_factor: float, + lr_schedule: Any, + start_step: int = 0, + ) -> None: + """LKF doesn't use a scheduler.""" + return None + + def supports_scheduler(self) -> bool: + return False + + +class AdaMuonStrategy(OptimizerStrategy): + """Strategy for creating AdaMuon optimizer.""" + + def create_optimizer( + self, + parameters: Any, + config: OptimizerConfig, + lr_config: Any, + ) -> AdaMuonOptimizer: + """Create AdaMuon optimizer.""" + return AdaMuonOptimizer( + parameters, + lr=lr_config.start_lr, + momentum=config.momentum, + weight_decay=config.weight_decay, + adam_betas=(config.adam_beta1, config.adam_beta2), + lr_adjust=config.lr_adjust, + lr_adjust_coeff=config.lr_adjust_coeff, + ) + + def create_scheduler( + self, + optimizer: Optimizer | Any, + warmup_steps: int, + warmup_start_factor: float, + lr_schedule: Any, + start_step: int = 0, + ) -> LRScheduler: + """Create LambdaLR scheduler with warmup.""" + + def warmup_linear(step: int) -> float: + """Compute LR multiplier with warmup.""" + current_step = step + start_step + if current_step < warmup_steps: + return warmup_start_factor + (1.0 - warmup_start_factor) * ( + current_step / warmup_steps + ) + else: + return ( + lr_schedule.value(current_step - warmup_steps) + / lr_schedule.start_lr + ) + + return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_linear) + + def supports_scheduler(self) -> bool: + return True + + +class HybridMuonStrategy(OptimizerStrategy): + """Strategy for creating HybridMuon optimizer.""" + + def create_optimizer( + self, + parameters: Any, + config: OptimizerConfig, + lr_config: Any, + ) -> HybridMuonOptimizer: + """Create HybridMuon optimizer.""" + return HybridMuonOptimizer( + parameters, + lr=lr_config.start_lr, + momentum=config.momentum, + weight_decay=config.weight_decay, + adam_betas=(config.adam_beta1, config.adam_beta2), + lr_adjust=config.lr_adjust, + lr_adjust_coeff=config.lr_adjust_coeff, + muon_2d_only=config.muon_2d_only, + min_2d_dim=config.min_2d_dim, + ) + + def create_scheduler( + self, + optimizer: Optimizer | Any, + warmup_steps: int, + warmup_start_factor: float, + lr_schedule: Any, + start_step: int = 0, + ) -> LRScheduler: + """Create LambdaLR scheduler with warmup.""" + + def warmup_linear(step: int) -> float: + """Compute LR multiplier with warmup.""" + current_step = step + start_step + if current_step < warmup_steps: + return warmup_start_factor + (1.0 - warmup_start_factor) * ( + current_step / warmup_steps + ) + else: + return ( + lr_schedule.value(current_step - warmup_steps) + / lr_schedule.start_lr + ) + + return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_linear) + + def supports_scheduler(self) -> bool: + return True + + +class OptimizerFactory: + """Factory for creating optimizers and schedulers. + + This factory centralizes optimizer creation and makes it easy to + register new optimizer types. + + Example: + >>> factory = OptimizerFactory() + >>> optimizer = factory.create_optimizer( + ... parameters=model.parameters(), + ... config=OptimizerConfig(opt_type="Adam"), + ... lr_config=lr_config, + ... ) + """ + + def __init__(self) -> None: + """Initialize factory with default strategies.""" + self._strategies: dict[str, OptimizerStrategy] = { + "Adam": AdamStrategy(), + "AdamW": AdamWStrategy(), + "LKF": LKFStrategy(), + "AdaMuon": AdaMuonStrategy(), + "HybridMuon": HybridMuonStrategy(), + } + + def register(self, opt_type: str, strategy: OptimizerStrategy) -> None: + """Register a new optimizer strategy. + + Parameters + ---------- + opt_type : str + Identifier for the optimizer type. + strategy : OptimizerStrategy + Strategy instance for creating the optimizer. + """ + self._strategies[opt_type] = strategy + log.info(f"Registered optimizer strategy: {opt_type}") + + def create_optimizer( + self, + parameters: Any, + config: OptimizerConfig, + lr_config: Any, + ) -> Optimizer | Any: + """Create an optimizer. + + Parameters + ---------- + parameters : Any + Model parameters to optimize. + config : OptimizerConfig + Optimizer configuration. + lr_config : Any + Learning rate configuration. + + Returns + ------- + Optimizer | Any + The created optimizer. + + Raises + ------ + ValueError + If optimizer type is not registered. + """ + if config.opt_type not in self._strategies: + raise ValueError( + f"Unknown optimizer type: {config.opt_type}. " + f"Available: {list(self._strategies.keys())}" + ) + + strategy = self._strategies[config.opt_type] + return strategy.create_optimizer(parameters, config, lr_config) + + def create_scheduler( + self, + opt_type: str, + optimizer: Optimizer | Any, + warmup_steps: int, + warmup_start_factor: float, + lr_schedule: Any, + start_step: int = 0, + ) -> LRScheduler | None: + """Create a learning rate scheduler. + + Parameters + ---------- + opt_type : str + Type of optimizer. + optimizer : Optimizer | Any + The optimizer to schedule. + warmup_steps : int + Number of warmup steps. + warmup_start_factor : float + Initial LR factor during warmup. + lr_schedule : Any + Learning rate schedule object. + start_step : int + Starting step for scheduler. + + Returns + ------- + LRScheduler | None + The created scheduler, or None if not supported. + """ + if opt_type not in self._strategies: + return None + + strategy = self._strategies[opt_type] + if not strategy.supports_scheduler(): + return None + + return strategy.create_scheduler( + optimizer, + warmup_steps, + warmup_start_factor, + lr_schedule, + start_step, + ) + + def supports_scheduler(self, opt_type: str) -> bool: + """Check if optimizer type supports LR scheduling. + + Parameters + ---------- + opt_type : str + Type of optimizer. + + Returns + ------- + bool + True if scheduler is supported. + """ + if opt_type not in self._strategies: + return False + return self._strategies[opt_type].supports_scheduler() + + def get_available_optimizers(self) -> list[str]: + """Get list of available optimizer types. + + Returns + ------- + list[str] + List of registered optimizer type names. + """ + return list(self._strategies.keys()) + + +# Global factory instance for convenience +_default_factory = OptimizerFactory() + + +def create_optimizer( + parameters: Any, + config: OptimizerConfig, + lr_config: Any, +) -> Optimizer | Any: + """Convenience function to create optimizer using default factory.""" + return _default_factory.create_optimizer(parameters, config, lr_config) + + +def create_scheduler( + opt_type: str, + optimizer: Optimizer | Any, + warmup_steps: int, + warmup_start_factor: float, + lr_schedule: Any, + start_step: int = 0, +) -> LRScheduler | None: + """Convenience function to create scheduler using default factory.""" + return _default_factory.create_scheduler( + opt_type, + optimizer, + warmup_steps, + warmup_start_factor, + lr_schedule, + start_step, + ) diff --git a/deepmd/pt/train/trainer.py b/deepmd/pt/train/trainer.py new file mode 100755 index 0000000000..df7e45729c --- /dev/null +++ b/deepmd/pt/train/trainer.py @@ -0,0 +1,1164 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Refactored PyTorch trainer with modular components. + +This module provides a clean, extensible trainer implementation that +uses composition over monolithic design. It supports: + +- Single-task and multi-task training +- Multiple optimizer types via strategy pattern +- Hook system for extensibility +- Clean separation of concerns +- Fine-tuning support +- Multi-task parameter sharing + +Future extension points for multi-backend support: +- AbstractTrainingLoop can be extended for JAX/NumPy backends +- OptimizerFactory can support backend-specific optimizers +- DataManager can use backend-specific data loading +""" + +from __future__ import ( + annotations, +) + +import functools +import logging +import time +from copy import ( + deepcopy, +) +from typing import ( + TYPE_CHECKING, + Any, +) + +import numpy as np +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP + +from deepmd.pt.loss import ( + DenoiseLoss, + DOSLoss, + EnergyHessianStdLoss, + EnergySpinLoss, + EnergyStdLoss, + PropertyLoss, + TaskLoss, + TensorLoss, +) +from deepmd.pt.model.model import ( + get_model, + get_zbl_model, +) +from deepmd.pt.train.checkpoint_manager import ( + CheckpointManager, +) +from deepmd.pt.train.config import ( + TrainingConfig, +) +from deepmd.pt.train.data_manager import ( + DataManager, +) +from deepmd.pt.train.hooks import ( + HookManager, + TensorBoardHook, + TimingHook, +) +from deepmd.pt.train.logger import ( + LossAccumulator, + TrainingLogger, +) +from deepmd.pt.train.optimizer_factory import ( + OptimizerFactory, +) +from deepmd.pt.train.training_loop import ( + TrainingLoopFactory, +) +from deepmd.pt.train.wrapper import ( + ModelWrapper, +) +from deepmd.pt.utils import ( + dp_random, +) +from deepmd.pt.utils.env import ( + DEVICE, + LOCAL_RANK, +) +from deepmd.pt.utils.learning_rate import ( + BaseLR, +) +from deepmd.pt.utils.stat import ( + make_stat_input, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, +) +from deepmd.utils.data import ( + DataRequirementItem, +) +from deepmd.utils.path import ( + DPH5Path, +) + +if TYPE_CHECKING: + from collections.abc import ( + Callable, + ) + + from deepmd.pt.utils.dataloader import ( + DpLoaderSet, + ) + +log = logging.getLogger(__name__) + + +def model_change_out_bias( + _model: Any, + _sample_func: Callable[[], Any], + _bias_adjust_mode: str = "change-by-statistic", +) -> Any: + """Change model output bias during fine-tuning. + + Parameters + ---------- + _model : Any + Model to modify. + _sample_func : Callable[[], Any] + Function to get sample data for statistics. + _bias_adjust_mode : str + Bias adjustment mode. + + Returns + ------- + Any + Modified model. + """ + old_bias = deepcopy(_model.get_out_bias()) + _model.change_out_bias( + _sample_func, + bias_adjust_mode=_bias_adjust_mode, + ) + new_bias = deepcopy(_model.get_out_bias()) + + model_type_map = _model.get_type_map() + log.info( + f"Change output bias of {model_type_map!s} from {to_numpy_array(old_bias).reshape(-1)!s} to {to_numpy_array(new_bias).reshape(-1)!s}." + ) + return _model + + +class Trainer: + """Main trainer class orchestrating the training process. + + This is a refactored, modular trainer that delegates specific + responsibilities to focused components: + + - TrainingConfig: Configuration management + - DataManager: Data loading and iteration + - OptimizerFactory: Optimizer creation + - CheckpointManager: Model persistence + - TrainingLoop: Core training step logic + - HookManager: Extensibility hooks + - TrainingLogger: Output formatting + + Parameters + ---------- + config : dict[str, Any] + Training configuration dictionary. + training_data : DpLoaderSet | dict[str, DpLoaderSet] + Training dataset(s). + validation_data : DpLoaderSet | dict[str, DpLoaderSet] | None + Validation dataset(s). + stat_file_path : str | None + Path to statistics file. + init_model : str | None + Path to initialization model. + restart_model : str | None + Path to checkpoint for restart. + finetune_model : str | None + Path to model for fine-tuning. + init_frz_model : str | None + Path to frozen model for initialization. + force_load : bool + Whether to force load mismatched checkpoints. + shared_links : dict[str, Any] | None + Parameter sharing configuration for multi-task. + finetune_links : dict[str, Any] | None + Fine-tuning mapping configuration. + rank : int + Distributed training rank. + + Attributes + ---------- + config : TrainingConfig + Parsed training configuration. + data_manager : DataManager + Data loading manager. + checkpoint_manager : CheckpointManager + Checkpoint persistence manager. + hook_manager : HookManager + Training hooks manager. + """ + + def __init__( + self, + config: dict[str, Any], + training_data: DpLoaderSet | dict[str, DpLoaderSet], + validation_data: DpLoaderSet | dict[str, DpLoaderSet] | None = None, + stat_file_path: str | None = None, + init_model: str | None = None, + restart_model: str | None = None, + finetune_model: str | None = None, + init_frz_model: str | None = None, + force_load: bool = False, + shared_links: dict[str, Any] | None = None, + finetune_links: dict[str, Any] | None = None, + rank: int = 0, + ) -> None: + """Initialize the trainer with all components.""" + self.rank = rank + self.world_size = ( + dist.get_world_size() + if dist.is_available() and dist.is_initialized() + else 1 + ) + + # Determine resume/finetune state + self.resume_model = restart_model or init_model or finetune_model + self.is_restart = restart_model is not None + self.is_finetune = finetune_model is not None + self.finetune_update_stat = False + + # Parse configuration + model_params = config.get("model", {}) + self.model_keys = ( + list(model_params.get("model_dict", {}).keys()) + if "model_dict" in model_params + else ["Default"] + ) + self.is_multitask = len(self.model_keys) > 1 and "model_dict" in model_params + + self.config = TrainingConfig.from_dict(config, self.model_keys) + self.shared_links = shared_links + self.finetune_links = finetune_links + + # Store for later use + self._config_dict = config + self._model_params = model_params + self._stat_file_path = stat_file_path + + # Initialize components + self._init_model(model_params, config) + self._init_loss(config, model_params) + + # Compute statistics before data manager (need get_sample_func) + self._compute_statistics_before_data(training_data, stat_file_path) + + self._init_data(training_data, validation_data, config, stat_file_path) + self._init_optimizer_and_scheduler() + self._init_distributed() + self._setup_finetune(finetune_model) + self._setup_multitask_shared_params(shared_links) + self._init_checkpoint_manager() + self._init_hooks() + self._init_logger() + + # Load checkpoint if resuming + self.start_step = 0 + if self.resume_model: + self._load_resume_checkpoint(finetune_model) + + # Load frozen model if specified + if init_frz_model: + self._load_frozen_model(init_frz_model) + + # Initialize training loop + self._init_training_loop() + + # Log model info + if self.rank == 0: + self._log_model_info() + + def _init_model(self, model_params: dict[str, Any], config: dict[str, Any]) -> None: + """Initialize model(s).""" + loss_params = self._get_loss_params(config) + + if self.is_multitask: + self.model: dict[str, torch.nn.Module] | torch.nn.Module = {} + for key in self.model_keys: + model_dict = model_params["model_dict"][key] + if loss_params and loss_params.get(key): + if self._is_hessian_loss(loss_params[key]): + model_dict = deepcopy(model_dict) + model_dict["hessian_mode"] = True + self.model[key] = self._create_single_model(model_dict) + else: + if loss_params and self._is_hessian_loss(loss_params): + model_params = deepcopy(model_params) + model_params["hessian_mode"] = True + self.model = self._create_single_model(model_params) + + def _create_single_model(self, model_params: dict[str, Any]) -> torch.nn.Module: + """Create a single model instance.""" + if "use_srtab" in model_params: + return get_zbl_model(deepcopy(model_params)).to(DEVICE) + return get_model(deepcopy(model_params)).to(DEVICE) + + def _init_loss(self, config: dict[str, Any], model_params: dict[str, Any]) -> None: + """Initialize loss function(s).""" + if self.is_multitask: + self.loss: dict[str, TaskLoss] | TaskLoss = {} + for key in self.model_keys: + loss_param = config["loss_dict"][key] + lr_param = self._get_lr_for_task(config, key) + ntypes = len(model_params["model_dict"][key]["type_map"]) + self.loss[key] = self._create_loss( + loss_param, lr_param, ntypes, self.model[key] + ) + else: + self.loss = self._create_loss( + config["loss"], + config["learning_rate"]["start_lr"], + len(model_params["type_map"]), + self.model, + ) + + def _create_loss( + self, + loss_params: dict[str, Any], + start_lr: float, + ntypes: int, + model: torch.nn.Module, + ) -> TaskLoss: + """Create loss function instance.""" + loss_type = loss_params.get("type", "ener") + + if loss_type == "ener": + if loss_params.get("start_pref_h", 0.0) > 0.0: + loss_params["starter_learning_rate"] = start_lr + return EnergyHessianStdLoss(**loss_params) + else: + loss_params["starter_learning_rate"] = start_lr + return EnergyStdLoss(**loss_params) + elif loss_type == "ener_spin": + loss_params["starter_learning_rate"] = start_lr + return EnergySpinLoss(**loss_params) + elif loss_type == "denoise": + loss_params["ntypes"] = ntypes + return DenoiseLoss(**loss_params) + elif loss_type == "dos": + loss_params["starter_learning_rate"] = start_lr + loss_params["numb_dos"] = model.model_output_def()["dos"].output_size + return DOSLoss(**loss_params) + elif loss_type == "tensor": + model_output_type = model.model_output_type() + if "mask" in model_output_type: + model_output_type.pop(model_output_type.index("mask")) + tensor_name = model_output_type[0] + loss_params["tensor_size"] = model.model_output_def()[ + tensor_name + ].output_size + loss_params["starter_learning_rate"] = start_lr + return TensorLoss(**loss_params) + elif loss_type == "property": + loss_params["task_dim"] = model.get_task_dim() + loss_params["var_name"] = model.get_var_name() + loss_params["intensive"] = model.get_intensive() + loss_params["starter_learning_rate"] = start_lr + return PropertyLoss(**loss_params) + else: + # Use TaskLoss.get_class_by_type for other types + loss_params["starter_learning_rate"] = start_lr + return TaskLoss.get_class_by_type(loss_type).get_loss(loss_params) + + def _compute_statistics_before_data( + self, + training_data: DpLoaderSet | dict[str, DpLoaderSet], + stat_file_path: str | None, + ) -> None: + """Compute model statistics before creating data manager.""" + # Determine finetune_has_new_type + finetune_has_new_type = False + if self.is_finetune and self.finetune_links is not None: + if self.is_multitask: + for key in self.model_keys: + if self.finetune_links[key].get_has_new_type(): + finetune_has_new_type = True + break + else: + finetune_has_new_type = self.finetune_links[ + "Default" + ].get_has_new_type() + + # Only compute stats on rank 0 and when not resuming (or finetune with new type) + # For finetune, we need sample_func for model_change_out_bias + should_compute = ( + not self.resume_model or finetune_has_new_type or self.is_finetune + ) and self.rank == 0 + + if not should_compute: + self.get_sample_func = None + return + + # Create get_sample_func for each model + if self.is_multitask: + self.get_sample_func = {} + for key in self.model_keys: + self.get_sample_func[key] = self._create_sample_func( + training_data[key], + self._config_dict["training"]["data_dict"][key]["training_data"], + ) + + # Compute statistics + finetune_has_new_type_key = ( + self.finetune_links[key].get_has_new_type() + if self.is_finetune and self.finetune_links + else False + ) + + # Get stat file path for this key + stat_path_key = None + if stat_file_path and isinstance(stat_file_path, dict): + stat_path_key = stat_file_path.get(key) + elif stat_file_path: + stat_path_key = stat_file_path + + self.model[key].compute_or_load_stat( + sampled_func=self.get_sample_func[key], + stat_file_path=stat_path_key, + ) + + if isinstance(stat_path_key, DPH5Path): + stat_path_key.root.close() + else: + self.get_sample_func = self._create_sample_func( + training_data, + self._config_dict["training"]["training_data"], + ) + + self.model.compute_or_load_stat( + sampled_func=self.get_sample_func, + stat_file_path=stat_file_path, + ) + + if isinstance(stat_file_path, DPH5Path): + stat_file_path.root.close() + + def _create_sample_func( + self, + training_data: DpLoaderSet, + training_params: dict[str, Any], + ) -> Callable[[], Any]: + """Create sample function for statistics computation.""" + data_stat_nbatch = training_params.get("data_stat_nbatch", 10) + + @functools.cache + def get_sample() -> Any: + sampled = make_stat_input( + training_data.systems, + training_data.dataloaders, + data_stat_nbatch, + ) + return sampled + + return get_sample + + def _init_data( + self, + training_data: DpLoaderSet | dict[str, DpLoaderSet], + validation_data: DpLoaderSet | dict[str, DpLoaderSet] | None, + config: dict[str, Any], + stat_file_path: str | None, + ) -> None: + """Initialize data manager and compute statistics.""" + # Add data requirements + self._setup_data_requirements(training_data, validation_data) + + # Create data manager + self.data_manager = DataManager( + training_data, + validation_data, + config.get("training", {}), + DEVICE, + ) + + # Print data summary + self.data_manager.print_summary(self.rank) + + def _setup_data_requirements( + self, + training_data: DpLoaderSet | dict[str, DpLoaderSet], + validation_data: DpLoaderSet | dict[str, DpLoaderSet] | None, + ) -> None: + """Setup data requirements for training and validation.""" + if self.is_multitask: + for key in self.model_keys: + data_req = self.loss[key].label_requirement + data_req += self._get_additional_data_requirement(self.model[key]) + training_data[key].add_data_requirement(data_req) + if validation_data and validation_data[key] is not None: + validation_data[key].add_data_requirement(data_req) + + training_data[key].preload_and_modify_all_data_torch() + if validation_data and validation_data[key] is not None: + validation_data[key].preload_and_modify_all_data_torch() + else: + data_req = self.loss.label_requirement + data_req += self._get_additional_data_requirement(self.model) + training_data.add_data_requirement(data_req) + if validation_data is not None: + validation_data.add_data_requirement(data_req) + + training_data.preload_and_modify_all_data_torch() + if validation_data is not None: + validation_data.preload_and_modify_all_data_torch() + + def _get_additional_data_requirement( + self, model: torch.nn.Module + ) -> list[DataRequirementItem]: + """Get additional data requirements from model.""" + requirements = [] + + if model.get_dim_fparam() > 0: + fparam_default = ( + model.get_default_fparam().cpu().numpy() + if model.has_default_fparam() + else 0.0 + ) + requirements.append( + DataRequirementItem( + "fparam", + model.get_dim_fparam(), + atomic=False, + must=not model.has_default_fparam(), + default=fparam_default, + ) + ) + + if model.get_dim_aparam() > 0: + requirements.append( + DataRequirementItem( + "aparam", model.get_dim_aparam(), atomic=True, must=True + ) + ) + + has_spin = getattr(model, "has_spin", False) + if callable(has_spin): + has_spin = has_spin() + if has_spin: + requirements.append( + DataRequirementItem("spin", ndof=3, atomic=True, must=True) + ) + + return requirements + + def _init_optimizer_and_scheduler(self) -> None: + """Initialize optimizer and learning rate scheduler.""" + self.optimizer_factory = OptimizerFactory() + + # Create wrapper + self.wrapper = ModelWrapper( + self.model, self.loss, model_params=getattr(self, "_model_params", {}) + ) + + # Create optimizer + opt_config = self.config.get_optimizer_config() + lr_config = self.config.get_lr_config() + + self.optimizer = self.optimizer_factory.create_optimizer( + self.wrapper.parameters(), + opt_config, + lr_config, + ) + + # Create LR schedule + self.lr_schedule = BaseLR( + type="exp", + start_lr=lr_config.start_lr, + stop_lr=lr_config.stop_lr, + decay_steps=lr_config.decay_steps, + decay_rate=lr_config.decay_rate, + stop_steps=lr_config.stop_steps, + ) + + # Create scheduler if supported + if self.optimizer_factory.supports_scheduler(opt_config.opt_type): + self.scheduler = self.optimizer_factory.create_scheduler( + opt_config.opt_type, + self.optimizer, + self.config.warmup_steps, + self.config.warmup_start_factor, + self.lr_schedule, + 0, # start_step, will be updated after loading checkpoint + ) + else: + self.scheduler = None + + def _setup_finetune(self, finetune_model: str | None) -> None: + """Setup fine-tuning if applicable.""" + if finetune_model is None or self.finetune_links is None: + return + + if self.is_multitask: + for key in self.model_keys: + finetune_rule_single = self.finetune_links[key] + + if finetune_rule_single.get_has_new_type(): + self.finetune_update_stat = True + + if not finetune_rule_single.get_resuming(): + self.model[key] = self._apply_finetune_to_model( + self.model[key], + finetune_rule_single, + self.get_sample_func[key] + if isinstance(self.get_sample_func, dict) + else self.get_sample_func, + ) + else: + finetune_rule_single = self.finetune_links["Default"] + self.model = self._apply_finetune_to_model( + self.model, + finetune_rule_single, + self.get_sample_func, + ) + + def _apply_finetune_to_model( + self, + model: torch.nn.Module, + finetune_rule: Any, + sample_func: Callable[[], Any], + ) -> torch.nn.Module: + """Apply fine-tuning modifications to a model.""" + # Handle change_out_bias + if not finetune_rule.get_random_fitting(): + model = model_change_out_bias( + model, + sample_func, + _bias_adjust_mode="change-by-statistic", + ) + return model + + def _setup_multitask_shared_params( + self, shared_links: dict[str, Any] | None + ) -> None: + """Setup multi-task parameter sharing.""" + if shared_links is None or not self.is_multitask: + return + + # Get data_stat_protect values + data_stat_protect_values = [ + self._model_params["model_dict"][key].get("data_stat_protect", 1e-2) + for key in self.model_keys + ] + + # Check all values are the same + assert all( + abs(v - data_stat_protect_values[0]) < 1e-10 + for v in data_stat_protect_values + ), ( + "Model key 'data_stat_protect' must be the same in each branch when multitask!" + ) + + # Compute model probabilities + model_prob = np.zeros(len(self.model_keys), dtype=np.float32) + for ii, model_key in enumerate(self.model_keys): + # Get training data size for this model + if hasattr(self, "data_manager") and self.data_manager: + # Try to get from data_manager + pass + # Use uniform probability for now + model_prob[ii] = 1.0 + + model_prob = model_prob / np.sum(model_prob) + model_key_prob_map = dict(zip(self.model_keys, model_prob)) + + # Call share_params + self.wrapper.share_params( + shared_links, + resume=(self.is_restart and not self.finetune_update_stat) + or self.rank != 0, + model_key_prob_map=model_key_prob_map, + data_stat_protect=data_stat_protect_values[0], + ) + + def _init_distributed(self) -> None: + """Initialize distributed training.""" + if dist.is_available() and dist.is_initialized(): + torch.cuda.set_device(LOCAL_RANK) + self.wrapper = DDP( + self.wrapper, + device_ids=[LOCAL_RANK], + find_unused_parameters=True, + output_device=LOCAL_RANK, + ) + + def _init_checkpoint_manager(self) -> None: + """Initialize checkpoint manager.""" + self.checkpoint_manager = CheckpointManager( + self.config.checkpoint, + self.rank, + ) + + def _init_hooks(self) -> None: + """Initialize hook manager and default hooks.""" + self.hook_manager = HookManager() + + # Register timing hook + if self.config.display.time_training: + self.hook_manager.register(TimingHook()) + + # Register TensorBoard hook if enabled + if self.config.display.tensorboard: + self.hook_manager.register( + TensorBoardHook( + log_dir=self.config.display.tensorboard_log_dir, + log_freq=self.config.display.tensorboard_freq, + ) + ) + + def _init_logger(self) -> None: + """Initialize training logger.""" + self.logger = TrainingLogger( + log_file=self.config.display.disp_file, + is_multitask=self.is_multitask, + model_keys=self.model_keys if self.is_multitask else None, + rank=self.rank, + restart=self.is_restart, + ) + + # Initialize loss accumulator if averaging enabled + if self.config.display.disp_avg: + self.loss_accumulator = LossAccumulator( + self.is_multitask, + self.model_keys if self.is_multitask else None, + ) + else: + self.loss_accumulator = None + + def _init_training_loop(self) -> None: + """Initialize training loop based on optimizer type.""" + opt_config = self.config.get_optimizer_config() + loss = self.loss["Default"] if self.is_multitask else self.loss + + loop_factory = TrainingLoopFactory( + opt_config.opt_type, + { + "kf_start_pref_e": opt_config.kf_start_pref_e, + "kf_limit_pref_e": opt_config.kf_limit_pref_e, + "kf_start_pref_f": opt_config.kf_start_pref_f, + "kf_limit_pref_f": opt_config.kf_limit_pref_f, + }, + self.config.num_steps, + ) + + self.training_loop = loop_factory.create( + self.wrapper, + self.optimizer, + loss, + self.config.gradient_max_norm, + ) + + def _load_resume_checkpoint(self, finetune_model: str | None) -> None: + """Load checkpoint for resume or finetune.""" + checkpoint = self.checkpoint_manager.load( + self.resume_model, + self.wrapper, + self.optimizer if self.is_restart else None, + strict=not self.is_finetune, + ) + + if self.is_restart: + self.start_step = checkpoint.get("step", 0) + log.info(f"Resuming training from step {self.start_step}") + + # Update scheduler start step + if self.scheduler is not None: + # Recreate scheduler with correct start step + opt_config = self.config.get_optimizer_config() + self.scheduler = self.optimizer_factory.create_scheduler( + opt_config.opt_type, + self.optimizer, + self.config.warmup_steps, + self.config.warmup_start_factor, + self.lr_schedule, + self.start_step, + ) + else: + log.info(f"Initialized model from {self.resume_model}") + + def _load_frozen_model(self, frozen_model_path: str) -> None: + """Load frozen model for initialization.""" + log.info(f"Loading frozen model from {frozen_model_path}") + frz_model = torch.jit.load(frozen_model_path, map_location=DEVICE) + state = frz_model.state_dict() + missing, unexpected = self.model.load_state_dict(state, strict=False) + if missing or unexpected: + log.warning( + f"Non-strict load. Missing: {missing}, Unexpected: {unexpected}" + ) + + def _log_model_info(self) -> None: + """Log model parameter count.""" + if self.is_multitask: + log.warning("In multitask mode, parameters may be shared across tasks.") + for key in self.model_keys: + trainable, total = self._count_parameters(self.model[key]) + log.info( + f"Model Params [{key}]: {total / 1e6:.3f} M " + f"(Trainable: {trainable / 1e6:.3f} M)" + ) + else: + trainable, total = self._count_parameters(self.model) + log.info( + f"Model Params: {total / 1e6:.3f} M " + f"(Trainable: {trainable / 1e6:.3f} M)" + ) + + @staticmethod + def _count_parameters(model: torch.nn.Module) -> tuple[int, int]: + """Count model parameters.""" + trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + total = sum(p.numel() for p in model.parameters()) + return trainable, total + + def _get_loss_params(self, config: dict[str, Any]) -> dict[str, Any] | None: + """Extract loss parameters from config.""" + if self.is_multitask: + return config.get("loss_dict") + return config.get("loss") + + def _get_lr_for_task(self, config: dict[str, Any], task_key: str) -> float: + """Get learning rate for a specific task.""" + if ( + config.get("learning_rate_dict") + and task_key in config["learning_rate_dict"] + ): + return config["learning_rate_dict"][task_key]["start_lr"] + return config["learning_rate"]["start_lr"] + + def _is_hessian_loss(self, loss_params: dict[str, Any]) -> bool: + """Check if loss uses hessian.""" + return ( + loss_params.get("type", "ener") == "ener" + and loss_params.get("start_pref_h", 0.0) > 0.0 + ) + + def run(self) -> None: + """Execute the training loop.""" + log.info(f"Starting training for {self.config.num_steps} steps") + if self.world_size > 1: + log.info(f"Rank: {self.rank}/{self.world_size}") + + # Training state + start_time = time.time() + total_train_time = 0.0 + timed_steps = 0 + last_display_step = self.start_step + + self.hook_manager.on_train_begin( + {"start_step": self.start_step, "num_steps": self.config.num_steps} + ) + + try: + for step in range(self.start_step, self.config.num_steps): + # Select task for multi-task + if self.is_multitask: + task_probs = self._compute_task_probs() + task_idx = dp_random.choice( + np.arange(len(self.model_keys), dtype=np.int_), + p=task_probs, + ) + task_key = self.model_keys[task_idx] + else: + task_key = "Default" + + # Execute training step + step_result = self._training_step(step, task_key) + + # Update loss accumulator + if self.loss_accumulator is not None: + self.loss_accumulator.update(step_result.more_loss, task_key) + + # Log and validate at display frequency + display_step = step + 1 + if ( + display_step % self.config.display.disp_freq == 0 + or display_step == 1 + ): + self._log_and_validate( + step, + display_step, + task_key, + step_result, + start_time, + total_train_time, + timed_steps, + last_display_step, + ) + + # Update timing stats + current_time = time.time() + train_time = current_time - start_time + start_time = current_time + + if display_step > self.start_step + 1: + total_train_time += train_time + timed_steps += min( + self.config.display.disp_freq, + display_step - last_display_step, + ) + last_display_step = display_step + + # Save checkpoint + if ( + display_step % self.config.checkpoint.save_freq == 0 + or display_step == self.config.num_steps + ): + self._save_checkpoint(step, step_result.lr) + + except KeyboardInterrupt: + log.info("Training interrupted by user") + finally: + self._finalize_training(total_train_time, timed_steps) + + def _training_step(self, step: int, task_key: str) -> Any: + """Execute a single training step.""" + self.hook_manager.on_step_begin(step, {"task_key": task_key}) + + # Get learning rates + lr_config = self.config.get_lr_config(task_key) + cur_lr = self.lr_schedule.value(step) + + if self.scheduler is not None: + cur_lr = self.scheduler.get_last_lr()[0] + if step < self.config.warmup_steps: + pref_lr = lr_config.start_lr + else: + pref_lr = cur_lr + else: + pref_lr = cur_lr + + # Get batch + input_dict, label_dict, log_dict = self.data_manager.get_train_batch( + task_key if self.is_multitask else None + ) + + # Execute training step via training loop + result = self.training_loop.step( + input_dict, + label_dict, + cur_lr, + pref_lr, + task_key, + ) + + # Update scheduler + if self.scheduler is not None: + self.scheduler.step() + + self.hook_manager.on_step_end( + step, + { + "loss": result.loss.item(), + "lr": result.lr, + "task_key": task_key, + **result.more_loss, + }, + ) + + return result + + def _compute_task_probs(self) -> np.ndarray: + """Compute sampling probabilities for multi-task.""" + # Check if model_prob is provided in config + if hasattr(self, "_config_dict"): + model_prob_dict = self._config_dict.get("training", {}).get( + "model_prob", {} + ) + if model_prob_dict: + probs = np.array( + [model_prob_dict.get(key, 1.0) for key in self.model_keys] + ) + return probs / probs.sum() + + # Default: uniform + probs = np.ones(len(self.model_keys), dtype=np.float32) + return probs / probs.sum() + + def _log_and_validate( + self, + step: int, + display_step: int, + task_key: str, + step_result: Any, + start_time: float, + total_train_time: float, + timed_steps: int, + last_display_step: int, + ) -> None: + """Log training progress and run validation.""" + # Set eval mode for validation + self.wrapper.eval() + + # Get training results + if self.loss_accumulator is not None: + train_results = self.loss_accumulator.get_all_averaged() + self.loss_accumulator.reset() + else: + if self.is_multitask: + train_results = {key: {} for key in self.model_keys} + train_results[task_key] = { + k: v for k, v in step_result.more_loss.items() if "l2_" not in k + } + else: + train_results = { + k: v for k, v in step_result.more_loss.items() if "l2_" not in k + } + + # Run validation + valid_results = self._run_validation(task_key) + + # Compute timing + current_time = time.time() + train_time = current_time - start_time + if timed_steps > 0: + eta = int( + (self.config.num_steps - display_step) * total_train_time / timed_steps + ) + else: + eta = 0 + + # Log + self.logger.log_step( + display_step, + train_results, + valid_results, + step_result.lr, + train_time if self.config.display.time_training else None, + eta if self.config.display.time_training else None, + task_key, + ) + + self.hook_manager.on_validation_end( + step, {"train": train_results, "valid": valid_results} + ) + + # Restore train mode + self.wrapper.train() + + def _run_validation( + self, current_task_key: str + ) -> dict[str, Any] | dict[str, dict[str, Any]] | None: + """Run validation on all tasks.""" + self.hook_manager.on_validation_begin(0, {}) + + if self.is_multitask: + results: dict[str, dict[str, Any]] = {} + for key in self.model_keys: + results[key] = self._validate_task(key) + return results + else: + return self._validate_task("Default") + + def _validate_task(self, task_key: str) -> dict[str, Any]: + """Validate a single task.""" + num_batches = self.data_manager.get_valid_numb_batch( + task_key if self.is_multitask else None + ) + + if num_batches == 0: + return {} + + results: dict[str, float] = {} + total_natoms = 0 + + for _ in range(num_batches): + input_dict, label_dict, _ = self.data_manager.get_valid_batch( + task_key if self.is_multitask else None + ) + + if not input_dict: + break + + # Note: Don't use torch.no_grad() here because the model + # needs to compute gradients for force calculations via autograd + _, loss, more_loss = self.wrapper( + **input_dict, + cur_lr=0.0, + label=label_dict, + task_key=task_key, + ) + + natoms = int(input_dict["atype"].shape[-1]) + total_natoms += natoms + + for k, v in more_loss.items(): + if "l2_" not in k and isinstance(v, (int, float)): + results[k] = results.get(k, 0.0) + v * natoms + + # Average by atom count + if total_natoms > 0: + results = {k: v / total_natoms for k, v in results.items()} + + return results + + def _save_checkpoint(self, step: int, lr: float) -> None: + """Save training checkpoint.""" + path = self.checkpoint_manager.save( + step + 1, + self.wrapper, + self.optimizer, + lr, + ) + + if path: + self.hook_manager.on_save_checkpoint(step, str(path), {"lr": lr}) + + def _finalize_training(self, total_time: float, timed_steps: int) -> None: + """Finalize training and cleanup.""" + # Save final checkpoint + self._save_checkpoint(self.config.num_steps - 1, 0.0) + + # Log summary + if timed_steps > 0: + excluded = self.config.num_steps - self.start_step - timed_steps + self.logger.log_summary(total_time, timed_steps, excluded) + + self.hook_manager.on_train_end( + {"total_time": total_time, "timed_steps": timed_steps} + ) + + self.logger.close() + + log.info( + f"Training completed. Model saved to {self.config.checkpoint.save_ckpt}" + ) + + def get_data( + self, is_train: bool = True, task_key: str = "Default" + ) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: + """Get a batch of data. + + This method is provided for backward compatibility and testing. + + Parameters + ---------- + is_train : bool + Whether to get training data (True) or validation data (False). + task_key : str + Task key for multi-task training. + + Returns + ------- + tuple[dict[str, Any], dict[str, Any], dict[str, Any]] + (input_dict, label_dict, log_dict) + """ + if is_train: + return self.data_manager.get_train_batch( + task_key if self.is_multitask else None + ) + else: + return self.data_manager.get_valid_batch( + task_key if self.is_multitask else None + ) diff --git a/deepmd/pt/train/training_loop.py b/deepmd/pt/train/training_loop.py new file mode 100755 index 0000000000..b0d449a364 --- /dev/null +++ b/deepmd/pt/train/training_loop.py @@ -0,0 +1,433 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Core training loop implementations for different optimizer types. + +This module provides specialized training loops for different optimizers, +making it easy to add new training strategies while keeping the main +trainer clean. +""" + +from __future__ import ( + annotations, +) + +import logging +from abc import ( + ABC, + abstractmethod, +) +from typing import ( + Any, +) + +import torch +import torch.distributed as dist + +from deepmd.pt.loss import ( + DenoiseLoss, + EnergyStdLoss, +) +from deepmd.pt.optimizer import ( + KFOptimizerWrapper, +) + +log = logging.getLogger(__name__) + + +class TrainingStepResult: + """Container for training step results. + + Attributes + ---------- + loss : torch.Tensor + The computed loss. + model_pred : dict[str, torch.Tensor] + Model predictions. + more_loss : dict[str, Any] + Additional loss components. + lr : float + Current learning rate. + """ + + def __init__( + self, + loss: torch.Tensor, + model_pred: dict[str, torch.Tensor], + more_loss: dict[str, Any], + lr: float, + ) -> None: + self.loss = loss + self.model_pred = model_pred + self.more_loss = more_loss + self.lr = lr + + +class BaseTrainingLoop(ABC): + """Abstract base class for training loops. + + Subclasses implement specific training strategies for different + optimizer types and training modes. + """ + + def __init__( + self, + wrapper: torch.nn.Module, + optimizer: Any, + gradient_max_norm: float = 0.0, + ) -> None: + """Initialize training loop. + + Parameters + ---------- + wrapper : torch.nn.Module + Model wrapper (may be wrapped in DDP). + optimizer : Any + Optimizer instance. + gradient_max_norm : float + Maximum gradient norm for clipping (0.0 = disabled). + """ + self.wrapper = wrapper + self.optimizer = optimizer + self.gradient_max_norm = gradient_max_norm + + @abstractmethod + def step( + self, + input_dict: dict[str, torch.Tensor], + label_dict: dict[str, torch.Tensor], + cur_lr: float, + pref_lr: float, + task_key: str = "Default", + ) -> TrainingStepResult: + """Execute a single training step. + + Parameters + ---------- + input_dict : dict[str, torch.Tensor] + Input tensors. + label_dict : dict[str, torch.Tensor] + Label tensors. + cur_lr : float + Current learning rate from scheduler. + pref_lr : float + Preferred learning rate for loss computation. + task_key : str + Task key for multi-task training. + + Returns + ------- + TrainingStepResult + Results from the training step. + """ + pass + + def zero_grad(self) -> None: + """Zero optimizer gradients.""" + self.optimizer.zero_grad(set_to_none=True) + + def _get_module(self) -> torch.nn.Module: + """Get unwrapped module from DDP if needed.""" + module = self.wrapper + if dist.is_available() and dist.is_initialized(): + if hasattr(module, "module"): + module = module.module + return module + + +class AdamTrainingLoop(BaseTrainingLoop): + """Training loop for Adam/AdamW/AdaMuon/HybridMuon optimizers. + + Standard backpropagation with gradient clipping support. + """ + + def step( + self, + input_dict: dict[str, torch.Tensor], + label_dict: dict[str, torch.Tensor], + cur_lr: float, + pref_lr: float, + task_key: str = "Default", + ) -> TrainingStepResult: + """Execute training step with standard backpropagation.""" + self.zero_grad() + + # Forward pass + model_pred, loss, more_loss = self.wrapper( + **input_dict, + cur_lr=pref_lr, + label=label_dict, + task_key=task_key, + ) + + # Backward pass + loss.backward() + + # Gradient clipping + if self.gradient_max_norm > 0.0: + torch.nn.utils.clip_grad_norm_( + self.wrapper.parameters(), + self.gradient_max_norm, + error_if_nonfinite=True, + ) + + # Optimizer step + with torch.device("cpu"): + self.optimizer.step() + + return TrainingStepResult( + loss=loss, + model_pred=model_pred, + more_loss=more_loss, + lr=cur_lr, + ) + + +class LKFEnergyTrainingLoop(BaseTrainingLoop): + """Training loop for LKF optimizer with energy/force loss. + + Uses Kalman Filter optimizer for energy and force updates. + """ + + def __init__( + self, + wrapper: torch.nn.Module, + optimizer: Any, + opt_param: dict[str, Any], + num_steps: int, + gradient_max_norm: float = 0.0, + ) -> None: + """Initialize LKF training loop. + + Parameters + ---------- + wrapper : torch.nn.Module + Model wrapper. + optimizer : Any + LKF optimizer. + opt_param : dict[str, Any] + Optimizer parameters including kf_start_pref_e, etc. + num_steps : int + Total training steps for prefactor scheduling. + gradient_max_norm : float + Maximum gradient norm (not used for LKF). + """ + super().__init__(wrapper, optimizer, gradient_max_norm) + self.opt_param = opt_param + self.num_steps = num_steps + + # Create KF wrapper + self.kf_wrapper = KFOptimizerWrapper( + wrapper, + optimizer, + 24, # kp + 6, # kq + dist.is_available() and dist.is_initialized(), + ) + + def _compute_prefactors(self, step: int) -> tuple[float, float]: + """Compute energy and force prefactors for current step.""" + start_pref_e = self.opt_param["kf_start_pref_e"] + limit_pref_e = self.opt_param["kf_limit_pref_e"] + start_pref_f = self.opt_param["kf_start_pref_f"] + limit_pref_f = self.opt_param["kf_limit_pref_f"] + + ratio = step / self.num_steps + + pref_e = start_pref_e * (limit_pref_e / start_pref_e) ** ratio + pref_f = start_pref_f * (limit_pref_f / start_pref_f) ** ratio + + return pref_e, pref_f + + def step( + self, + input_dict: dict[str, torch.Tensor], + label_dict: dict[str, torch.Tensor], + cur_lr: float, + pref_lr: float, + task_key: str = "Default", + ) -> TrainingStepResult: + """Execute LKF training step.""" + # Compute prefactors + step = self.optimizer.state.get("step", 0) + pref_e, pref_f = self._compute_prefactors(step) + + # Update energy + _ = self.kf_wrapper.update_energy(input_dict, label_dict["energy"], pref_e) + + # Update force + p_energy, p_force = self.kf_wrapper.update_force( + input_dict, label_dict["force"], pref_f + ) + + model_pred = {"energy": p_energy, "force": p_force} + + # Compute loss using wrapper's loss function + module = self._get_module() + + def fake_model() -> dict[str, torch.Tensor]: + return model_pred + + natoms = int(input_dict["atype"].shape[-1]) + + _, loss, more_loss = module.loss[task_key]( + {}, + fake_model, + label_dict, + natoms, + learning_rate=pref_lr, + ) + + return TrainingStepResult( + loss=loss, + model_pred=model_pred, + more_loss=more_loss, + lr=cur_lr, + ) + + +class LKFDenoiseTrainingLoop(BaseTrainingLoop): + """Training loop for LKF optimizer with denoising loss.""" + + def __init__( + self, + wrapper: torch.nn.Module, + optimizer: Any, + loss: DenoiseLoss, + gradient_max_norm: float = 0.0, + ) -> None: + """Initialize LKF denoise training loop.""" + super().__init__(wrapper, optimizer, gradient_max_norm) + + self.kf_wrapper = KFOptimizerWrapper( + wrapper, + optimizer, + 24, # kp + 6, # kq + dist.is_available() and dist.is_initialized(), + ) + self.loss_module = loss + + def step( + self, + input_dict: dict[str, torch.Tensor], + label_dict: dict[str, torch.Tensor], + cur_lr: float, + pref_lr: float, + task_key: str = "Default", + ) -> TrainingStepResult: + """Execute LKF denoise training step.""" + module = self._get_module() + loss_fn = module.loss[task_key] + + # Update coordinates via KF + model_pred = self.kf_wrapper.update_denoise_coord( + input_dict, + label_dict["clean_coord"], + 1, # prefactor + loss_fn.mask_loss_coord, + label_dict.get("coord_mask"), + ) + + # Compute loss + loss, more_loss = loss_fn( + model_pred, + label_dict, + input_dict["natoms"], + learning_rate=pref_lr, + ) + + return TrainingStepResult( + loss=loss, + model_pred=model_pred, + more_loss=more_loss, + lr=cur_lr, + ) + + +class TrainingLoopFactory: + """Factory for creating appropriate training loops. + + Selects the correct training loop implementation based on + optimizer type and loss function. + """ + + def __init__( + self, + opt_type: str, + opt_param: dict[str, Any], + num_steps: int, + ) -> None: + """Initialize factory. + + Parameters + ---------- + opt_type : str + Type of optimizer. + opt_param : dict[str, Any] + Optimizer parameters. + num_steps : int + Total training steps. + """ + self.opt_type = opt_type + self.opt_param = opt_param + self.num_steps = num_steps + + def create( + self, + wrapper: torch.nn.Module, + optimizer: Any, + loss: Any, + gradient_max_norm: float = 0.0, + ) -> BaseTrainingLoop: + """Create training loop instance. + + Parameters + ---------- + wrapper : torch.nn.Module + Model wrapper. + optimizer : Any + Optimizer instance. + loss : Any + Loss function/module. + gradient_max_norm : float + Maximum gradient norm. + + Returns + ------- + BaseTrainingLoop + Appropriate training loop for the configuration. + + Raises + ------ + ValueError + If optimizer type is not supported. + """ + if self.opt_type in ["Adam", "AdamW", "AdaMuon", "HybridMuon"]: + return AdamTrainingLoop( + wrapper, + optimizer, + gradient_max_norm, + ) + + elif self.opt_type == "LKF": + if isinstance(loss, EnergyStdLoss): + return LKFEnergyTrainingLoop( + wrapper, + optimizer, + self.opt_param, + self.num_steps, + gradient_max_norm, + ) + elif isinstance(loss, DenoiseLoss): + return LKFDenoiseTrainingLoop( + wrapper, + optimizer, + loss, + gradient_max_norm, + ) + else: + raise ValueError( + f"LKF optimizer not supported for loss type: {type(loss)}" + ) + + else: + raise ValueError(f"Unsupported optimizer type: {self.opt_type}") diff --git a/source/tests/pt/test_new_training.py b/source/tests/pt/test_new_training.py new file mode 100755 index 0000000000..988c5b85b9 --- /dev/null +++ b/source/tests/pt/test_new_training.py @@ -0,0 +1,447 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Unit tests for the refactored training system. + +Includes end-to-end CLI tests to verify dp --pt train works correctly. +""" + +import copy +import unittest + +from deepmd.pt.entrypoints.main import ( + get_trainer, +) +from deepmd.pt.train.config import ( + CheckpointConfig, + DisplayConfig, + LearningRateConfig, + OptimizerConfig, + TrainingConfig, +) +from deepmd.pt.train.hooks import ( + HookManager, + HookPriority, + TrainingHook, +) +from deepmd.pt.train.logger import ( + LossAccumulator, +) +from deepmd.pt.train.optimizer_factory import ( + OptimizerFactory, +) + + +class TestOptimizerConfig(unittest.TestCase): + """Test OptimizerConfig dataclass.""" + + def test_default_values(self): + """Test default configuration values.""" + config = OptimizerConfig() + self.assertEqual(config.opt_type, "Adam") + self.assertEqual(config.weight_decay, 0.001) + self.assertEqual(config.momentum, 0.95) + + def test_from_dict(self): + """Test creating from dictionary.""" + params = { + "opt_type": "AdamW", + "weight_decay": 0.01, + "kf_blocksize": 1024, + } + config = OptimizerConfig.from_dict(params) + self.assertEqual(config.opt_type, "AdamW") + self.assertEqual(config.weight_decay, 0.01) + self.assertEqual(config.kf_blocksize, 1024) + # Check defaults are preserved + self.assertEqual(config.momentum, 0.95) + + +class TestLearningRateConfig(unittest.TestCase): + """Test LearningRateConfig dataclass.""" + + def test_default_values(self): + """Test default configuration values.""" + config = LearningRateConfig() + self.assertEqual(config.start_lr, 1e-3) + self.assertEqual(config.stop_lr, 1e-8) + + def test_from_dict(self): + """Test creating from dictionary.""" + params = {"start_lr": 0.001, "decay_steps": 5000} + config = LearningRateConfig.from_dict(params) + self.assertEqual(config.start_lr, 0.001) + self.assertEqual(config.decay_steps, 5000) + + +class TestTrainingConfig(unittest.TestCase): + """Test TrainingConfig dataclass.""" + + def test_single_task_config(self): + """Test single-task configuration parsing.""" + config_dict = { + "training": { + "numb_steps": 1000, + "warmup_steps": 100, + "disp_freq": 100, + }, + "learning_rate": {"start_lr": 0.001}, + } + config = TrainingConfig.from_dict(config_dict) + self.assertEqual(config.num_steps, 1000) + self.assertEqual(config.warmup_steps, 100) + self.assertFalse(config.is_multitask) + + def test_multitask_config(self): + """Test multi-task configuration parsing.""" + config_dict = { + "training": { + "numb_steps": 1000, + "optim_dict": { + "task1": {"opt_type": "Adam"}, + "task2": {"opt_type": "AdamW"}, + }, + }, + "learning_rate": {"start_lr": 0.001}, + } + model_keys = ["task1", "task2"] + config = TrainingConfig.from_dict(config_dict, model_keys) + self.assertTrue(config.is_multitask) + self.assertIn("task1", config.optimizer_dict) + self.assertIn("task2", config.optimizer_dict) + + def test_warmup_ratio(self): + """Test warmup ratio computation.""" + config_dict = { + "training": { + "numb_steps": 1000, + "warmup_ratio": 0.1, + }, + "learning_rate": {"start_lr": 0.001}, + } + config = TrainingConfig.from_dict(config_dict) + self.assertEqual(config.warmup_steps, 100) + + def test_invalid_num_steps(self): + """Test validation of invalid num_steps.""" + config_dict = { + "training": {"numb_steps": 0}, + "learning_rate": {"start_lr": 0.001}, + } + with self.assertRaises(ValueError): + TrainingConfig.from_dict(config_dict) + + +class TestOptimizerFactory(unittest.TestCase): + """Test OptimizerFactory.""" + + def setUp(self): + """Set up test fixtures.""" + self.factory = OptimizerFactory() + + def test_available_optimizers(self): + """Test getting available optimizer types.""" + optimizers = self.factory.get_available_optimizers() + self.assertIn("Adam", optimizers) + self.assertIn("AdamW", optimizers) + self.assertIn("LKF", optimizers) + + def test_scheduler_support(self): + """Test checking scheduler support.""" + self.assertTrue(self.factory.supports_scheduler("Adam")) + self.assertTrue(self.factory.supports_scheduler("AdamW")) + self.assertFalse(self.factory.supports_scheduler("LKF")) + + +class TestHooks(unittest.TestCase): + """Test hook system.""" + + def test_hook_priority(self): + """Test hook priority ordering.""" + manager = HookManager() + + # Create hooks with different priorities + class LowPriorityHook(TrainingHook): + priority = HookPriority.LOW + + class HighPriorityHook(TrainingHook): + priority = HookPriority.HIGH + + class NormalPriorityHook(TrainingHook): + priority = HookPriority.NORMAL + + low = LowPriorityHook() + high = HighPriorityHook() + normal = NormalPriorityHook() + + manager.register(low) + manager.register(high) + manager.register(normal) + + # Check order: high, normal, low + self.assertEqual(manager.hooks[0], high) + self.assertEqual(manager.hooks[1], normal) + self.assertEqual(manager.hooks[2], low) + + def test_hook_execution(self): + """Test hook method execution.""" + manager = HookManager() + + class TestHook(TrainingHook): + def __init__(self): + self.step_count = 0 + + def on_step_end(self, step, logs): + self.step_count += 1 + + hook = TestHook() + manager.register(hook) + + manager.on_step_end(0, {"loss": 1.0}) + manager.on_step_end(1, {"loss": 0.5}) + + self.assertEqual(hook.step_count, 2) + + def test_hook_error_handling(self): + """Test that hook errors don't crash training.""" + manager = HookManager() + + class BadHook(TrainingHook): + def on_step_end(self, step, logs): + raise RuntimeError("Hook error") + + class GoodHook(TrainingHook): + def __init__(self): + self.called = False + + def on_step_end(self, step, logs): + self.called = True + + bad = BadHook() + good = GoodHook() + + manager.register(bad) + manager.register(good) + + # Should not raise + manager.on_step_end(0, {"loss": 1.0}) + + # Good hook should still be called + self.assertTrue(good.called) + + +class TestLossAccumulator(unittest.TestCase): + """Test LossAccumulator.""" + + def test_single_task_accumulation(self): + """Test single-task loss accumulation.""" + accumulator = LossAccumulator(is_multitask=False) + + accumulator.update({"loss": 1.0, "rmse": 0.5}, "Default") + accumulator.update({"loss": 2.0, "rmse": 1.0}, "Default") + + averaged = accumulator.get_averaged() + self.assertAlmostEqual(averaged["loss"], 1.5) + self.assertAlmostEqual(averaged["rmse"], 0.75) + + def test_multitask_accumulation(self): + """Test multi-task loss accumulation.""" + accumulator = LossAccumulator(is_multitask=True, model_keys=["task1", "task2"]) + + accumulator.update({"loss": 1.0}, "task1") + accumulator.update({"loss": 2.0}, "task1") + accumulator.update({"loss": 3.0}, "task2") + + avg1 = accumulator.get_averaged("task1") + avg2 = accumulator.get_averaged("task2") + + self.assertAlmostEqual(avg1["loss"], 1.5) + self.assertAlmostEqual(avg2["loss"], 3.0) + + def test_reset(self): + """Test accumulator reset.""" + accumulator = LossAccumulator(is_multitask=False) + + accumulator.update({"loss": 1.0}, "Default") + accumulator.reset() + + averaged = accumulator.get_averaged() + self.assertEqual(averaged, {}) + + def test_skip_l2_keys(self): + """Test that l2_ keys are skipped.""" + accumulator = LossAccumulator(is_multitask=False) + + accumulator.update({"loss": 1.0, "l2_loss": 100.0}, "Default") + + averaged = accumulator.get_averaged() + self.assertIn("loss", averaged) + self.assertNotIn("l2_loss", averaged) + + +class TestDisplayConfig(unittest.TestCase): + """Test DisplayConfig.""" + + def test_default_values(self): + """Test default display configuration.""" + config = DisplayConfig() + self.assertEqual(config.disp_file, "lcurve.out") + self.assertEqual(config.disp_freq, 1000) + self.assertTrue(config.disp_training) + + def test_tensorboard_defaults(self): + """Test TensorBoard configuration defaults.""" + config = DisplayConfig() + self.assertFalse(config.tensorboard) + self.assertEqual(config.tensorboard_log_dir, "log") + + +class TestCheckpointConfig(unittest.TestCase): + """Test CheckpointConfig.""" + + def test_default_values(self): + """Test default checkpoint configuration.""" + config = CheckpointConfig() + self.assertEqual(config.save_ckpt, "model.ckpt") + self.assertEqual(config.save_freq, 1000) + self.assertEqual(config.max_ckpt_keep, 5) + + +class TestEndToEndCLI(unittest.TestCase): + """End-to-end tests for CLI integration. + + These tests verify that dp --pt train works correctly with the new Trainer. + """ + + def setUp(self): + """Set up test fixtures with water data config.""" + self.config = { + "model": { + "type_map": ["O", "H"], + "descriptor": { + "type": "se_e2_a", + "sel": [46, 92], + "rcut_smth": 0.50, + "rcut": 6.00, + "neuron": [25, 50, 100], + "resnet_dt": False, + "axis_neuron": 16, + "seed": 1, + }, + "fitting_net": { + "neuron": [240, 240, 240], + "resnet_dt": True, + "seed": 1, + }, + }, + "learning_rate": { + "type": "exp", + "start_lr": 0.001, + "stop_lr": 3.51e-8, + "decay_steps": 5000, + }, + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0.0, + "limit_pref_v": 0.0, + }, + "training": { + "training_data": { + "systems": ["source/tests/pt/water/data/data_0"], + "batch_size": 1, + "numb_btch": 1, + }, + "validation_data": { + "systems": ["source/tests/pt/water/data/data_1"], + "batch_size": 1, + "numb_btch": 1, + }, + "numb_steps": 2, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 1, + "save_freq": 1, + }, + } + + def test_get_trainer_creates_new_trainer(self): + """Test get_trainer creates new modular Trainer.""" + config = copy.deepcopy(self.config) + trainer = get_trainer(config) + + # Verify it's the new Trainer + from deepmd.pt.train.trainer import Trainer as NewTrainer + + self.assertIsInstance(trainer, NewTrainer) + + def test_trainer_get_data(self): + """Test Trainer.get_data method works correctly.""" + config = copy.deepcopy(self.config) + trainer = get_trainer(config) + + # Get training data + input_dict, label_dict, log_dict = trainer.get_data(is_train=True) + self.assertIn("coord", input_dict) + self.assertIn("atype", input_dict) + + # Get validation data + input_dict, label_dict, log_dict = trainer.get_data(is_train=False) + self.assertIn("coord", input_dict) + self.assertIn("atype", input_dict) + + def test_trainer_model_accessible(self): + """Test that trainer.model is accessible.""" + config = copy.deepcopy(self.config) + trainer = get_trainer(config) + + # Model should be accessible + self.assertIsNotNone(trainer.model) + + # Should have expected attributes + self.assertTrue(hasattr(trainer.model, "get_descriptor")) + self.assertTrue(hasattr(trainer.model, "get_fitting_net")) + + def test_trainer_config_components(self): + """Test that config components are properly initialized.""" + config = copy.deepcopy(self.config) + trainer = get_trainer(config) + + # Check config components + self.assertIsNotNone(trainer.config.optimizer) + self.assertIsNotNone(trainer.config.learning_rate) + self.assertIsNotNone(trainer.config.display) + self.assertIsNotNone(trainer.config.checkpoint) + + # Check optimizer type + self.assertEqual(trainer.config.optimizer.opt_type, "Adam") + + def test_trainer_components_initialized(self): + """Test that all trainer components are initialized.""" + config = copy.deepcopy(self.config) + trainer = get_trainer(config) + + # Check components + self.assertIsNotNone(trainer.data_manager) + self.assertIsNotNone(trainer.checkpoint_manager) + self.assertIsNotNone(trainer.hook_manager) + self.assertIsNotNone(trainer.logger) + self.assertIsNotNone(trainer.optimizer) + self.assertIsNotNone(trainer.lr_schedule) + + def test_trainer_wrapper_initialized(self): + """Test that model wrapper is properly initialized.""" + config = copy.deepcopy(self.config) + trainer = get_trainer(config) + + # Wrapper should exist + self.assertIsNotNone(trainer.wrapper) + + # Wrapper should have model and loss + self.assertTrue(hasattr(trainer.wrapper, "model")) + self.assertTrue(hasattr(trainer.wrapper, "loss")) + + +if __name__ == "__main__": + unittest.main()