Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
8356629
feat(pt_expt): add dos, dipole, polar and property fittings
Feb 22, 2026
292fa72
add make_fx, mv itertools to parameterized
Feb 22, 2026
7289167
feat(pt_expt): full models dipole, polar, dos, property and dp-zbl
Feb 22, 2026
553b91d
rm _forward_lower
Feb 22, 2026
0753cd7
rm register_dpmodel_mapping from fitting
Feb 22, 2026
6d6adfe
remove the atomic model in pt_expt. mv atomic model's output stat tes…
Feb 22, 2026
bf448ac
add translated_output_def
Feb 22, 2026
9572a04
base model registration
Feb 22, 2026
0dcd03b
implement compute_or_load_stat
Feb 23, 2026
28fbd08
fix bug in test_ener
Feb 23, 2026
237e4a8
refact make_model, concrete models from different backends inherit fr…
Feb 23, 2026
2a958ec
Add compute_or_load_stat consistency tests and fix dpmodel backend bugs
Feb 23, 2026
41af959
rm tmp test files
Feb 24, 2026
b2028a8
Merge remote-tracking branch 'origin/feat-other-full-model' into feat…
Feb 24, 2026
5a4a5d2
remove concrete methods and data from BaseModel
Feb 24, 2026
cab1b35
Merge branch 'master' into feat-other-full-model
Feb 24, 2026
19f9058
rm model_type
Feb 24, 2026
26b0a40
fix spin model
Feb 24, 2026
356a1e6
Merge remote-tracking branch 'origin/feat-other-full-model' into feat…
Feb 24, 2026
eecd82b
add get_observed_type_list to abstract API and implement in dpmodel
Feb 24, 2026
aba2d71
fix: dpmodel change_type_map drops model_with_new_type_stat and uses …
Feb 24, 2026
21dc4e7
consolidate get_out_bias/set_out_bias into base_atomic_model
Feb 24, 2026
61722b9
change fitting -> fitting_net
Feb 24, 2026
c41515a
fix: dpmodel change_out_bias missing compute_fitting_input_stat for s…
Feb 24, 2026
124eedd
Merge remote-tracking branch 'origin/feat-other-full-model' into feat…
Feb 24, 2026
3827a9c
fix bug
Feb 24, 2026
9e926bf
fix bug
Feb 24, 2026
f1dbd4f
add missing get_observed_type_list to paddel
Feb 24, 2026
df132d4
add tests for get_model_def_script get_min_nbor_dist and set_case_embd
Feb 24, 2026
0c169cb
Merge remote-tracking branch 'origin/feat-other-full-model' into feat…
Feb 24, 2026
1da8708
fix hlo
Feb 24, 2026
b4d43f0
add dipole model api tests. mv get_observed_type_list to base
Feb 24, 2026
6ac0cef
fix frozen model
Feb 24, 2026
4b54857
add polar model api tests.
Feb 24, 2026
c35ee54
add property model api tests, fix bugs
Feb 24, 2026
11c0201
add dos test, fix bug
Feb 24, 2026
eb32961
Merge remote-tracking branch 'origin/feat-other-full-model' into feat…
Feb 24, 2026
fbfd042
add ut for dp-zbl model
Feb 25, 2026
b49a10f
add test_get_model_def_script test_get_min_nbor_dist test_set_case_em…
Feb 25, 2026
d42c8d8
chore(pt): mv the input stat update to model_change_out_bias to keep …
Feb 25, 2026
b7af468
chore(pd): update in the same way as pt
Feb 25, 2026
0ec5748
update test for change out bias
Feb 25, 2026
00f83cc
test the stat is changed
Feb 25, 2026
ff4a27c
rm unused methods
Feb 25, 2026
15f2af8
use deep copy
Feb 25, 2026
35d4cbe
Merge remote-tracking branch 'origin/feat-other-full-model' into feat…
Feb 25, 2026
4540382
Extracted compare_variables_recursive to source/tests/consistent/mode…
Feb 26, 2026
aa2643e
Merge branch 'master' into feat-other-full-model
Feb 26, 2026
9427d25
fix: remove dead code and redundant assignments in dpmodel atomic models
Feb 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import functools
import math
from collections.abc import (
Callable,
Expand Down Expand Up @@ -52,13 +53,15 @@ def __init__(
pair_exclude_types: list[tuple[int, int]] = [],
rcond: float | None = None,
preset_out_bias: dict[str, Array] | None = None,
data_stat_protect: float = 1e-2,
) -> None:
super().__init__()
self.type_map = type_map
self.reinit_atom_exclude(atom_exclude_types)
self.reinit_pair_exclude(pair_exclude_types)
self.rcond = rcond
self.preset_out_bias = preset_out_bias
self.data_stat_protect = data_stat_protect

def init_out_stat(self) -> None:
"""Initialize the output bias."""
Expand All @@ -77,6 +80,14 @@ def init_out_stat(self) -> None:
self.out_bias = out_bias_data
self.out_std = out_std_data

def get_out_bias(self) -> Array:
"""Get the output bias."""
return self.out_bias

def set_out_bias(self, out_bias: Array) -> None:
"""Set the output bias."""
self.out_bias = out_bias

def __setitem__(self, key: str, value: Array) -> None:
if key in ["out_bias"]:
self.out_bias = value
Expand Down Expand Up @@ -287,6 +298,57 @@ def compute_or_load_out_stat(
bias_adjust_mode="set-by-statistic",
)

def _make_wrapped_sampler(
self,
sampled_func: Callable[[], list[dict]],
) -> Callable[[], list[dict]]:
"""Wrap the sampled function with exclusion types and default fparam.

The returned callable is cached so that the sampling (which may be
expensive) is performed at most once.

Parameters
----------
sampled_func
The lazy sampled function to get data frames from different data
systems.

Returns
-------
Callable[[], list[dict]]
A cached wrapper around *sampled_func* that additionally sets
``pair_exclude_types``, ``atom_exclude_types`` and default
``fparam`` on every sample dict when applicable.
"""

@functools.lru_cache
def wrapped_sampler() -> list[dict]:
sampled = sampled_func()
if self.pair_excl is not None:
pair_exclude_types = self.pair_excl.get_exclude_types()
for sample in sampled:
sample["pair_exclude_types"] = list(pair_exclude_types)
if self.atom_excl is not None:
atom_exclude_types = self.atom_excl.get_exclude_types()
for sample in sampled:
sample["atom_exclude_types"] = list(atom_exclude_types)
if (
"find_fparam" not in sampled[0]
and "fparam" not in sampled[0]
and self.has_default_fparam()
):
default_fparam = self.get_default_fparam()
if default_fparam is not None:
default_fparam_np = np.array(default_fparam)
for sample in sampled:
nframe = sample["atype"].shape[0]
sample["fparam"] = np.tile(
default_fparam_np.reshape(1, -1), (nframe, 1)
)
return sampled

return wrapped_sampler

def change_out_bias(
self,
sample_merged: Callable[[], list[dict]] | list[dict],
Expand Down
88 changes: 74 additions & 14 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from collections.abc import (
Callable,
)
from typing import (
Any,
)
Expand All @@ -15,6 +18,9 @@
from deepmd.dpmodel.output_def import (
FittingOutputDef,
)
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.version import (
check_version_compatibility,
)
Expand Down Expand Up @@ -48,17 +54,16 @@ def __init__(
**kwargs: Any,
) -> None:
super().__init__(type_map, **kwargs)
self.type_map = type_map
self.descriptor = descriptor
self.fitting = fitting
if hasattr(self.fitting, "reinit_exclude"):
self.fitting.reinit_exclude(self.atom_exclude_types)
self.fitting_net = fitting
if hasattr(self.fitting_net, "reinit_exclude"):
self.fitting_net.reinit_exclude(self.atom_exclude_types)
self.type_map = type_map
super().init_out_stat()

def fitting_output_def(self) -> FittingOutputDef:
"""Get the output def of the fitting net."""
return self.fitting.output_def()
return self.fitting_net.output_def()

def get_rcut(self) -> float:
"""Get the cut-off radius."""
Expand All @@ -73,7 +78,7 @@ def set_case_embd(self, case_idx: int) -> None:
Set the case embedding of this atomic model by the given case_idx,
typically concatenated with the output of the descriptor and fed into the fitting net.
"""
self.fitting.set_case_embd(case_idx)
self.fitting_net.set_case_embd(case_idx)

def mixed_types(self) -> bool:
"""If true, the model
Expand Down Expand Up @@ -166,7 +171,7 @@ def forward_atomic(
nlist,
mapping=mapping,
)
ret = self.fitting(
ret = self.fitting_net(
descriptor,
atype,
gr=rot_mat,
Expand All @@ -177,6 +182,37 @@ def forward_atomic(
)
return ret

def compute_or_load_stat(
self,
sampled_func: Callable[[], list[dict]],
stat_file_path: DPPath | None = None,
compute_or_load_out_stat: bool = True,
) -> None:
"""Compute or load the statistics parameters of the model,
such as mean and standard deviation of descriptors or the energy bias of the fitting net.

Parameters
----------
sampled_func
The lazy sampled function to get data frames from different data systems.
stat_file_path
The path to the stat file.
compute_or_load_out_stat : bool
Whether to compute the output statistics.
If False, it will only compute the input statistics
(e.g. mean and standard deviation of descriptors).
"""
if stat_file_path is not None and self.type_map is not None:
stat_file_path /= " ".join(self.type_map)

wrapped_sampler = self._make_wrapped_sampler(sampled_func)
self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path)
self.fitting_net.compute_input_stats(
wrapped_sampler, stat_file_path=stat_file_path
)
if compute_or_load_out_stat:
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)

def change_type_map(
self, type_map: list[str], model_with_new_type_stat: Any | None = None
) -> None:
Expand All @@ -193,7 +229,31 @@ def change_type_map(
if model_with_new_type_stat is not None
else None,
)
self.fitting.change_type_map(type_map=type_map)
self.fitting_net.change_type_map(type_map=type_map)

def compute_fitting_input_stat(
self,
sample_merged: Callable[[], list[dict]] | list[dict],
stat_file_path: DPPath | None = None,
) -> None:
"""Compute the input statistics (e.g. mean and stddev) for the fittings from packed data.

Parameters
----------
sample_merged : Union[Callable[[], list[dict]], list[dict]]
- list[dict]: A list of data samples from various data systems.
Each element, ``merged[i]``, is a data dictionary containing
``keys``: ``np.ndarray`` originating from the ``i``-th data system.
- Callable[[], list[dict]]: A lazy function that returns data samples
in the above format only when needed.
stat_file_path : Optional[DPPath]
The path to the stat file.
"""
self.fitting_net.compute_input_stats(
sample_merged,
protection=self.data_stat_protect,
stat_file_path=stat_file_path,
)

def serialize(self) -> dict:
dd = super().serialize()
Expand All @@ -204,7 +264,7 @@ def serialize(self) -> dict:
"@version": 2,
"type_map": self.type_map,
"descriptor": self.descriptor.serialize(),
"fitting": self.fitting.serialize(),
"fitting": self.fitting_net.serialize(),
}
)
return dd
Expand All @@ -230,19 +290,19 @@ def deserialize(cls, data: dict[str, Any]) -> "DPAtomicModel":

