Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 30 additions & 13 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from deepmd.pt.train import (
training,
)
from deepmd.pt.train.trainer import Trainer as NewTrainer
from deepmd.pt.train.wrapper import (
ModelWrapper,
)
Expand Down Expand Up @@ -106,6 +107,7 @@ def get_trainer(
init_frz_model: str | None = None,
shared_links: dict[str, Any] | None = None,
finetune_links: dict[str, Any] | None = None,
use_legacy: bool = False,
) -> training.Trainer:
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type annotation of get_trainer() is training.Trainer, but the function can now return the new modular deepmd.pt.train.trainer.Trainer when use_legacy=False. Update the annotation to a union/protocol (or a common base type) to keep typing accurate for downstream users.

Suggested change
) -> training.Trainer:
) -> training.Trainer | NewTrainer:

Copilot uses AI. Check for mistakes.
multi_task = "model_dict" in config.get("model", {})

Expand Down Expand Up @@ -200,19 +202,34 @@ def prepare_trainer_input_single(
seed=data_seed,
)

trainer = training.Trainer(
config,
train_data,
stat_file_path=stat_file_path,
validation_data=validation_data,
init_model=init_model,
restart_model=restart_model,
finetune_model=finetune_model,
force_load=force_load,
shared_links=shared_links,
finetune_links=finetune_links,
init_frz_model=init_frz_model,
)
if use_legacy:
trainer = training.Trainer(
config,
train_data,
stat_file_path=stat_file_path,
validation_data=validation_data,
init_model=init_model,
restart_model=restart_model,
finetune_model=finetune_model,
force_load=force_load,
shared_links=shared_links,
finetune_links=finetune_links,
init_frz_model=init_frz_model,
)
else:
trainer = NewTrainer(
config,
train_data,
stat_file_path=stat_file_path,
validation_data=validation_data,
init_model=init_model,
restart_model=restart_model,
finetune_model=finetune_model,
force_load=force_load,
shared_links=shared_links,
finetune_links=finetune_links,
init_frz_model=init_frz_model,
)
Comment on lines +205 to +232
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code for creating the trainer instance is duplicated for the use_legacy and else branches. This can be refactored to reduce code duplication and improve maintainability by selecting the class first and then instantiating it once.

    trainer_class = training.Trainer if use_legacy else NewTrainer
    trainer = trainer_class(
        config,
        train_data,
        stat_file_path=stat_file_path,
        validation_data=validation_data,
        init_model=init_model,
        restart_model=restart_model,
        finetune_model=finetune_model,
        force_load=force_load,
        shared_links=shared_links,
        finetune_links=finetune_links,
        init_frz_model=init_frz_model,
    )

return trainer


Expand Down
101 changes: 101 additions & 0 deletions deepmd/pt/train/__init__.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1 +1,102 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""PyTorch training module with modular, extensible design.

This module provides a clean, component-based training system:

- TrainingConfig: Configuration management with validation
- DataManager: Data loading and batch iteration
- OptimizerFactory: Strategy pattern for optimizer creation
- CheckpointManager: Model persistence and recovery
- TrainingLoop: Specialized training step implementations
- HookManager: Extensible callback system
- TrainingLogger: Formatted output and file I/O
- Trainer: Main orchestrator coordinating all components

Example:
>>> from deepmd.pt.train import Trainer, TrainingConfig
>>>
>>> # Create trainer
>>> trainer = Trainer(
... config=config_dict,
... training_data=train_dataset,
... validation_data=valid_dataset,
... )
>>>
>>> # Run training
>>> trainer.run()

Future extensions for multi-backend support:
- AbstractTrainingLoop can be extended for JAX/NumPy
- OptimizerFactory can support backend-specific optimizers
- DataManager can use backend-specific data loading
"""

from deepmd.pt.train.checkpoint_manager import (
CheckpointManager,
)
from deepmd.pt.train.config import (
CheckpointConfig,
DisplayConfig,
LearningRateConfig,
OptimizerConfig,
TrainingConfig,
)
from deepmd.pt.train.data_manager import (
DataManager,
)
from deepmd.pt.train.hooks import (
HookManager,
HookPriority,
TensorBoardHook,
TimingHook,
TrainingHook,
)
from deepmd.pt.train.logger import (
LossAccumulator,
TrainingLogger,
)
from deepmd.pt.train.optimizer_factory import (
OptimizerFactory,
)
from deepmd.pt.train.trainer import (
Trainer,
)

# Keep old Trainer available for backward compatibility during transition
from deepmd.pt.train.training import Trainer as LegacyTrainer
from deepmd.pt.train.training_loop import (
AdamTrainingLoop,
BaseTrainingLoop,
LKFEnergyTrainingLoop,
TrainingLoopFactory,
)
from deepmd.pt.train.wrapper import (
ModelWrapper,
)

__all__ = [
# New modular components
"AdamTrainingLoop",
"BaseTrainingLoop",
"CheckpointConfig",
"CheckpointManager",
"DataManager",
"DisplayConfig",
"HookManager",
"HookPriority",
"LKFEnergyTrainingLoop",
"LearningRateConfig",
# Legacy support
"LegacyTrainer",
"LossAccumulator",
"ModelWrapper",
"OptimizerConfig",
"OptimizerFactory",
"TensorBoardHook",
"TimingHook",
"Trainer",
"TrainingConfig",
"TrainingHook",
"TrainingLogger",
"TrainingLoopFactory",
]
Loading
Loading