diff --git a/.gitignore b/.gitignore index 6382ecedd2..298cd1d90c 100644 --- a/.gitignore +++ b/.gitignore @@ -71,3 +71,4 @@ frozen_model.* # Test system directories system/ +tests/ diff --git a/deepmd/__about__.py b/deepmd/__about__.py new file mode 100644 index 0000000000..0d2f7d41b3 --- /dev/null +++ b/deepmd/__about__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +# Auto-generated stub for development use +__version__ = "dev" diff --git a/deepmd/dpmodel/model/model.py b/deepmd/dpmodel/model/model.py index 8f96e965b0..220d1f4464 100644 --- a/deepmd/dpmodel/model/model.py +++ b/deepmd/dpmodel/model/model.py @@ -104,13 +104,14 @@ def get_standard_model(data: dict) -> EnergyModel: else: raise RuntimeError(f"Unknown fitting type: {fitting_net_type}") - model = modelcls( - descriptor=descriptor, - fitting=fitting, - type_map=data["type_map"], - atom_exclude_types=atom_exclude_types, - pair_exclude_types=pair_exclude_types, - ) + model_kwargs: dict = { + "descriptor": descriptor, + "fitting": fitting, + "type_map": data["type_map"], + "atom_exclude_types": atom_exclude_types, + "pair_exclude_types": pair_exclude_types, + } + model = modelcls(**model_kwargs) return model diff --git a/deepmd/entrypoints/test.py b/deepmd/entrypoints/test.py index 4a0cb27cb1..6676fbf122 100644 --- a/deepmd/entrypoints/test.py +++ b/deepmd/entrypoints/test.py @@ -887,6 +887,8 @@ def test_property( high_prec=True, ) + is_xas = var_name == "xas" + if dp.get_dim_fparam() > 0: data.add( "fparam", dp.get_dim_fparam(), atomic=False, must=True, high_prec=False @@ -894,6 +896,10 @@ def test_property( if dp.get_dim_aparam() > 0: data.add("aparam", dp.get_dim_aparam(), atomic=True, must=True, high_prec=False) + # XAS requires sel_type.npy (per-frame absorbing element type index) + if is_xas: + data.add("sel_type", 1, atomic=False, must=True, high_prec=False) + test_data = data.get_test() mixed_type = data.mixed_type natoms = len(test_data["type"][0]) @@ -918,21 +924,79 @@ def test_property( else: aparam = None + # XAS: per-atom outputs are needed to average over absorbing-element atoms + eval_atomic = has_atom_property or is_xas ret = dp.eval( coord, box, atype, fparam=fparam, aparam=aparam, - atomic=has_atom_property, + atomic=eval_atomic, mixed_type=mixed_type, ) - property = ret[0] + if is_xas: + # ret[1]: per-atom property [numb_test, natoms, task_dim] + atom_prop = ret[1].reshape([numb_test, natoms, dp.task_dim]) + if mixed_type: + atype_frames = atype # [numb_test, natoms] + else: + atype_frames = np.tile(atype, (numb_test, 1)) # [numb_test, natoms] + sel_type_int = test_data["sel_type"][:numb_test, 0].astype(int) + property = np.zeros([numb_test, dp.task_dim], dtype=atom_prop.dtype) + for i in range(numb_test): + t = sel_type_int[i] + mask = atype_frames[i] == t # [natoms] + property[i] = atom_prop[i][mask].sum(axis=0) # sum, not mean + + # Add back the per-(type, edge) energy reference so output is in + # absolute eV (matching label format). xas_e_ref is saved in the + # model checkpoint by XASLoss.compute_output_stats. + try: + # dp is DeepProperty (wrapper); the PT backend is dp.deep_eval, + # and its ModelWrapper is dp.deep_eval.dp. + xas_e_ref = dp.deep_eval.dp.model["Default"].atomic_model.xas_e_ref + except AttributeError: + xas_e_ref = None + if xas_e_ref is not None and fparam is not None: + import torch as _torch + + edge_idx_all = ( + _torch.tensor(fparam.reshape(numb_test, -1)).argmax(dim=-1).numpy() + ) + e_ref_np = xas_e_ref.cpu().numpy() # [ntypes, nfparam, 2] + for i in range(numb_test): + t = sel_type_int[i] + e = int(edge_idx_all[i]) + property[i, :2] += e_ref_np[t, e] + + # Restore intensity dims: pred_abs = pred * intensity_std + intensity_ref + try: + am = dp.deep_eval.dp.model["Default"].atomic_model + xas_intensity_ref = getattr(am, "xas_intensity_ref", None) + xas_intensity_std = getattr(am, "xas_intensity_std", None) + except AttributeError: + xas_intensity_ref = None + xas_intensity_std = None + if xas_intensity_ref is not None and xas_intensity_std is not None and fparam is not None: + import torch as _torch + + edge_idx_all = ( + _torch.tensor(fparam.reshape(numb_test, -1)).argmax(dim=-1).numpy() + ) + int_ref_np = xas_intensity_ref.cpu().numpy() # [ntypes, nfparam, n_pts] + int_std_np = xas_intensity_std.cpu().numpy() # [ntypes, nfparam, n_pts] + for i in range(numb_test): + t = sel_type_int[i] + e = int(edge_idx_all[i]) + property[i, 2:] = property[i, 2:] * int_std_np[t, e] + int_ref_np[t, e] + else: + property = ret[0] property = property.reshape([numb_test, dp.task_dim]) - if has_atom_property: + if has_atom_property and not is_xas: aproperty = ret[1] aproperty = aproperty.reshape([numb_test, natoms * dp.task_dim]) diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index 11a877040d..73b746971d 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -170,6 +170,12 @@ def __init__( if not self.input_param.get("hessian_mode") and not no_jit: model = torch.jit.script(model) self.dp = ModelWrapper(model) + # Filter out loss-related keys that may be present in old training checkpoints. + # This is for backward compatibility with checkpoints saved before the + # XASLoss refactor that removed persistent buffers from the loss module. + state_dict = { + k: v for k, v in state_dict.items() if not k.startswith("loss.") + } self.dp.load_state_dict(state_dict) elif str(self.model_path).endswith(".pth"): extra_files = {"data_modifier.pth": ""} diff --git a/deepmd/pt/infer/inference.py b/deepmd/pt/infer/inference.py index b026cd54c5..d77023222a 100644 --- a/deepmd/pt/infer/inference.py +++ b/deepmd/pt/infer/inference.py @@ -73,4 +73,7 @@ def __init__( self.wrapper = ModelWrapper(self.model) # inference only if JIT: self.wrapper = torch.jit.script(self.wrapper) + # Drop loss-related keys (e.g. loss buffers like XASLoss.e_ref) that + # are not part of the inference-only wrapper. + state_dict = {k: v for k, v in state_dict.items() if not k.startswith("loss.")} self.wrapper.load_state_dict(state_dict) diff --git a/deepmd/pt/loss/__init__.py b/deepmd/pt/loss/__init__.py index 1d25c1e52f..17b2cd37c3 100644 --- a/deepmd/pt/loss/__init__.py +++ b/deepmd/pt/loss/__init__.py @@ -21,6 +21,9 @@ from .tensor import ( TensorLoss, ) +from .xas import ( + XASLoss, +) __all__ = [ "DOSLoss", @@ -31,4 +34,5 @@ "PropertyLoss", "TaskLoss", "TensorLoss", + "XASLoss", ] diff --git a/deepmd/pt/loss/xas.py b/deepmd/pt/loss/xas.py new file mode 100644 index 0000000000..21e90e5af1 --- /dev/null +++ b/deepmd/pt/loss/xas.py @@ -0,0 +1,191 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import logging +from typing import ( + Any, +) + +import torch +import torch.nn.functional as F + +from deepmd.pt.loss.loss import ( + TaskLoss, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.utils.data import ( + DataRequirementItem, +) + +log = logging.getLogger(__name__) + + +class XASLoss(TaskLoss): + """Loss for XAS spectrum fitting via property fitting + sel_type reduction. + + The model outputs per-atom property vectors (atom_xas). For each frame + this loss sums the contributions of atoms matching ``sel_type`` (read from + ``sel_type.npy`` per system) and computes a loss against the per-frame XAS + label. + + Normalisation statistics (``xas_e_ref``, ``xas_intensity_ref/std``, + ``out_bias``, ``out_std``) are computed once before training by + :meth:`DPXASAtomicModel.compute_or_load_out_stat` via the standard + :meth:`compute_or_load_stat` pipeline and stored as model buffers. + + Parameters + ---------- + task_dim : int + Output dimension of the fitting net (e.g. 102 = E_min + E_max + 100 pts). + nfparam : int + Length of the fparam one-hot vector (= number of edge types). + var_name : str + Property name, must match ``property_name`` in the fitting config. + loss_func : str + One of ``smooth_mae``, ``mae``, ``mse``, ``rmse``. + metric : list[str] + Metrics to display during training (absolute scale). + beta : float + Beta parameter for smooth_l1 loss. + pref_energy : float + Weight multiplier for the two energy dimensions (E_min, E_max). + pref_spectrum : float + Weight multiplier for the intensity dimensions (index 2 onward). + smooth_reg : float + Coefficient of the second-order smoothness regulariser applied to the + predicted intensity dimensions in standardised space. 0.0 disables (default). + """ + + def __init__( + self, + task_dim: int, + nfparam: int, + var_name: str = "xas", + loss_func: str = "smooth_mae", + metric: list[str] = ["mae"], + beta: float = 1.0, + pref_energy: float = 1.0, + pref_spectrum: float = 1.0, + smooth_reg: float = 0.0, + **kwargs: Any, + ) -> None: + super().__init__() + self.task_dim = task_dim + self.nfparam = nfparam + self.var_name = var_name + self.loss_func = loss_func + self.metric = metric + self.beta = beta + self.pref_energy = pref_energy + self.pref_spectrum = pref_spectrum + self.smooth_reg = smooth_reg + + def forward( + self, + input_dict: dict[str, torch.Tensor], + model: torch.nn.Module, + label: dict[str, torch.Tensor], + natoms: int, + learning_rate: float = 0.0, + mae: bool = False, + ) -> tuple[dict[str, torch.Tensor], torch.Tensor, dict[str, torch.Tensor]]: + model_pred = model(**input_dict) + + # per-atom outputs: [nf, nloc, task_dim] + atom_prop = model_pred[f"atom_{self.var_name}"] + atype = input_dict["atype"] # [nf, nloc] + + sel_type = label["sel_type"][:, 0].long() # [nf] + + nf, nloc, td = atom_prop.shape + mask_3d = atype.unsqueeze(-1) == sel_type.view(nf, 1, 1) # [nf, nloc, 1] + pred = (atom_prop * mask_3d).sum(dim=1) # [nf, task_dim] + + label_xas = label[self.var_name] # [nf, task_dim] + + # --- per-(type, edge) stat lookup from model buffers --- + fparam = input_dict.get("fparam") + if fparam is not None and fparam.numel() > 0: + edge_idx = fparam.reshape(nf, -1).argmax(dim=-1).clamp(0, self.nfparam - 1) + else: + edge_idx = torch.zeros(nf, dtype=torch.long, device=pred.device) + + am = model.atomic_model + e_ref = am.xas_e_ref # [ntypes, nfparam, 2] + intensity_ref = am.xas_intensity_ref # [ntypes, nfparam, n_pts] + intensity_std = am.xas_intensity_std # [ntypes, nfparam, n_pts] + + _dev = e_ref.device + _sel = sel_type.to(_dev) + _eidx = edge_idx.to(_dev) + + e_ref_frame = e_ref[_sel, _eidx].to(pred.device) # [nf, 2] + intensity_ref_frame = intensity_ref[_sel, _eidx].to(pred.device) # [nf, n_pts] + intensity_std_frame = intensity_std[_sel, _eidx].to(pred.device) # [nf, n_pts] + + # Normalised targets: + # energy dims → chemical shift: label - e_ref + # intensity dims → standardised: (label - ref) / std + label_energy_norm = label_xas[:, :2] - e_ref_frame + label_intens_norm = (label_xas[:, 2:] - intensity_ref_frame) / intensity_std_frame + + def _elem_loss(p: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + if self.loss_func == "smooth_mae": + return F.smooth_l1_loss(p, t, reduction="sum", beta=self.beta) + elif self.loss_func == "mae": + return F.l1_loss(p, t, reduction="sum") + elif self.loss_func == "mse": + return F.mse_loss(p, t, reduction="sum") + elif self.loss_func == "rmse": + return torch.sqrt(F.mse_loss(p, t, reduction="mean")) + else: + raise RuntimeError(f"Unknown loss function: {self.loss_func}") + + loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0] + loss += self.pref_energy * _elem_loss(pred[:, :2], label_energy_norm) + loss += self.pref_spectrum * _elem_loss(pred[:, 2:], label_intens_norm) + + # Smoothness regulariser on standardised intensity dims (scale-invariant). + n_pts = self.task_dim - 2 + if self.smooth_reg > 0.0 and n_pts >= 3: + pi = pred[:, 2:] # [nf, n_pts] in standardised space + curv = pi[:, 2:] - 2.0 * pi[:, 1:-1] + pi[:, :-2] + loss += self.smooth_reg * (curv**2).mean() + + # --- metrics (reported on absolute scale) --- + pred_abs = pred.clone() + pred_abs[:, :2] = pred[:, :2] + e_ref_frame + pred_abs[:, 2:] = pred[:, 2:] * intensity_std_frame + intensity_ref_frame + + more_loss: dict[str, torch.Tensor] = {} + if "mae" in self.metric: + more_loss["mae"] = F.l1_loss( + pred_abs, label_xas, reduction="mean" + ).detach() + if "rmse" in self.metric: + more_loss["rmse"] = torch.sqrt( + F.mse_loss(pred_abs, label_xas, reduction="mean") + ).detach() + + model_pred[self.var_name] = pred_abs + return model_pred, loss, more_loss + + @property + def label_requirement(self) -> list[DataRequirementItem]: + """Declare required data files: xas label + sel_type.""" + return [ + DataRequirementItem( + self.var_name, + ndof=self.task_dim, + atomic=False, + must=True, + high_prec=True, + ), + DataRequirementItem( + "sel_type", + ndof=1, + atomic=False, + must=True, + high_prec=False, + ), + ] diff --git a/deepmd/pt/model/atomic_model/__init__.py b/deepmd/pt/model/atomic_model/__init__.py index 4da9bf781b..fbf7478778 100644 --- a/deepmd/pt/model/atomic_model/__init__.py +++ b/deepmd/pt/model/atomic_model/__init__.py @@ -41,6 +41,7 @@ ) from .property_atomic_model import ( DPPropertyAtomicModel, + DPXASAtomicModel, ) __all__ = [ @@ -51,6 +52,7 @@ "DPEnergyAtomicModel", "DPPolarAtomicModel", "DPPropertyAtomicModel", + "DPXASAtomicModel", "DPZBLLinearEnergyAtomicModel", "LinearEnergyAtomicModel", "PairTabAtomicModel", diff --git a/deepmd/pt/model/atomic_model/property_atomic_model.py b/deepmd/pt/model/atomic_model/property_atomic_model.py index baf9c5b7fc..a4cd4f1483 100644 --- a/deepmd/pt/model/atomic_model/property_atomic_model.py +++ b/deepmd/pt/model/atomic_model/property_atomic_model.py @@ -1,18 +1,31 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import logging +from collections import ( + defaultdict, +) +from collections.abc import ( + Callable, +) from typing import ( Any, ) +import numpy as np import torch from deepmd.pt.model.task.property import ( PropertyFittingNet, ) +from deepmd.pt.utils import ( + env, +) from .dp_atomic_model import ( DPAtomicModel, ) +log = logging.getLogger(__name__) + class DPPropertyAtomicModel(DPAtomicModel): def __init__( @@ -52,3 +65,160 @@ def apply_out_stat( for kk in self.bias_keys: ret[kk] = ret[kk] * out_std[kk][0] + out_bias[kk][0] return ret + + +class DPXASAtomicModel(DPPropertyAtomicModel): + """Atomic model for XAS spectrum fitting. + + Extends :class:`DPPropertyAtomicModel` with per-(absorbing_type, edge) + statistics buffers: ``xas_e_ref`` [ntypes, nfparam, 2], + ``xas_intensity_ref`` and ``xas_intensity_std`` [ntypes, nfparam, n_pts]. + + These buffers are computed by :meth:`compute_or_load_out_stat` (called via + the standard :meth:`compute_or_load_stat` pipeline before training starts) + and saved in the checkpoint so that absolute edge energies and intensity + scales are available at inference time. + """ + + def __init__( + self, descriptor: Any, fitting: Any, type_map: Any, **kwargs: Any + ) -> None: + super().__init__(descriptor, fitting, type_map, **kwargs) + nfparam: int = getattr(fitting, "numb_fparam", 0) + if nfparam > 0: + ntypes: int = len(type_map) + n_pts: int = max(getattr(fitting, "dim_out", 2) - 2, 0) + self.register_buffer( + "xas_e_ref", + torch.zeros(ntypes, nfparam, 2, dtype=torch.float64), + ) + # maps edge_idx (argmax of fparam one-hot) → absorbing atom type index + self.register_buffer( + "xas_edge_to_seltype", + torch.zeros(nfparam, dtype=torch.long), + ) + # per-(type, edge, point) intensity statistics for inference denormalisation + self.register_buffer( + "xas_intensity_ref", + torch.zeros(ntypes, nfparam, n_pts, dtype=torch.float64), + ) + self.register_buffer( + "xas_intensity_std", + torch.ones(ntypes, nfparam, n_pts, dtype=torch.float64), + ) + else: + self.xas_e_ref: torch.Tensor | None = None + self.xas_edge_to_seltype: torch.Tensor | None = None + self.xas_intensity_ref: torch.Tensor | None = None + self.xas_intensity_std: torch.Tensor | None = None + + def compute_or_load_out_stat( + self, + merged: Callable[[], list[dict]] | list[dict], + stat_file_path: Any = None, + ) -> None: + """Compute per-(absorbing_type, edge) statistics from training data. + + Populates ``xas_e_ref``, ``xas_intensity_ref``, ``xas_intensity_std``, + and sets ``out_bias``/``out_std`` so the NN trains in a normalised space. + Falls back to the parent implementation when ``nfparam == 0``. + """ + if self.xas_e_ref is None: + super().compute_or_load_out_stat(merged, stat_file_path) + return + + sampled = merged() if callable(merged) else merged + + nfparam: int = self.xas_e_ref.shape[1] + ntypes: int = self.xas_e_ref.shape[0] + n_pts: int = self.xas_intensity_ref.shape[2] + task_dim: int = 2 + n_pts + var_name: str = self.bias_keys[0] + + accum: dict[tuple[int, int], list] = defaultdict(list) + for frame in sampled: + if ( + var_name not in frame + or "sel_type" not in frame + or "fparam" not in frame + ): + continue + xas = frame[var_name].reshape(-1, task_dim) + sel_type = frame["sel_type"].reshape(-1).long() + edge_idx = frame["fparam"].reshape(-1, nfparam).argmax(dim=-1) + for i in range(xas.shape[0]): + t = int(sel_type[i].item()) + e = int(edge_idx[i].item()) + if 0 <= t < ntypes and 0 <= e < nfparam: + accum[(t, e)].append(xas[i].detach().cpu().numpy()) + + if not accum: + log.warning( + "DPXASAtomicModel.compute_or_load_out_stat: no XAS frames found; " + "stats remain at defaults. Training may be unstable." + ) + return + + e_ref = torch.zeros(ntypes, nfparam, 2, dtype=env.GLOBAL_PT_FLOAT_PRECISION) + e_std = torch.ones(ntypes, nfparam, 2, dtype=env.GLOBAL_PT_FLOAT_PRECISION) + intensity_ref = torch.zeros( + ntypes, nfparam, n_pts, dtype=env.GLOBAL_PT_FLOAT_PRECISION + ) + intensity_std = torch.ones( + ntypes, nfparam, n_pts, dtype=env.GLOBAL_PT_FLOAT_PRECISION + ) + + for (t, e), vals in accum.items(): + arr = np.array(vals) # [n, task_dim] + e_ref[t, e] = torch.tensor( + np.mean(arr[:, :2], axis=0), dtype=env.GLOBAL_PT_FLOAT_PRECISION + ) + e_std[t, e] = torch.tensor( + np.std(arr[:, :2], axis=0).clip(min=1.0), + dtype=env.GLOBAL_PT_FLOAT_PRECISION, + ) + if n_pts > 0: + intensity_ref[t, e] = torch.tensor( + np.mean(arr[:, 2:], axis=0), dtype=env.GLOBAL_PT_FLOAT_PRECISION + ) + intensity_std[t, e] = torch.tensor( + np.std(arr[:, 2:], axis=0).clip(min=1e-6), + dtype=env.GLOBAL_PT_FLOAT_PRECISION, + ) + log.info( + f"DPXASAtomicModel stats: type={t}, edge={e} | " + f"E_ref=[{float(e_ref[t,e,0]):.2f}, {float(e_ref[t,e,1]):.2f}] eV | " + f"n={len(vals)}" + ) + + self.xas_e_ref.copy_(e_ref.to(self.xas_e_ref.dtype)) + self.xas_intensity_ref.copy_(intensity_ref.to(self.xas_intensity_ref.dtype)) + self.xas_intensity_std.copy_(intensity_std.to(self.xas_intensity_std.dtype)) + + # Legacy fallback mapping used by XASModel.forward when sel_type is not provided. + if self.xas_edge_to_seltype is not None: + mapping = torch.zeros( + nfparam, dtype=torch.long, device=self.xas_edge_to_seltype.device + ) + for t, e in accum.keys(): + mapping[e] = t + self.xas_edge_to_seltype.copy_(mapping) + + key_idx = self.bias_keys.index(var_name) + populated = e_std.abs().gt(1.0) + e_std_global = ( + e_std[populated].mean(dim=0) + if populated.any() + else torch.ones(2, dtype=e_std.dtype) + ) + with torch.no_grad(): + self.out_bias[key_idx, :, :2] = 0.0 + self.out_std[key_idx, :, :2] = e_std_global.to(self.out_std.dtype) + if n_pts > 0: + self.out_bias[key_idx, :, 2:] = 0.0 + self.out_std[key_idx, :, 2:] = 1.0 + + log.info( + f"DPXASAtomicModel: stats computed for {len(accum)} (type, edge) groups. " + f"out_std[:2]={e_std_global.tolist()} eV." + ) diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 24075412db..f577e4f0cf 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -73,6 +73,9 @@ SpinEnergyModel, SpinModel, ) +from .xas_model import ( + XASModel, +) def _get_standard_model_components(model_params: dict, ntypes: int) -> tuple: @@ -269,19 +272,23 @@ def get_standard_model(model_params: dict) -> BaseModel: elif fitting_net_type in ["ener", "direct_force_ener"]: modelcls = EnergyModel elif fitting_net_type == "property": - modelcls = PropertyModel + property_name = model_params.get("fitting_net", {}).get( + "property_name", model_params.get("fitting_net", {}).get("var_name", "") + ) + modelcls = XASModel if property_name == "xas" else PropertyModel else: raise RuntimeError(f"Unknown fitting type: {fitting_net_type}") - model = modelcls( - descriptor=descriptor, - fitting=fitting, - type_map=model_params["type_map"], - atom_exclude_types=atom_exclude_types, - pair_exclude_types=pair_exclude_types, - preset_out_bias=preset_out_bias, - data_stat_protect=data_stat_protect, - ) + model_kwargs: dict[str, Any] = { + "descriptor": descriptor, + "fitting": fitting, + "type_map": model_params["type_map"], + "atom_exclude_types": atom_exclude_types, + "pair_exclude_types": pair_exclude_types, + "preset_out_bias": preset_out_bias, + "data_stat_protect": data_stat_protect, + } + model = modelcls(**model_kwargs) if model_params.get("hessian_mode"): model.enable_hessian() model.model_def_script = json.dumps(model_params_old) @@ -315,6 +322,7 @@ def get_model(model_params: dict) -> Any: "PolarModel", "SpinEnergyModel", "SpinModel", + "XASModel", "get_model", "make_hessian_model", "make_model", diff --git a/deepmd/pt/model/model/xas_model.py b/deepmd/pt/model/model/xas_model.py new file mode 100644 index 0000000000..2577ed0555 --- /dev/null +++ b/deepmd/pt/model/model/xas_model.py @@ -0,0 +1,184 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import torch + +from deepmd.pt.model.atomic_model import ( + DPXASAtomicModel, +) +from deepmd.pt.model.model.model import ( + BaseModel, +) + +from .property_model import ( + PropertyModel, +) + + +@BaseModel.register("xas") +class XASModel(PropertyModel): + """Model for XAS spectrum fitting. + + Identical to :class:`PropertyModel` but uses :class:`DPXASAtomicModel` + as the underlying atomic model, which carries the per-(absorbing_type, + edge) energy reference buffer ``xas_e_ref`` in the checkpoint. This + buffers are populated by :meth:`DPXASAtomicModel.compute_or_load_out_stat` + (via the standard stat pipeline) before training starts and restored at + inference time so that absolute edge energies are available without any + external reference files. + + Two corrections are applied in ``forward`` that are absent in the generic + :class:`PropertyModel`: + + 1. **sel_type reduction** — only atoms of the absorbing type contribute to + the reduced spectrum. + + 2. **e_ref restoration** — during training the energy dimensions (E_min, + E_max at indices 0–1) are trained against chemical shifts + ``label − e_ref``. At inference we add ``e_ref`` back so the output + is in absolute edge-energy units (eV). + """ + + model_type = "xas" + + def __init__( + self, + descriptor: Any, + fitting: Any, + type_map: Any, + **kwargs: Any, + ) -> None: + xas_atomic = DPXASAtomicModel(descriptor, fitting, type_map, **kwargs) + super().__init__( + descriptor, fitting, type_map, atomic_model_=xas_atomic, **kwargs + ) + + def forward( + self, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, torch.Tensor]: + """Forward pass with XAS-specific reductions. + + For inference with multi-type edges, use :meth:`forward_xas` instead + which accepts an explicit ``sel_type`` argument. + """ + return self.forward_xas( + coord, atype, box, fparam, aparam, do_atomic_virial, sel_type=None + ) + + def forward_xas( + self, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + sel_type: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """Forward pass with XAS-specific reductions. + + Parameters + ---------- + coord : torch.Tensor + Coordinates, shape [nf, nloc, 3]. + atype : torch.Tensor + Atom types, shape [nf, nloc]. + box : torch.Tensor | None + Box vectors, shape [nf, 9]. + fparam : torch.Tensor | None + Frame parameters (one-hot edge encoding), shape [nf, nfparam]. + aparam : torch.Tensor | None + Atom parameters, shape [nf, nloc, naparam]. + do_atomic_virial : bool + Whether to compute atomic virial. + sel_type : torch.Tensor | None + Absorbing atom type per frame, shape [nf]. Required when multiple + element types share the same edge (e.g., K-edge for H/Li/Be/...). + If None, falls back to legacy ``xas_edge_to_seltype`` mapping which + only works when each edge has exactly one absorbing element type. + + Returns + ------- + dict[str, torch.Tensor] + Model predictions including reduced XAS spectrum. + """ + # Call forward_common directly (same as PropertyModel.forward) + model_ret = self.forward_common( + coord, + atype, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + + var_name = self.get_var_name() + model_predict: dict[str, torch.Tensor] = {} + model_predict[f"atom_{var_name}"] = model_ret[var_name] + + if fparam is None or fparam.numel() == 0: + model_predict[var_name] = model_ret[f"{var_name}_redu"] + return model_predict + + am = self.atomic_model + atom_xas = model_ret[var_name] # [nf, nloc, task_dim] + nf = atype.shape[0] + + # Derive edge_idx from one-hot fparam + nfparam = fparam.reshape(nf, -1).shape[-1] + edge_idx = fparam.reshape(nf, -1).argmax(dim=-1).clamp(0, nfparam - 1) # [nf] + + # Determine absorbing atom type per frame + if sel_type is not None: + # Explicit sel_type provided — use directly + sel_type_per_frame = sel_type.to(atype.device).long() + else: + # Legacy fallback: use xas_edge_to_seltype mapping + # WARNING: only correct when each edge has exactly one absorbing type + edge_to_sel = getattr(am, "xas_edge_to_seltype", None) + if edge_to_sel is None: + return model_predict + sel_type_per_frame = edge_to_sel[edge_idx.to(edge_to_sel.device)].to( + atype.device + ) + + # Sum only sel_type atoms per frame + mask_3d = atype.unsqueeze(-1) == sel_type_per_frame.view(nf, 1, 1) # [nf, nloc, 1] + xas_redu = (atom_xas * mask_3d.to(atom_xas.dtype)).sum(dim=1) # [nf, task_dim] + + xas_redu = xas_redu.clone() + + # Restore energy dims to absolute eV: pred_abs = pred + e_ref + xas_e_ref = getattr(am, "xas_e_ref", None) + if xas_e_ref is not None: + e_ref_frame = xas_e_ref[ + sel_type_per_frame.to(xas_e_ref.device), + edge_idx.to(xas_e_ref.device), + ].to(dtype=xas_redu.dtype, device=xas_redu.device) # [nf, 2] + xas_redu[:, :2] = xas_redu[:, :2] + e_ref_frame + + # Restore intensity dims to absolute scale: + # pred_abs = pred_standardised * intensity_std + intensity_ref + xas_intensity_ref = getattr(am, "xas_intensity_ref", None) + if xas_intensity_ref is not None: + xas_intensity_std = am.xas_intensity_std + _st = sel_type_per_frame.to(xas_intensity_ref.device) + _ei = edge_idx.to(xas_intensity_ref.device) + int_ref = xas_intensity_ref[_st, _ei].to( + dtype=xas_redu.dtype, device=xas_redu.device + ) # [nf, n_pts] + int_std = xas_intensity_std[_st, _ei].to( + dtype=xas_redu.dtype, device=xas_redu.device + ) # [nf, n_pts] + xas_redu[:, 2:] = xas_redu[:, 2:] * int_std + int_ref + + model_predict[var_name] = xas_redu + return model_predict diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index 7aac7b9a29..7c8919862b 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -431,6 +431,7 @@ def __init__( type_map: list[str] | None = None, use_aparam_as_mask: bool = False, default_fparam: list[float] | None = None, + normalize_fparam: bool = True, **kwargs: Any, ) -> None: super().__init__() @@ -451,6 +452,7 @@ def __init__( self.seed = seed self.type_map = type_map self.use_aparam_as_mask = use_aparam_as_mask + self.normalize_fparam = normalize_fparam # order matters, should be place after the assignment of ntypes self.reinit_exclude(exclude_types) self.trainable = trainable @@ -622,6 +624,7 @@ def serialize(self) -> dict: "trainable": [self.trainable] * (len(self.neuron) + 1), "layer_name": None, "use_aparam_as_mask": self.use_aparam_as_mask, + "normalize_fparam": self.normalize_fparam, "spin": None, } @@ -786,9 +789,10 @@ def _forward_common( ) fparam = fparam.view([nf, self.numb_fparam]) nb, _ = fparam.shape - t_fparam_avg = self._extend_f_avg_std(self.fparam_avg, nb) - t_fparam_inv_std = self._extend_f_avg_std(self.fparam_inv_std, nb) - fparam = (fparam - t_fparam_avg) * t_fparam_inv_std + if self.normalize_fparam: + t_fparam_avg = self._extend_f_avg_std(self.fparam_avg, nb) + t_fparam_inv_std = self._extend_f_avg_std(self.fparam_inv_std, nb) + fparam = (fparam - t_fparam_avg) * t_fparam_inv_std fparam = torch.tile(fparam.reshape([nf, 1, -1]), [1, nloc, 1]) xx = torch.cat( [xx, fparam], diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 8d16e1c7ea..a5eeb29e89 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -43,6 +43,7 @@ PropertyLoss, TaskLoss, TensorLoss, + XASLoss, ) from deepmd.pt.model.model import ( get_model, @@ -94,9 +95,13 @@ get_optimizer_state_dict, set_optimizer_state_dict, ) -from torch.distributed.fsdp import ( - fully_shard, -) + +try: + from torch.distributed.fsdp import ( + fully_shard, + ) +except ImportError: + fully_shard = None from torch.distributed.optim import ( ZeroRedundancyOptimizer, ) @@ -1752,6 +1757,11 @@ def get_loss( loss_params["var_name"] = var_name loss_params["intensive"] = intensive return PropertyLoss(**loss_params) + elif loss_type == "xas": + loss_params["task_dim"] = _model.get_task_dim() + loss_params["var_name"] = _model.get_var_name() + loss_params["nfparam"] = _model.get_fitting_net().numb_fparam + return XASLoss(**loss_params) else: loss_params["starter_learning_rate"] = start_lr return TaskLoss.get_class_by_type(loss_type).get_loss(loss_params) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index b12bc7ef6f..095c71413c 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1929,6 +1929,12 @@ def fitting_property() -> list[Argument]: doc_trainable = "Whether the parameters in the fitting net are trainable. This option can be\n\n\ - bool: True if all parameters of the fitting net are trainable, False otherwise.\n\n\ - list of bool: Specifies if each layer is trainable. Since the fitting net is composed by hidden layers followed by a output layer, the length of this list should be equal to len(`neuron`)+1." + doc_normalize_fparam = ( + "Whether to normalize fparam by subtracting its mean and dividing by its std " + "computed from the training data. Set to False when fparam is a one-hot " + "encoding (e.g. edge-type in XAS), where normalization would distort the " + "discrete identity of each category." + ) return [ Argument("numb_fparam", int, optional=True, default=0, doc=doc_numb_fparam), Argument("numb_aparam", int, optional=True, default=0, doc=doc_numb_aparam), @@ -1979,6 +1985,13 @@ def fitting_property() -> list[Argument]: default=True, doc=doc_trainable, ), + Argument( + "normalize_fparam", + bool, + optional=True, + default=True, + doc=doc_normalize_fparam, + ), ] @@ -3465,6 +3478,42 @@ def loss_dos() -> list[Argument]: ] +@loss_args_plugin.register("xas", doc=doc_only_pt_supported) +def loss_xas() -> list[Argument]: + doc_loss_func = ( + "The loss function to minimize: 'smooth_mae' (default), 'mae', 'mse', 'rmse'." + ) + doc_metric = "Metrics to display during training. Supported: 'mae', 'rmse'." + doc_beta = "Beta parameter for smooth_l1 loss." + doc_pref_energy = ( + "Weight multiplier for the two energy dimensions (E_min, E_max at indices 0-1). " + "Default 1.0. Decrease this if energy-shift terms dominate the loss and " + "spectral shape is undertrained." + ) + doc_pref_spectrum = ( + "Weight multiplier for the intensity dimensions (index 2 onward). " + "Default 1.0. Increase this to focus training on spectral shape." + ) + doc_smooth_reg = ( + "Coefficient of the second-order smoothness regulariser applied to the " + "predicted intensity dimensions in standardised space. Penalises curvature " + "(pred[i+1] - 2*pred[i] + pred[i-1])^2 to suppress high-frequency wiggles. " + "0.0 disables it (default). Typical range: 1e-4 to 1e-2." + ) + return [ + Argument( + "loss_func", str, optional=True, default="smooth_mae", doc=doc_loss_func + ), + Argument("metric", list, optional=True, default=["mae"], doc=doc_metric), + Argument("beta", float, optional=True, default=1.0, doc=doc_beta), + Argument("pref_energy", float, optional=True, default=1.0, doc=doc_pref_energy), + Argument( + "pref_spectrum", float, optional=True, default=1.0, doc=doc_pref_spectrum + ), + Argument("smooth_reg", float, optional=True, default=0.0, doc=doc_smooth_reg), + ] + + @loss_args_plugin.register("property") def loss_property() -> list[Argument]: doc_loss_func = "The loss function to minimize, such as 'mae','smooth_mae'." diff --git a/examples/xas/train/README.md b/examples/xas/train/README.md new file mode 100644 index 0000000000..a4a96c757a --- /dev/null +++ b/examples/xas/train/README.md @@ -0,0 +1,108 @@ +# XAS Spectrum Fitting with DeePMD-kit + +This example shows how to train a model to predict X-ray absorption spectra (XAS) +from atomic structure using DeePMD-kit's `property` fitting net. + +## Concept + +- The model predicts a 102-dimensional output per atom: `[E_min, E_max, I_0, …, I_99]` +- During training, per-atom outputs are averaged over atoms of the **absorbing element** + (identified by `sel_type.npy` in each training system) +- The edge type (K, L1, L2, …) is provided as a frame-level parameter `fparam` +- One training system per `(element, edge)` pair + +## Quick Start + +**1. Generate example training data** + +```bash +python gen_data.py +``` + +This creates `data/Fe_K/` and `data/O_K/` with 50 frames each. + +**2. Train the model** + +```bash +dp train input.json +``` + +**3. Freeze the model** + +```bash +dp freeze -o model.pb +``` + +**4. Test the model** + +```bash +dp test -m model.pb -s data/Fe_K -n 10 +dp test -m model.pb -s data/O_K -n 10 +``` + +`dp test` automatically detects `sel_type.npy` and applies element-wise averaging +before computing the error metrics. + +## Data Format + +Each system directory must contain: + +``` +data/Fe_K/ +├── type.raw # atom type indices, one per line (int) +├── type_map.raw # element symbols, one per line +└── set.000/ + ├── coord.npy # [nframes, natoms*3] Cartesian coordinates (Å) + ├── box.npy # [nframes, 9] cell vectors (Å), row-major + ├── fparam.npy # [nframes, nfparam] edge one-hot encoding + ├── sel_type.npy # [nframes, 1] absorbing element type index (float64) + └── xas.npy # [nframes, 102] XAS label: [E_min, E_max, I_0..I_99] +``` + +### `sel_type.npy` + +The type index of the absorbing element, stored as float64, constant per system. + +``` +Fe is type 0 → sel_type.npy filled with 0.0 +O is type 1 → sel_type.npy filled with 1.0 +``` + +### `xas.npy` label layout (`task_dim = 102`) + +| Column | Meaning | +| ----------- | ------------------------------------------------------------------ | +| `xas[i,0]` | `E_min` (eV) — lower bound of energy grid | +| `xas[i,1]` | `E_max` (eV) — upper bound of energy grid | +| `xas[i,2:]` | `I[0..99]` — 100 intensity values on `linspace(E_min, E_max, 100)` | + +### `fparam.npy` edge encoding (`nfparam = 3`) + +| Edge | Encoding | +| ---- | --------- | +| K | `[1,0,0]` | +| L1 | `[0,1,0]` | +| L2 | `[0,0,1]` | + +Extend with more entries for additional edges and set `numb_fparam` accordingly. + +## Input Parameters + +Key fields in `input.json`: + +| Parameter | Description | +| ------------------------- | ------------------------------------------------------------- | +| `fitting_net.type` | Must be `"property"` | +| `fitting_net.task_dim` | `102` (2 energy bounds + 100 intensities) | +| `fitting_net.intensive` | `true` — per-atom outputs are **averaged**, not summed | +| `fitting_net.numb_fparam` | Number of edge-type features (3 for K/L1/L2) | +| `loss.type` | `"xas"` — uses `sel_type.npy` for element-selective averaging | +| `loss.loss_func` | `"smooth_mae"` (recommended) or `"mse"` | + +## Extending to More Elements / Edges + +- Add a new system directory per `(element, edge)` pair +- Set `sel_type.npy` to the type index of the absorbing element in that system +- Set `fparam.npy` to the one-hot vector for the corresponding edge +- List all system paths under `training.training_data.systems` +- Increase `numb_fparam` if adding new edge types diff --git a/examples/xas/train/gen_data.py b/examples/xas/train/gen_data.py new file mode 100644 index 0000000000..82c81319b8 --- /dev/null +++ b/examples/xas/train/gen_data.py @@ -0,0 +1,166 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Generate example XAS training data for a Fe-O system. + +This script shows the required data format for XAS spectrum fitting. + +Data layout +----------- +One training system per (element, edge) pair: + + data/Fe_K/ — Fe K-edge XAS + data/O_K/ — O K-edge XAS + +Each system directory contains: + + type.raw — atom type indices (int, one per line) + type_map.raw — element symbols, one per line + set.000/ + coord.npy — [nframes, natoms*3] Cartesian coordinates (Å) + box.npy — [nframes, 9] cell vectors (Å), row-major + fparam.npy — [nframes, nfparam] edge encoding (one-hot or continuous) + sel_type.npy — [nframes, 1] type index of absorbing element (float) + xas.npy — [nframes, task_dim] XAS label: [E_min, E_max, I_0..I_99] + +Label format (task_dim = 102) +------------------------------ + xas[i, 0] = E_min (eV) — lower bound of the energy grid for frame i + xas[i, 1] = E_max (eV) — upper bound of the energy grid for frame i + xas[i, 2:] = I (arb. units) — 100 equally-spaced intensity values + on the grid linspace(E_min, E_max, 100) + +fparam encoding (nfparam = 3 for K/L1/L2 edges) +------------------------------------------------- + K-edge → [1, 0, 0] + L1-edge → [0, 1, 0] + L2-edge → [0, 0, 1] + (extend as needed; use numb_fparam in input.json accordingly) + +sel_type.npy +------------ + Integer type index of the absorbing element, stored as float64. + All frames in a system must share the same value (it is constant per system). + Example: Fe is type 0 → sel_type.npy filled with 0.0 + O is type 1 → sel_type.npy filled with 1.0 +""" + +import os + +import numpy as np + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- +nframes = 50 # number of frames per system +numb_pts = 100 # energy grid points +task_dim = numb_pts + 2 # E_min + E_max + 100 intensities +nfparam = 3 # K / L1 / L2 one-hot +natoms = 8 # 4 Fe (type 0) + 4 O (type 1) +box_size = 4.0 # Å + +rng = np.random.default_rng(42) + +# Equilibrium positions: simple rock-salt-like arrangement +base_pos = np.array( + [ + [0.0, 0.0, 0.0], + [2.0, 2.0, 0.0], + [2.0, 0.0, 2.0], + [0.0, 2.0, 2.0], # Fe + [1.0, 1.0, 1.0], + [3.0, 3.0, 1.0], + [3.0, 1.0, 3.0], + [1.0, 3.0, 3.0], # O + ] +) + +coords = base_pos[None] + rng.normal(0, 0.1, (nframes, natoms, 3)) +box = np.tile(np.diag([box_size] * 3).reshape(9), (nframes, 1)) + +type_arr = np.array([0, 0, 0, 0, 1, 1, 1, 1], dtype=int) # Fe Fe Fe Fe O O O O + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- +def gaussian_spectrum(peak_eV, e_min, e_max, npts=100, width_frac=0.10): + grid = np.linspace(e_min, e_max, npts) + width = (e_max - e_min) * width_frac + return np.exp(-0.5 * ((grid - peak_eV) / width) ** 2) + + +def write_system( + path: str, + sel_type_idx: int, + atom_slice, # slice object selecting absorbing atoms + e_min: float, + e_max: float, + peak_center: float, + peak_shift_scale: float, + fparam_vec, # 1-D array of length nfparam (one-hot edge encoding) +): + os.makedirs(f"{path}/set.000", exist_ok=True) + + # --- structure --- + np.savetxt(f"{path}/type.raw", type_arr, fmt="%d") + with open(f"{path}/type_map.raw", "w") as f: + f.write("Fe\nO\n") + np.save(f"{path}/set.000/box.npy", box.astype(np.float64)) + np.save( + f"{path}/set.000/coord.npy", + coords.reshape(nframes, natoms * 3).astype(np.float64), + ) + + # --- fparam: same edge for all frames --- + fparam = np.tile(fparam_vec, (nframes, 1)).astype(np.float64) + np.save(f"{path}/set.000/fparam.npy", fparam) + + # --- sel_type: constant per system --- + sel = np.full((nframes, 1), float(sel_type_idx), dtype=np.float64) + np.save(f"{path}/set.000/sel_type.npy", sel) + + # --- xas labels --- + labels = np.zeros((nframes, task_dim), dtype=np.float64) + for i in range(nframes): + # peak position shifts slightly with mean x-coordinate of absorbing atoms + mean_x = coords[i, atom_slice, 0].mean() + peak = peak_center + mean_x * peak_shift_scale + spectrum = gaussian_spectrum(peak, e_min, e_max) + labels[i, 0] = e_min + labels[i, 1] = e_max + labels[i, 2:] = spectrum + np.save(f"{path}/set.000/xas.npy", labels) + + print(f" {path}:") + print(f" sel_type = {sel_type_idx} fparam = {fparam_vec.tolist()}") + print(f" xas.npy shape = {labels.shape}") + + +# --------------------------------------------------------------------------- +# Generate Fe K-edge and O K-edge systems +# --------------------------------------------------------------------------- +print("Generating example XAS training data...") + +write_system( + path="data/Fe_K", + sel_type_idx=0, # Fe is type 0 + atom_slice=slice(0, 4), # first 4 atoms are Fe + e_min=7100.0, # Fe K-edge region (eV) + e_max=7250.0, + peak_center=7112.0, # Fe K-edge energy + peak_shift_scale=2.0, # chemical shift ∝ local environment + fparam_vec=np.array([1.0, 0.0, 0.0]), # K-edge one-hot +) + +write_system( + path="data/O_K", + sel_type_idx=1, # O is type 1 + atom_slice=slice(4, 8), # last 4 atoms are O + e_min=525.0, # O K-edge region (eV) + e_max=560.0, + peak_center=535.0, # O K-edge energy + peak_shift_scale=0.5, + fparam_vec=np.array([1.0, 0.0, 0.0]), # also K-edge +) + +print(f"\nDone. {nframes} frames per system, task_dim={task_dim}, nfparam={nfparam}") +print("Data written to ./data/Fe_K/ and ./data/O_K/") diff --git a/examples/xas/train/input.json b/examples/xas/train/input.json new file mode 100644 index 0000000000..f58417b478 --- /dev/null +++ b/examples/xas/train/input.json @@ -0,0 +1,77 @@ +{ + "_comment": "XAS spectrum fitting example — Fe-O system, Fe K-edge + O K-edge", + + "model": { + "type_map": [ + "Fe", + "O" + ], + "descriptor": { + "type": "se_e2_a", + "rcut": 6.0, + "rcut_smth": 0.5, + "sel": [ + 40, + 40 + ], + "neuron": [ + 25, + 50, + 100 + ], + "resnet_dt": false, + "axis_neuron": 16, + "seed": 1 + }, + "fitting_net": { + "_comment": "property fitting with task_dim=102: [E_min, E_max, I_0, ..., I_99]", + "type": "property", + "property_name": "xas", + "task_dim": 102, + "_comment_intensive": "intensive=true: per-atom outputs are averaged (not summed)", + "intensive": true, + "_comment_fparam": "fparam encodes edge type: 1-hot vector, e.g. [1,0,0]=K, [0,1,0]=L1, [0,0,1]=L2", + "numb_fparam": 3, + "neuron": [ + 128, + 128, + 128 + ], + "resnet_dt": true, + "seed": 1 + } + }, + "loss": { + "_comment": "xas loss: reads sel_type.npy to select which element to reduce over", + "type": "xas", + "loss_func": "smooth_mae", + "metric": [ + "mae", + "rmse" + ] + }, + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.001, + "stop_lr": 1e-8, + "decay_rate": 0.95 + }, + "training": { + "training_data": { + "_comment": "one system per (element, edge) pair", + "systems": [ + "./data/Fe_K/", + "./data/O_K/" + ], + "batch_size": "auto" + }, + "numb_steps": 200000, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 1000, + "save_freq": 10000, + "save_ckpt": "model.ckpt", + "stat_file": "stat_files" + } +}