def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this atomic model."""
return self.fitting.get_dim_fparam()
return self.fitting_net.get_dim_fparam()

def get_dim_aparam(self) -> int:
"""Get the number (dimension) of atomic parameters of this atomic model."""
return self.fitting.get_dim_aparam()
return self.fitting_net.get_dim_aparam()

def has_default_fparam(self) -> bool:
"""Check if the model has default frame parameters."""
return self.fitting.has_default_fparam()
return self.fitting_net.has_default_fparam()

def get_default_fparam(self) -> list[float] | None:
"""Get the default frame parameters."""
return self.fitting.get_default_fparam()
return self.fitting_net.get_default_fparam()

def get_sel_type(self) -> list[int]:
"""Get the selected atom types of this model.
Expand All @@ -251,7 +311,7 @@ def get_sel_type(self) -> list[int]:
to the result of the model.
If returning an empty list, all atom types are selected.
"""
return self.fitting.get_sel_type()
return self.fitting_net.get_sel_type()

def is_aparam_nall(self) -> bool:
"""Check whether the shape of atomic parameters is (nframes, nall, ndim).
Expand Down
40 changes: 39 additions & 1 deletion deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from collections.abc import (
Callable,
)
from typing import (
Any,
)
Expand All @@ -17,6 +20,9 @@
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
)
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.version import (
check_version_compatibility,
)
Expand Down Expand Up @@ -327,6 +333,38 @@ def deserialize(cls, data: dict) -> "LinearEnergyAtomicModel":
data["models"] = models
return super().deserialize(data)

def compute_or_load_stat(
self,
sampled_func: Callable[[], list[dict]],
stat_file_path: DPPath | None = None,
compute_or_load_out_stat: bool = True,
) -> None:
"""Compute or load the statistics parameters of the model.

For LinearEnergyAtomicModel, this first computes input stats for each
sub-model (without output stats), then computes its own output stats.

Parameters
----------
sampled_func
The lazy sampled function to get data frames from different data systems.
stat_file_path
The path to the stat file.
compute_or_load_out_stat : bool
Whether to compute the output statistics.
"""
for md in self.models:
md.compute_or_load_stat(
sampled_func, stat_file_path, compute_or_load_out_stat=False
)

if stat_file_path is not None and self.type_map is not None:
stat_file_path /= " ".join(self.type_map)

if compute_or_load_out_stat:
wrapped_sampler = self._make_wrapped_sampler(sampled_func)
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)

def _compute_weight(
self,
extended_coord: Array,
Expand Down Expand Up @@ -512,4 +550,4 @@ def _compute_weight(
# to handle masked atoms
coef = xp.where(sigma != 0, coef, xp.zeros_like(coef))
self.zbl_weight = coef
return [1 - xp.expand_dims(coef, -1), xp.expand_dims(coef, -1)]
return [1 - xp.expand_dims(coef, axis=-1), xp.expand_dims(coef, axis=-1)]
Loading