From c4ed1f906d2c758aa77f5e2e518b613dcd8d429c Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Mon, 30 Mar 2026 17:45:05 +0800 Subject: [PATCH 1/5] feat(pt/dpmodel): add sequential_update for dpa3 --- deepmd/dpmodel/descriptor/dpa3.py | 18 + deepmd/dpmodel/descriptor/repflows.py | 451 ++++++++++++++++++ deepmd/pt/model/descriptor/dpa3.py | 1 + deepmd/pt/model/descriptor/repflow_layer.py | 411 ++++++++++++++++ deepmd/pt/model/descriptor/repflows.py | 3 + deepmd/utils/argcheck.py | 14 + .../tests/consistent/descriptor/test_dpa3.py | 10 + source/tests/pt/model/test_dpa3.py | 9 + 8 files changed, 917 insertions(+) diff --git a/deepmd/dpmodel/descriptor/dpa3.py b/deepmd/dpmodel/descriptor/dpa3.py index 5f5aea50e5..88f56213ee 100644 --- a/deepmd/dpmodel/descriptor/dpa3.py +++ b/deepmd/dpmodel/descriptor/dpa3.py @@ -170,6 +170,12 @@ class RepFlowArgs: In the dynamic selection case, neighbor-scale normalization will use `e_sel / sel_reduce_factor` or `a_sel / sel_reduce_factor` instead of the raw `e_sel` or `a_sel` values, accommodating larger selection numbers. + sequential_update : bool, optional + Whether to use sequential update mode within each repflow layer. + When True, updates are applied sequentially: edge self → angle self (using updated edge) + → edge angle (using updated angle) → node (using final edge), + instead of the default parallel mode where all updates use original embeddings. + Currently only supports ``update_style='res_residual'``. """ def __init__( @@ -201,6 +207,7 @@ def __init__( use_exp_switch: bool = False, use_dynamic_sel: bool = False, sel_reduce_factor: float = 10.0, + sequential_update: bool = False, ) -> None: self.n_dim = n_dim self.e_dim = e_dim @@ -231,6 +238,15 @@ def __init__( self.use_exp_switch = use_exp_switch self.use_dynamic_sel = use_dynamic_sel self.sel_reduce_factor = sel_reduce_factor + self.sequential_update = sequential_update + if self.sequential_update: + if self.update_style != "res_residual": + raise ValueError( + "sequential_update only supports update_style='res_residual', " + f"got '{self.update_style}'!" + ) + if not self.update_angle: + raise ValueError("sequential_update requires update_angle=True!") def __getitem__(self, key: str) -> Any: if hasattr(self, key): @@ -266,6 +282,7 @@ def serialize(self) -> dict: "use_exp_switch": self.use_exp_switch, "use_dynamic_sel": self.use_dynamic_sel, "sel_reduce_factor": self.sel_reduce_factor, + "sequential_update": self.sequential_update, } @classmethod @@ -404,6 +421,7 @@ def init_subclass_params(sub_data: dict | Any, sub_class: type) -> Any: use_exp_switch=self.repflow_args.use_exp_switch, use_dynamic_sel=self.repflow_args.use_dynamic_sel, sel_reduce_factor=self.repflow_args.sel_reduce_factor, + sequential_update=self.repflow_args.sequential_update, use_loc_mapping=use_loc_mapping, exclude_types=exclude_types, env_protection=env_protection, diff --git a/deepmd/dpmodel/descriptor/repflows.py b/deepmd/dpmodel/descriptor/repflows.py index 30637dc75a..5788840279 100644 --- a/deepmd/dpmodel/descriptor/repflows.py +++ b/deepmd/dpmodel/descriptor/repflows.py @@ -230,6 +230,7 @@ def __init__( use_exp_switch: bool = False, use_dynamic_sel: bool = False, sel_reduce_factor: float = 10.0, + sequential_update: bool = False, use_loc_mapping: bool = True, seed: int | list[int] | None = None, trainable: bool = True, @@ -268,6 +269,7 @@ def __init__( self.use_dynamic_sel = use_dynamic_sel self.use_loc_mapping = use_loc_mapping self.sel_reduce_factor = sel_reduce_factor + self.sequential_update = sequential_update if self.use_dynamic_sel and not self.smooth_edge_update: raise NotImplementedError( "smooth_edge_update must be True when use_dynamic_sel is True!" @@ -339,6 +341,7 @@ def __init__( optim_update=self.optim_update, use_dynamic_sel=self.use_dynamic_sel, sel_reduce_factor=self.sel_reduce_factor, + sequential_update=self.sequential_update, smooth_edge_update=self.smooth_edge_update, seed=child_seed(child_seed(seed, 1), ii), trainable=trainable, @@ -757,6 +760,7 @@ def serialize(self) -> dict: "smooth_edge_update": self.smooth_edge_update, "use_dynamic_sel": self.use_dynamic_sel, "sel_reduce_factor": self.sel_reduce_factor, + "sequential_update": self.sequential_update, "use_loc_mapping": self.use_loc_mapping, # variables "edge_embd": self.edge_embd.serialize(), @@ -905,6 +909,7 @@ def __init__( optim_update: bool = True, use_dynamic_sel: bool = False, sel_reduce_factor: float = 10.0, + sequential_update: bool = False, smooth_edge_update: bool = False, activation_function: str = "silu", update_style: str = "res_residual", @@ -954,8 +959,15 @@ def __init__( self.smooth_edge_update = smooth_edge_update self.use_dynamic_sel = use_dynamic_sel self.sel_reduce_factor = sel_reduce_factor + self.sequential_update = sequential_update self.dynamic_e_sel = self.nnei / self.sel_reduce_factor self.dynamic_a_sel = self.a_sel / self.sel_reduce_factor + if self.sequential_update and self.update_style != "res_residual": + raise NotImplementedError( + "sequential_update only supports update_style='res_residual'!" + ) + if self.sequential_update and not self.update_angle: + raise NotImplementedError("sequential_update requires update_angle=True!") assert update_residual_init in [ "norm", @@ -1342,6 +1354,418 @@ def optim_edge_update_dynamic( result_update = bias + sub_node_update + sub_edge_update + sub_node_ext_update return result_update + def _call_sequential( + self, + xp: object, + node_ebd: Array, + node_ebd_ext: Array, + edge_ebd: Array, + h2: Array, + angle_ebd: Array, + nlist: Array, + nlist_mask: Array, + sw: Array, + a_nlist_mask: Array, + a_sw: Array, + nei_node_ebd: Array, + n2e_index: Array, + n_ext2e_index: Array, + n2a_index: Array, + eij2a_index: Array, + eik2a_index: Array, + nb: int, + nloc: int, + nnei: int, + nall: int, + n_edge: int, + ) -> tuple[Array, Array, Array]: + """Sequential update path: edge_self → angle_self → edge_angle → node. + + Only supports update_style='res_residual'. + """ + assert self.angle_self_linear is not None + assert self.edge_angle_linear1 is not None + assert self.edge_angle_linear2 is not None + + # ==================================================================== + # Phase 1: Edge self update (uses original node_ebd, edge_ebd) + # ==================================================================== + if not self.optim_update: + if not self.use_dynamic_sel: + edge_info = xp.concat( + [ + xp.tile( + xp.reshape(node_ebd, (nb, nloc, 1, self.n_dim)), + (1, 1, self.nnei, 1), + ), + nei_node_ebd, + edge_ebd, + ], + axis=-1, + ) + else: + edge_info = xp.concat( + [ + xp.take( + xp.reshape(node_ebd, (-1, self.n_dim)), + n2e_index, + axis=0, + ), + nei_node_ebd, + edge_ebd, + ], + axis=-1, + ) + edge_self_update = self.act(self.edge_self_linear(edge_info)) + else: + edge_self_update = self.act( + self.optim_edge_update( + node_ebd, + node_ebd_ext, + edge_ebd, + nlist, + "edge", + ) + if not self.use_dynamic_sel + else self.optim_edge_update_dynamic( + node_ebd, + node_ebd_ext, + edge_ebd, + n2e_index, + n_ext2e_index, + "edge", + ) + ) + + # Apply edge self residual + edge_ebd_s1 = edge_ebd + self.e_residual[0] * edge_self_update + + # ==================================================================== + # Phase 2: Angle self update (uses original node_ebd, updated edge_ebd_s1) + # ==================================================================== + if self.a_compress_rate != 0: + if not self.a_compress_use_split: + assert self.a_compress_n_linear is not None + assert self.a_compress_e_linear is not None + node_ebd_for_angle = self.a_compress_n_linear(node_ebd) + edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd_s1) + else: + node_ebd_for_angle = node_ebd[..., : self.n_a_compress_dim] + edge_ebd_for_angle = edge_ebd_s1[..., : self.e_a_compress_dim] + else: + node_ebd_for_angle = node_ebd + edge_ebd_for_angle = edge_ebd_s1 + + if not self.use_dynamic_sel: + edge_ebd_for_angle = edge_ebd_for_angle[..., : self.a_sel, :] + edge_ebd_for_angle = xp.where( + xp.expand_dims(a_nlist_mask, axis=-1), + edge_ebd_for_angle, + xp.zeros_like(edge_ebd_for_angle), + ) + + if not self.optim_update: + node_for_angle_info = ( + xp.tile( + xp.reshape( + node_ebd_for_angle, (nb, nloc, 1, 1, self.n_a_compress_dim) + ), + (1, 1, self.a_sel, self.a_sel, 1), + ) + if not self.use_dynamic_sel + else xp.take( + xp.reshape(node_ebd_for_angle, (-1, self.n_a_compress_dim)), + n2a_index, + axis=0, + ) + ) + edge_for_angle_k = ( + xp.tile( + xp.reshape( + edge_ebd_for_angle, + (nb, nloc, 1, self.a_sel, self.e_a_compress_dim), + ), + (1, 1, self.a_sel, 1, 1), + ) + if not self.use_dynamic_sel + else xp.take( + edge_ebd_for_angle, + eik2a_index, + axis=0, + ) + ) + edge_for_angle_j = ( + xp.tile( + xp.reshape( + edge_ebd_for_angle, + (nb, nloc, self.a_sel, 1, self.e_a_compress_dim), + ), + (1, 1, 1, self.a_sel, 1), + ) + if not self.use_dynamic_sel + else xp.take( + edge_ebd_for_angle, + eij2a_index, + axis=0, + ) + ) + edge_for_angle_info = xp.concat( + [edge_for_angle_k, edge_for_angle_j], axis=-1 + ) + angle_info = xp.concat( + [angle_ebd, node_for_angle_info, edge_for_angle_info], axis=-1 + ) + angle_self_update = self.act(self.angle_self_linear(angle_info)) + else: + angle_self_update = self.act( + self.optim_angle_update( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + "angle", + ) + if not self.use_dynamic_sel + else self.optim_angle_update_dynamic( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + n2a_index, + eij2a_index, + eik2a_index, + "angle", + ) + ) + + # Apply angle self residual + a_updated = angle_ebd + self.a_residual[0] * angle_self_update + + # ==================================================================== + # Phase 3: Edge angle update (uses updated angle a_updated, updated edge_ebd_s1) + # ==================================================================== + if not self.optim_update: + angle_info_s2 = xp.concat( + [a_updated, node_for_angle_info, edge_for_angle_info], axis=-1 + ) + edge_angle_update = self.act(self.edge_angle_linear1(angle_info_s2)) + else: + edge_angle_update = self.act( + self.optim_angle_update( + a_updated, + node_ebd_for_angle, + edge_ebd_for_angle, + "edge", + ) + if not self.use_dynamic_sel + else self.optim_angle_update_dynamic( + a_updated, + node_ebd_for_angle, + edge_ebd_for_angle, + n2a_index, + eij2a_index, + eik2a_index, + "edge", + ) + ) + + # Reduce edge angle update over angle dimension + if not self.use_dynamic_sel: + weighted_edge_angle_update = ( + a_sw[:, :, :, xp.newaxis, xp.newaxis] + * a_sw[:, :, xp.newaxis, :, xp.newaxis] + * edge_angle_update + ) + reduced_edge_angle_update = xp.sum(weighted_edge_angle_update, axis=-2) / ( + self.a_sel**0.5 + ) + padding_edge_angle_update = xp.concat( + [ + reduced_edge_angle_update, + xp.zeros( + (nb, nloc, self.nnei - self.a_sel, self.e_dim), + dtype=edge_ebd.dtype, + device=array_api_compat.device(edge_ebd), + ), + ], + axis=2, + ) + else: + weighted_edge_angle_update = edge_angle_update * xp.expand_dims( + a_sw, axis=-1 + ) + padding_edge_angle_update = aggregate( + weighted_edge_angle_update, + eij2a_index, + average=False, + num_owner=n_edge, + ) / (self.dynamic_a_sel**0.5) + + if not self.smooth_edge_update: + if self.use_dynamic_sel: + raise NotImplementedError( + "smooth_edge_update must be True when use_dynamic_sel is True!" + ) + full_mask = xp.concat( + [ + a_nlist_mask, + xp.zeros( + (nb, nloc, self.nnei - self.a_sel), + dtype=a_nlist_mask.dtype, + device=array_api_compat.device(a_nlist_mask), + ), + ], + axis=-1, + ) + padding_edge_angle_update = xp.where( + xp.expand_dims(full_mask, axis=-1), + padding_edge_angle_update, + edge_ebd, + ) + + edge_angle_processed = self.act( + self.edge_angle_linear2(padding_edge_angle_update) + ) + + # Apply edge angle residual on top of edge_ebd_s1 + e_updated = edge_ebd_s1 + self.e_residual[1] * edge_angle_processed + + # ==================================================================== + # Phase 4: Node edge message (uses e_updated) + # ==================================================================== + if not self.optim_update: + if not self.use_dynamic_sel: + edge_info_updated = xp.concat( + [ + xp.tile( + xp.reshape(node_ebd, (nb, nloc, 1, self.n_dim)), + (1, 1, self.nnei, 1), + ), + nei_node_ebd, + e_updated, + ], + axis=-1, + ) + else: + edge_info_updated = xp.concat( + [ + xp.take( + xp.reshape(node_ebd, (-1, self.n_dim)), + n2e_index, + axis=0, + ), + nei_node_ebd, + e_updated, + ], + axis=-1, + ) + node_edge_update = self.act( + self.node_edge_linear(edge_info_updated) + ) * xp.expand_dims(sw, axis=-1) + else: + node_edge_update = self.act( + self.optim_edge_update( + node_ebd, + node_ebd_ext, + e_updated, + nlist, + "node", + ) + if not self.use_dynamic_sel + else self.optim_edge_update_dynamic( + node_ebd, + node_ebd_ext, + e_updated, + n2e_index, + n_ext2e_index, + "node", + ) + ) * xp.expand_dims(sw, axis=-1) + + node_edge_update = ( + (xp.sum(node_edge_update, axis=-2) / self.nnei) + if not self.use_dynamic_sel + else ( + xp.reshape( + aggregate( + node_edge_update, + n2e_index, + average=False, + num_owner=nb * nloc, + ), + (nb, nloc, node_edge_update.shape[-1]), + ) + / self.dynamic_e_sel + ) + ) + + # ==================================================================== + # Phase 5: Node updates (node_self, node_sym with e_updated, node_edge) + # ==================================================================== + n_update_list: list[Array] = [node_ebd] + + # node self mlp + node_self_mlp = self.act(self.node_self_mlp(node_ebd)) + n_update_list.append(node_self_mlp) + + # node sym using e_updated + node_sym_list: list[Array] = [] + node_sym_list.append( + symmetrization_op( + e_updated, + h2, + nlist_mask, + sw, + self.axis_neuron, + ) + if not self.use_dynamic_sel + else symmetrization_op_dynamic( + e_updated, + h2, + sw, + owner=n2e_index, + num_owner=nb * nloc, + nb=nb, + nloc=nloc, + scale_factor=self.dynamic_e_sel ** (-0.5), + axis_neuron=self.axis_neuron, + ) + ) + node_sym_list.append( + symmetrization_op( + nei_node_ebd, + h2, + nlist_mask, + sw, + self.axis_neuron, + ) + if not self.use_dynamic_sel + else symmetrization_op_dynamic( + nei_node_ebd, + h2, + sw, + owner=n2e_index, + num_owner=nb * nloc, + nb=nb, + nloc=nloc, + scale_factor=self.dynamic_e_sel ** (-0.5), + axis_neuron=self.axis_neuron, + ) + ) + node_sym = self.act(self.node_sym_linear(xp.concat(node_sym_list, axis=-1))) + n_update_list.append(node_sym) + + if self.n_multi_edge_message > 1: + node_edge_update_mul_head = xp.reshape( + node_edge_update, (nb, nloc, self.n_multi_edge_message, self.n_dim) + ) + for head_index in range(self.n_multi_edge_message): + n_update_list.append(node_edge_update_mul_head[:, :, head_index, :]) + else: + n_update_list.append(node_edge_update) + + n_updated = self.list_update(n_update_list, "node") + + return n_updated, e_updated, a_updated + def call( self, node_ebd_ext: Array, # nf x nall x n_dim @@ -1446,6 +1870,32 @@ def call( ) ) + if self.sequential_update and self.update_angle: + return self._call_sequential( + xp, + node_ebd, + node_ebd_ext, + edge_ebd, + h2, + angle_ebd, + nlist, + nlist_mask, + sw, + a_nlist_mask, + a_sw, + nei_node_ebd, + n2e_index, + n_ext2e_index, + n2a_index, + eij2a_index, + eik2a_index, + nb, + nloc, + nnei, + nall, + n_edge, + ) + n_update_list: list[Array] = [node_ebd] e_update_list: list[Array] = [edge_ebd] a_update_list: list[Array] = [angle_ebd] @@ -1907,6 +2357,7 @@ def serialize(self) -> dict: "smooth_edge_update": self.smooth_edge_update, "use_dynamic_sel": self.use_dynamic_sel, "sel_reduce_factor": self.sel_reduce_factor, + "sequential_update": self.sequential_update, "node_self_mlp": self.node_self_mlp.serialize(), "node_sym_linear": self.node_sym_linear.serialize(), "node_edge_linear": self.node_edge_linear.serialize(), diff --git a/deepmd/pt/model/descriptor/dpa3.py b/deepmd/pt/model/descriptor/dpa3.py index 0c6982afe5..a5f79280fa 100644 --- a/deepmd/pt/model/descriptor/dpa3.py +++ b/deepmd/pt/model/descriptor/dpa3.py @@ -167,6 +167,7 @@ def init_subclass_params(sub_data: Any, sub_class: Any) -> Any: use_exp_switch=self.repflow_args.use_exp_switch, use_dynamic_sel=self.repflow_args.use_dynamic_sel, sel_reduce_factor=self.repflow_args.sel_reduce_factor, + sequential_update=self.repflow_args.sequential_update, use_loc_mapping=use_loc_mapping, exclude_types=exclude_types, env_protection=env_protection, diff --git a/deepmd/pt/model/descriptor/repflow_layer.py b/deepmd/pt/model/descriptor/repflow_layer.py index 338f48b060..57c9368839 100644 --- a/deepmd/pt/model/descriptor/repflow_layer.py +++ b/deepmd/pt/model/descriptor/repflow_layer.py @@ -53,6 +53,7 @@ def __init__( optim_update: bool = True, use_dynamic_sel: bool = False, sel_reduce_factor: float = 10.0, + sequential_update: bool = False, smooth_edge_update: bool = False, activation_function: str = "silu", update_style: str = "res_residual", @@ -102,8 +103,15 @@ def __init__( self.smooth_edge_update = smooth_edge_update self.use_dynamic_sel = use_dynamic_sel self.sel_reduce_factor = sel_reduce_factor + self.sequential_update = sequential_update self.dynamic_e_sel = self.nnei / self.sel_reduce_factor self.dynamic_a_sel = self.a_sel / self.sel_reduce_factor + if self.sequential_update and self.update_style != "res_residual": + raise NotImplementedError( + "sequential_update only supports update_style='res_residual'!" + ) + if self.sequential_update and not self.update_angle: + raise NotImplementedError("sequential_update requires update_angle=True!") assert update_residual_init in [ "norm", @@ -694,6 +702,383 @@ def optim_edge_update_dynamic( result_update = bias + sub_node_update + sub_edge_update + sub_node_ext_update return result_update + def _forward_sequential( + self, + node_ebd: torch.Tensor, + node_ebd_ext: torch.Tensor, + edge_ebd: torch.Tensor, + h2: torch.Tensor, + angle_ebd: torch.Tensor, + nlist: torch.Tensor, + nlist_mask: torch.Tensor, + sw: torch.Tensor, + a_nlist_mask: torch.Tensor, + a_sw: torch.Tensor, + nei_node_ebd: torch.Tensor, + n2e_index: torch.Tensor, + n_ext2e_index: torch.Tensor, + n2a_index: torch.Tensor, + eij2a_index: torch.Tensor, + eik2a_index: torch.Tensor, + nb: int, + nloc: int, + nnei: int, + nall: int, + n_edge: int | None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Sequential update path: edge_self → angle_self → edge_angle → node. + + Only supports update_style='res_residual'. + """ + assert self.edge_angle_linear1 is not None + assert self.edge_angle_linear2 is not None + assert self.angle_self_linear is not None + + # ==================================================================== + # Phase 1: Edge self update (uses original node_ebd, edge_ebd) + # ==================================================================== + if not self.optim_update: + if not self.use_dynamic_sel: + edge_info = torch.cat( + [ + torch.tile(node_ebd.unsqueeze(-2), [1, 1, self.nnei, 1]), + nei_node_ebd, + edge_ebd, + ], + dim=-1, + ) + else: + edge_info = torch.cat( + [ + torch.index_select( + node_ebd.reshape(-1, self.n_dim), 0, n2e_index + ), + nei_node_ebd, + edge_ebd, + ], + dim=-1, + ) + edge_self_update = self.act(self.edge_self_linear(edge_info)) + else: + edge_self_update = self.act( + self.optim_edge_update( + node_ebd, + node_ebd_ext, + edge_ebd, + nlist, + "edge", + ) + if not self.use_dynamic_sel + else self.optim_edge_update_dynamic( + node_ebd, + node_ebd_ext, + edge_ebd, + n2e_index, + n_ext2e_index, + "edge", + ) + ) + + # Apply edge self residual: edge_ebd_s1 = edge_ebd + e_residual[0] * edge_self_update + edge_ebd_s1 = edge_ebd + self.e_residual[0] * edge_self_update + + # ==================================================================== + # Phase 2: Angle self update (uses original node_ebd, updated edge_ebd_s1) + # ==================================================================== + # Prepare edge for angle from edge_ebd_s1 (updated edge) + if self.a_compress_rate != 0: + if not self.a_compress_use_split: + assert self.a_compress_n_linear is not None + assert self.a_compress_e_linear is not None + node_ebd_for_angle = self.a_compress_n_linear(node_ebd) + edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd_s1) + else: + node_ebd_for_angle = node_ebd[..., : self.n_a_compress_dim] + edge_ebd_for_angle = edge_ebd_s1[..., : self.e_a_compress_dim] + else: + node_ebd_for_angle = node_ebd + edge_ebd_for_angle = edge_ebd_s1 + + if not self.use_dynamic_sel: + edge_ebd_for_angle = edge_ebd_for_angle[..., : self.a_sel, :] + edge_ebd_for_angle = torch.where( + a_nlist_mask.unsqueeze(-1), edge_ebd_for_angle, 0.0 + ) + + # Initialize for JIT: these are only used in non-optim_update path + node_for_angle_info = angle_ebd # placeholder, overwritten below + edge_for_angle_info = angle_ebd # placeholder, overwritten below + + if not self.optim_update: + node_for_angle_info = ( + torch.tile( + node_ebd_for_angle.unsqueeze(2).unsqueeze(2), + (1, 1, self.a_sel, self.a_sel, 1), + ) + if not self.use_dynamic_sel + else torch.index_select( + node_ebd_for_angle.reshape(-1, self.n_a_compress_dim), + 0, + n2a_index, + ) + ) + edge_for_angle_k = ( + torch.tile(edge_ebd_for_angle.unsqueeze(2), (1, 1, self.a_sel, 1, 1)) + if not self.use_dynamic_sel + else torch.index_select(edge_ebd_for_angle, 0, eik2a_index) + ) + edge_for_angle_j = ( + torch.tile(edge_ebd_for_angle.unsqueeze(3), (1, 1, 1, self.a_sel, 1)) + if not self.use_dynamic_sel + else torch.index_select(edge_ebd_for_angle, 0, eij2a_index) + ) + edge_for_angle_info = torch.cat( + [edge_for_angle_k, edge_for_angle_j], dim=-1 + ) + angle_info = torch.cat( + [angle_ebd, node_for_angle_info, edge_for_angle_info], dim=-1 + ) + angle_self_update = self.act(self.angle_self_linear(angle_info)) + else: + angle_self_update = self.act( + self.optim_angle_update( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + "angle", + ) + if not self.use_dynamic_sel + else self.optim_angle_update_dynamic( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + n2a_index, + eij2a_index, + eik2a_index, + "angle", + ) + ) + + # Apply angle self residual: angle_ebd_s2 = angle_ebd + a_residual[0] * angle_self_update + a_updated = angle_ebd + self.a_residual[0] * angle_self_update + + # ==================================================================== + # Phase 3: Edge angle update (uses updated angle_ebd_s2, updated edge_ebd_s1) + # ==================================================================== + if not self.optim_update: + # Rebuild angle_info with updated angle (a_updated) + angle_info_s2 = torch.cat( + [a_updated, node_for_angle_info, edge_for_angle_info], dim=-1 + ) + edge_angle_update = self.act(self.edge_angle_linear1(angle_info_s2)) + else: + edge_angle_update = self.act( + self.optim_angle_update( + a_updated, + node_ebd_for_angle, + edge_ebd_for_angle, + "edge", + ) + if not self.use_dynamic_sel + else self.optim_angle_update_dynamic( + a_updated, + node_ebd_for_angle, + edge_ebd_for_angle, + n2a_index, + eij2a_index, + eik2a_index, + "edge", + ) + ) + + # Reduce edge angle update over angle dimension + if not self.use_dynamic_sel: + weighted_edge_angle_update = ( + a_sw.unsqueeze(-1).unsqueeze(-1) + * a_sw.unsqueeze(-2).unsqueeze(-1) + * edge_angle_update + ) + reduced_edge_angle_update = torch.sum( + weighted_edge_angle_update, dim=-2 + ) / (self.a_sel**0.5) + padding_edge_angle_update = torch.concat( + [ + reduced_edge_angle_update, + torch.zeros( + [nb, nloc, self.nnei - self.a_sel, self.e_dim], + dtype=edge_ebd.dtype, + device=edge_ebd.device, + ), + ], + dim=2, + ) + else: + assert n_edge is not None + weighted_edge_angle_update = edge_angle_update * a_sw.unsqueeze(-1) + padding_edge_angle_update = aggregate( + weighted_edge_angle_update, + eij2a_index, + average=False, + num_owner=n_edge, + ) / (self.dynamic_a_sel**0.5) + + if not self.smooth_edge_update: + if self.use_dynamic_sel: + raise NotImplementedError( + "smooth_edge_update must be True when use_dynamic_sel is True!" + ) + full_mask = torch.concat( + [ + a_nlist_mask, + torch.zeros( + [nb, nloc, self.nnei - self.a_sel], + dtype=a_nlist_mask.dtype, + device=a_nlist_mask.device, + ), + ], + dim=-1, + ) + padding_edge_angle_update = torch.where( + full_mask.unsqueeze(-1), padding_edge_angle_update, edge_ebd + ) + + edge_angle_processed = self.act( + self.edge_angle_linear2(padding_edge_angle_update) + ) + + # Apply edge angle residual on top of edge_ebd_s1 (no recomputation) + e_updated = edge_ebd_s1 + self.e_residual[1] * edge_angle_processed + + # ==================================================================== + # Phase 4: Node edge message (uses e_updated) + # ==================================================================== + if not self.optim_update: + if not self.use_dynamic_sel: + edge_info_updated = torch.cat( + [ + torch.tile(node_ebd.unsqueeze(-2), [1, 1, self.nnei, 1]), + nei_node_ebd, + e_updated, + ], + dim=-1, + ) + else: + edge_info_updated = torch.cat( + [ + torch.index_select( + node_ebd.reshape(-1, self.n_dim), 0, n2e_index + ), + nei_node_ebd, + e_updated, + ], + dim=-1, + ) + node_edge_update = self.act( + self.node_edge_linear(edge_info_updated) + ) * sw.unsqueeze(-1) + else: + node_edge_update = self.act( + self.optim_edge_update( + node_ebd, + node_ebd_ext, + e_updated, + nlist, + "node", + ) + if not self.use_dynamic_sel + else self.optim_edge_update_dynamic( + node_ebd, + node_ebd_ext, + e_updated, + n2e_index, + n_ext2e_index, + "node", + ) + ) * sw.unsqueeze(-1) + + node_edge_update = ( + (torch.sum(node_edge_update, dim=-2) / self.nnei) + if not self.use_dynamic_sel + else ( + aggregate( + node_edge_update, + n2e_index, + average=False, + num_owner=nb * nloc, + ).reshape(nb, nloc, node_edge_update.shape[-1]) + / self.dynamic_e_sel + ) + ) + + # ==================================================================== + # Phase 5: Node updates (node_self, node_sym with e_updated, node_edge) + # ==================================================================== + n_update_list: list[torch.Tensor] = [node_ebd] + + # node self mlp (uses original node_ebd) + node_self_mlp = self.act(self.node_self_mlp(node_ebd)) + n_update_list.append(node_self_mlp) + + # node sym using e_updated + node_sym_list: list[torch.Tensor] = [] + node_sym_list.append( + self.symmetrization_op( + e_updated, + h2, + nlist_mask, + sw, + self.axis_neuron, + ) + if not self.use_dynamic_sel + else self.symmetrization_op_dynamic( + e_updated, + h2, + sw, + owner=n2e_index, + num_owner=nb * nloc, + nb=nb, + nloc=nloc, + scale_factor=self.dynamic_e_sel ** (-0.5), + axis_neuron=self.axis_neuron, + ) + ) + node_sym_list.append( + self.symmetrization_op( + nei_node_ebd, + h2, + nlist_mask, + sw, + self.axis_neuron, + ) + if not self.use_dynamic_sel + else self.symmetrization_op_dynamic( + nei_node_ebd, + h2, + sw, + owner=n2e_index, + num_owner=nb * nloc, + nb=nb, + nloc=nloc, + scale_factor=self.dynamic_e_sel ** (-0.5), + axis_neuron=self.axis_neuron, + ) + ) + node_sym = self.act(self.node_sym_linear(torch.cat(node_sym_list, dim=-1))) + n_update_list.append(node_sym) + + if self.n_multi_edge_message > 1: + node_edge_update_mul_head = node_edge_update.view( + nb, nloc, self.n_multi_edge_message, self.n_dim + ) + for head_index in range(self.n_multi_edge_message): + n_update_list.append(node_edge_update_mul_head[..., head_index, :]) + else: + n_update_list.append(node_edge_update) + + n_updated = self.list_update(n_update_list, "node") + + return n_updated, e_updated, a_updated + def forward( self, node_ebd_ext: torch.Tensor, # nf x nall x n_dim [OR] nf x nloc x n_dim when not parallel_mode @@ -783,6 +1168,31 @@ def forward( ) ) + if self.sequential_update and self.update_angle: + return self._forward_sequential( + node_ebd, + node_ebd_ext, + edge_ebd, + h2, + angle_ebd, + nlist, + nlist_mask, + sw, + a_nlist_mask, + a_sw, + nei_node_ebd, + n2e_index, + n_ext2e_index, + n2a_index, + eij2a_index, + eik2a_index, + nb, + nloc, + nnei, + nall, + n_edge, + ) + n_update_list: list[torch.Tensor] = [node_ebd] e_update_list: list[torch.Tensor] = [edge_ebd] a_update_list: list[torch.Tensor] = [angle_ebd] @@ -1220,6 +1630,7 @@ def serialize(self) -> dict: "smooth_edge_update": self.smooth_edge_update, "use_dynamic_sel": self.use_dynamic_sel, "sel_reduce_factor": self.sel_reduce_factor, + "sequential_update": self.sequential_update, "node_self_mlp": self.node_self_mlp.serialize(), "node_sym_linear": self.node_sym_linear.serialize(), "node_edge_linear": self.node_edge_linear.serialize(), diff --git a/deepmd/pt/model/descriptor/repflows.py b/deepmd/pt/model/descriptor/repflows.py index 433897860f..7c16ab3c7a 100644 --- a/deepmd/pt/model/descriptor/repflows.py +++ b/deepmd/pt/model/descriptor/repflows.py @@ -219,6 +219,7 @@ def __init__( use_exp_switch: bool = False, use_dynamic_sel: bool = False, sel_reduce_factor: float = 10.0, + sequential_update: bool = False, use_loc_mapping: bool = True, optim_update: bool = True, seed: int | list[int] | None = None, @@ -258,6 +259,7 @@ def __init__( self.use_exp_switch = use_exp_switch self.use_dynamic_sel = use_dynamic_sel self.sel_reduce_factor = sel_reduce_factor + self.sequential_update = sequential_update if self.use_dynamic_sel and not self.smooth_edge_update: raise NotImplementedError( "smooth_edge_update must be True when use_dynamic_sel is True!" @@ -329,6 +331,7 @@ def __init__( optim_update=self.optim_update, use_dynamic_sel=self.use_dynamic_sel, sel_reduce_factor=self.sel_reduce_factor, + sequential_update=self.sequential_update, smooth_edge_update=self.smooth_edge_update, seed=child_seed(child_seed(seed, 1), ii), trainable=trainable, diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index b12bc7ef6f..70a7985702 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1550,6 +1550,13 @@ def dpa3_repflow_args() -> list[Argument]: "or `a_sel / sel_reduce_factor` instead of the raw `e_sel` or `a_sel` values, " "accommodating larger selection numbers." ) + doc_sequential_update = ( + "Whether to use sequential update mode within each repflow layer. " + "When True, updates are applied sequentially: edge self → angle self (using updated edge) " + "→ edge angle (using updated angle) → node (using final edge), " + "instead of the default parallel mode where all updates use original embeddings. " + "Currently only supports update_style='res_residual'." + ) return [ # repflow args @@ -1680,6 +1687,13 @@ def dpa3_repflow_args() -> list[Argument]: default=10.0, doc=doc_sel_reduce_factor, ), + Argument( + "sequential_update", + bool, + optional=True, + default=False, + doc=doc_sequential_update, + ), ] diff --git a/source/tests/consistent/descriptor/test_dpa3.py b/source/tests/consistent/descriptor/test_dpa3.py index bca0759f5c..b980c584a1 100644 --- a/source/tests/consistent/descriptor/test_dpa3.py +++ b/source/tests/consistent/descriptor/test_dpa3.py @@ -79,6 +79,7 @@ (1,), # n_multi_edge_message ("float64",), # precision (False, True), # add_chg_spin_ebd + (False, True), # sequential_update ) class TestDPA3(CommonTest, DescriptorTest, unittest.TestCase): @property @@ -99,6 +100,7 @@ def data(self) -> dict: n_multi_edge_message, precision, add_chg_spin_ebd, + sequential_update, ) = self.param return { "ntypes": self.ntypes, @@ -130,6 +132,7 @@ def data(self) -> dict: "update_style": "res_residual", "update_residual": 0.1, "update_residual_init": update_residual_init, + "sequential_update": sequential_update, } ), # kwargs for descriptor @@ -160,6 +163,7 @@ def skip_pt(self) -> bool: n_multi_edge_message, precision, add_chg_spin_ebd, + sequential_update, ) = self.param return CommonTest.skip_pt @@ -181,6 +185,7 @@ def skip_pd(self) -> bool: n_multi_edge_message, precision, add_chg_spin_ebd, + sequential_update, ) = self.param return True if add_chg_spin_ebd else CommonTest.skip_pd @@ -202,6 +207,7 @@ def skip_dp(self) -> bool: n_multi_edge_message, precision, add_chg_spin_ebd, + sequential_update, ) = self.param return CommonTest.skip_dp @@ -223,6 +229,7 @@ def skip_tf(self) -> bool: n_multi_edge_message, precision, add_chg_spin_ebd, + sequential_update, ) = self.param return True @@ -288,6 +295,7 @@ def setUp(self) -> None: n_multi_edge_message, precision, add_chg_spin_ebd, + sequential_update, ) = self.param # fparam for charge=5, spin=1 when add_chg_spin_ebd is True self.fparam = ( @@ -394,6 +402,7 @@ def rtol(self) -> float: n_multi_edge_message, precision, add_chg_spin_ebd, + sequential_update, ) = self.param if precision == "float64": return 1e-10 @@ -421,6 +430,7 @@ def atol(self) -> float: n_multi_edge_message, precision, add_chg_spin_ebd, + sequential_update, ) = self.param if precision == "float64": return 1e-6 # need to fix in the future, see issue https://github.com/deepmodeling/deepmd-kit/issues/3786 diff --git a/source/tests/pt/model/test_dpa3.py b/source/tests/pt/model/test_dpa3.py index 12b0be4532..d66eab9dea 100644 --- a/source/tests/pt/model/test_dpa3.py +++ b/source/tests/pt/model/test_dpa3.py @@ -56,6 +56,7 @@ def test_consistency( prec, ect, add_chg_spin, + seq_upd, ) in itertools.product( [True, False], # update_angle ["res_residual"], # update_style @@ -67,7 +68,11 @@ def test_consistency( ["float64"], # precision [False], # use_econf_tebd [False, True], # add_chg_spin_ebd + [False, True], # sequential_update ): + # sequential_update only works with update_angle=True + if seq_upd and not ua: + continue dtype = PRECISION_DICT[prec] rtol, atol = get_tols(prec) if prec == "float64": @@ -93,6 +98,7 @@ def test_consistency( update_style=rus, update_residual_init=ruri, smooth_edge_update=True, + sequential_update=seq_upd, ) # dpa3 new impl @@ -177,6 +183,7 @@ def test_jit( nme, prec, ect, + seq_upd, ) in itertools.product( [True], # update_angle ["res_residual"], # update_style @@ -187,6 +194,7 @@ def test_jit( [1, 2], # n_multi_edge_message ["float64"], # precision [False], # use_econf_tebd + [False, True], # sequential_update ): dtype = PRECISION_DICT[prec] rtol, atol = get_tols(prec) @@ -211,6 +219,7 @@ def test_jit( update_style=rus, update_residual_init=ruri, smooth_edge_update=True, + sequential_update=seq_upd, ) # dpa3 new impl From 8523d88e4b9cc6c202dbefeeea490df6bf64421f Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Fri, 8 May 2026 20:59:35 +0800 Subject: [PATCH 2/5] fix comments --- deepmd/dpmodel/descriptor/dpa3.py | 3 +- deepmd/dpmodel/descriptor/repflows.py | 772 +++++++----------- deepmd/pt/model/descriptor/repflow_layer.py | 710 ++++++---------- deepmd/utils/argcheck.py | 2 +- .../tests/consistent/descriptor/test_dpa3.py | 8 +- 5 files changed, 563 insertions(+), 932 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa3.py b/deepmd/dpmodel/descriptor/dpa3.py index 88f56213ee..52b2ca9c2c 100644 --- a/deepmd/dpmodel/descriptor/dpa3.py +++ b/deepmd/dpmodel/descriptor/dpa3.py @@ -175,7 +175,8 @@ class RepFlowArgs: When True, updates are applied sequentially: edge self → angle self (using updated edge) → edge angle (using updated angle) → node (using final edge), instead of the default parallel mode where all updates use original embeddings. - Currently only supports ``update_style='res_residual'``. + Currently only supports ``update_style='res_residual'`` and requires ``update_angle=True``; + otherwise, a ``ValueError`` will be raised during initialization. """ def __init__( diff --git a/deepmd/dpmodel/descriptor/repflows.py b/deepmd/dpmodel/descriptor/repflows.py index 5788840279..1af1d15bd9 100644 --- a/deepmd/dpmodel/descriptor/repflows.py +++ b/deepmd/dpmodel/descriptor/repflows.py @@ -1354,42 +1354,20 @@ def optim_edge_update_dynamic( result_update = bias + sub_node_update + sub_edge_update + sub_node_ext_update return result_update - def _call_sequential( + def _compute_edge_self_update( self, xp: object, node_ebd: Array, node_ebd_ext: Array, edge_ebd: Array, - h2: Array, - angle_ebd: Array, - nlist: Array, - nlist_mask: Array, - sw: Array, - a_nlist_mask: Array, - a_sw: Array, nei_node_ebd: Array, + nlist: Array, n2e_index: Array, n_ext2e_index: Array, - n2a_index: Array, - eij2a_index: Array, - eik2a_index: Array, nb: int, nloc: int, - nnei: int, - nall: int, - n_edge: int, - ) -> tuple[Array, Array, Array]: - """Sequential update path: edge_self → angle_self → edge_angle → node. - - Only supports update_style='res_residual'. - """ - assert self.angle_self_linear is not None - assert self.edge_angle_linear1 is not None - assert self.edge_angle_linear2 is not None - - # ==================================================================== - # Phase 1: Edge self update (uses original node_ebd, edge_ebd) - # ==================================================================== + ) -> Array: + """Compute edge self update.""" if not self.optim_update: if not self.use_dynamic_sel: edge_info = xp.concat( @@ -1416,9 +1394,9 @@ def _call_sequential( ], axis=-1, ) - edge_self_update = self.act(self.edge_self_linear(edge_info)) + return self.act(self.edge_self_linear(edge_info)) else: - edge_self_update = self.act( + return self.act( self.optim_edge_update( node_ebd, node_ebd_ext, @@ -1437,24 +1415,26 @@ def _call_sequential( ) ) - # Apply edge self residual - edge_ebd_s1 = edge_ebd + self.e_residual[0] * edge_self_update - - # ==================================================================== - # Phase 2: Angle self update (uses original node_ebd, updated edge_ebd_s1) - # ==================================================================== + def _prepare_angle_embeddings( + self, + xp: object, + node_ebd: Array, + edge_ebd: Array, + a_nlist_mask: Array, + ) -> tuple[Array, Array]: + """Prepare compressed node/edge embeddings for angle computation.""" if self.a_compress_rate != 0: if not self.a_compress_use_split: assert self.a_compress_n_linear is not None assert self.a_compress_e_linear is not None node_ebd_for_angle = self.a_compress_n_linear(node_ebd) - edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd_s1) + edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd) else: node_ebd_for_angle = node_ebd[..., : self.n_a_compress_dim] - edge_ebd_for_angle = edge_ebd_s1[..., : self.e_a_compress_dim] + edge_ebd_for_angle = edge_ebd[..., : self.e_a_compress_dim] else: node_ebd_for_angle = node_ebd - edge_ebd_for_angle = edge_ebd_s1 + edge_ebd_for_angle = edge_ebd if not self.use_dynamic_sel: edge_ebd_for_angle = edge_ebd_for_angle[..., : self.a_sel, :] @@ -1463,12 +1443,34 @@ def _call_sequential( edge_ebd_for_angle, xp.zeros_like(edge_ebd_for_angle), ) + return node_ebd_for_angle, edge_ebd_for_angle + + def _compute_angle_update( + self, + xp: object, + angle_ebd: Array, + node_ebd_for_angle: Array, + edge_ebd_for_angle: Array, + feat: str, + n2a_index: Array, + eij2a_index: Array, + eik2a_index: Array, + nb: int, + nloc: int, + ) -> Array: + """Compute angle-based update (for edge_angle or angle_self). + Parameters + ---------- + feat : str + "edge" for edge_angle_linear1, "angle" for angle_self_linear. + """ if not self.optim_update: node_for_angle_info = ( xp.tile( xp.reshape( - node_ebd_for_angle, (nb, nloc, 1, 1, self.n_a_compress_dim) + node_ebd_for_angle, + (nb, nloc, 1, 1, self.n_a_compress_dim), ), (1, 1, self.a_sel, self.a_sel, 1), ) @@ -1515,59 +1517,72 @@ def _call_sequential( angle_info = xp.concat( [angle_ebd, node_for_angle_info, edge_for_angle_info], axis=-1 ) - angle_self_update = self.act(self.angle_self_linear(angle_info)) + if feat == "edge": + assert self.edge_angle_linear1 is not None + return self.act(self.edge_angle_linear1(angle_info)) + else: + assert self.angle_self_linear is not None + return self.act(self.angle_self_linear(angle_info)) else: - angle_self_update = self.act( - self.optim_angle_update( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - "angle", + if feat == "edge": + return self.act( + self.optim_angle_update( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + "edge", + ) + if not self.use_dynamic_sel + else self.optim_angle_update_dynamic( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + n2a_index, + eij2a_index, + eik2a_index, + "edge", + ) ) - if not self.use_dynamic_sel - else self.optim_angle_update_dynamic( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - n2a_index, - eij2a_index, - eik2a_index, - "angle", + else: + return self.act( + self.optim_angle_update( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + "angle", + ) + if not self.use_dynamic_sel + else self.optim_angle_update_dynamic( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + n2a_index, + eij2a_index, + eik2a_index, + "angle", + ) ) - ) - # Apply angle self residual - a_updated = angle_ebd + self.a_residual[0] * angle_self_update - - # ==================================================================== - # Phase 3: Edge angle update (uses updated angle a_updated, updated edge_ebd_s1) - # ==================================================================== - if not self.optim_update: - angle_info_s2 = xp.concat( - [a_updated, node_for_angle_info, edge_for_angle_info], axis=-1 - ) - edge_angle_update = self.act(self.edge_angle_linear1(angle_info_s2)) - else: - edge_angle_update = self.act( - self.optim_angle_update( - a_updated, - node_ebd_for_angle, - edge_ebd_for_angle, - "edge", - ) - if not self.use_dynamic_sel - else self.optim_angle_update_dynamic( - a_updated, - node_ebd_for_angle, - edge_ebd_for_angle, - n2a_index, - eij2a_index, - eik2a_index, - "edge", - ) - ) + def _compute_edge_angle_reduction( + self, + xp: object, + edge_angle_update: Array, + edge_ebd_fallback: Array, + a_sw: Array, + a_nlist_mask: Array, + nb: int, + nloc: int, + n_edge: int, + eij2a_index: Array, + ) -> Array: + """Reduce edge angle update over angle dimension, pad, and apply linear2. - # Reduce edge angle update over angle dimension + Parameters + ---------- + edge_ebd_fallback : Array + Edge embedding used for non-smooth padding fallback. + """ + assert self.edge_angle_linear2 is not None if not self.use_dynamic_sel: weighted_edge_angle_update = ( a_sw[:, :, :, xp.newaxis, xp.newaxis] @@ -1582,8 +1597,8 @@ def _call_sequential( reduced_edge_angle_update, xp.zeros( (nb, nloc, self.nnei - self.a_sel, self.e_dim), - dtype=edge_ebd.dtype, - device=array_api_compat.device(edge_ebd), + dtype=edge_ebd_fallback.dtype, + device=array_api_compat.device(edge_ebd_fallback), ), ], axis=2, @@ -1618,34 +1633,41 @@ def _call_sequential( padding_edge_angle_update = xp.where( xp.expand_dims(full_mask, axis=-1), padding_edge_angle_update, - edge_ebd, + edge_ebd_fallback, ) - edge_angle_processed = self.act( - self.edge_angle_linear2(padding_edge_angle_update) - ) + return self.act(self.edge_angle_linear2(padding_edge_angle_update)) - # Apply edge angle residual on top of edge_ebd_s1 - e_updated = edge_ebd_s1 + self.e_residual[1] * edge_angle_processed - - # ==================================================================== - # Phase 4: Node edge message (uses e_updated) - # ==================================================================== + def _compute_node_edge_message( + self, + xp: object, + node_ebd: Array, + node_ebd_ext: Array, + edge_ebd: Array, + nei_node_ebd: Array, + sw: Array, + nlist: Array, + n2e_index: Array, + n_ext2e_index: Array, + nb: int, + nloc: int, + ) -> Array: + """Compute node edge message and reduce over neighbor dimension.""" if not self.optim_update: if not self.use_dynamic_sel: - edge_info_updated = xp.concat( + edge_info = xp.concat( [ xp.tile( xp.reshape(node_ebd, (nb, nloc, 1, self.n_dim)), (1, 1, self.nnei, 1), ), nei_node_ebd, - e_updated, + edge_ebd, ], axis=-1, ) else: - edge_info_updated = xp.concat( + edge_info = xp.concat( [ xp.take( xp.reshape(node_ebd, (-1, self.n_dim)), @@ -1653,19 +1675,19 @@ def _call_sequential( axis=0, ), nei_node_ebd, - e_updated, + edge_ebd, ], axis=-1, ) node_edge_update = self.act( - self.node_edge_linear(edge_info_updated) + self.node_edge_linear(edge_info) ) * xp.expand_dims(sw, axis=-1) else: node_edge_update = self.act( self.optim_edge_update( node_ebd, node_ebd_ext, - e_updated, + edge_ebd, nlist, "node", ) @@ -1673,7 +1695,7 @@ def _call_sequential( else self.optim_edge_update_dynamic( node_ebd, node_ebd_ext, - e_updated, + edge_ebd, n2e_index, n_ext2e_index, "node", @@ -1696,21 +1718,25 @@ def _call_sequential( / self.dynamic_e_sel ) ) + return node_edge_update - # ==================================================================== - # Phase 5: Node updates (node_self, node_sym with e_updated, node_edge) - # ==================================================================== - n_update_list: list[Array] = [node_ebd] - - # node self mlp - node_self_mlp = self.act(self.node_self_mlp(node_ebd)) - n_update_list.append(node_self_mlp) - - # node sym using e_updated + def _compute_node_sym( + self, + xp: object, + edge_ebd: Array, + nei_node_ebd: Array, + h2: Array, + nlist_mask: Array, + sw: Array, + n2e_index: Array, + nb: int, + nloc: int, + ) -> Array: + """Compute node symmetrization update (grrg + drrd).""" node_sym_list: list[Array] = [] node_sym_list.append( symmetrization_op( - e_updated, + edge_ebd, h2, nlist_mask, sw, @@ -1718,7 +1744,7 @@ def _call_sequential( ) if not self.use_dynamic_sel else symmetrization_op_dynamic( - e_updated, + edge_ebd, h2, sw, owner=n2e_index, @@ -1750,21 +1776,7 @@ def _call_sequential( axis_neuron=self.axis_neuron, ) ) - node_sym = self.act(self.node_sym_linear(xp.concat(node_sym_list, axis=-1))) - n_update_list.append(node_sym) - - if self.n_multi_edge_message > 1: - node_edge_update_mul_head = xp.reshape( - node_edge_update, (nb, nloc, self.n_multi_edge_message, self.n_dim) - ) - for head_index in range(self.n_multi_edge_message): - n_update_list.append(node_edge_update_mul_head[:, :, head_index, :]) - else: - n_update_list.append(node_edge_update) - - n_updated = self.list_update(n_update_list, "node") - - return n_updated, e_updated, a_updated + return self.act(self.node_sym_linear(xp.concat(node_sym_list, axis=-1))) def call( self, @@ -1870,160 +1882,146 @@ def call( ) ) + # Edge self update (always from original embeddings) + edge_self_update = self._compute_edge_self_update( + xp, + node_ebd, + node_ebd_ext, + edge_ebd, + nei_node_ebd, + nlist, + n2e_index, + n_ext2e_index, + nb, + nloc, + ) + if self.sequential_update and self.update_angle: - return self._call_sequential( + # === Sequential update path === + # Phase 1: Apply edge self residual + edge_ebd_s1 = edge_ebd + self.e_residual[0] * edge_self_update + + # Phase 2: Angle self (uses updated edge_ebd_s1) + node_for_a, edge_for_a = self._prepare_angle_embeddings( + xp, node_ebd, edge_ebd_s1, a_nlist_mask + ) + angle_self_update = self._compute_angle_update( xp, - node_ebd, - node_ebd_ext, - edge_ebd, - h2, angle_ebd, - nlist, - nlist_mask, - sw, - a_nlist_mask, - a_sw, - nei_node_ebd, - n2e_index, - n_ext2e_index, + node_for_a, + edge_for_a, + "angle", n2a_index, eij2a_index, eik2a_index, nb, nloc, - nnei, - nall, - n_edge, ) + a_updated = angle_ebd + self.a_residual[0] * angle_self_update - n_update_list: list[Array] = [node_ebd] - e_update_list: list[Array] = [edge_ebd] - a_update_list: list[Array] = [angle_ebd] - - # node self mlp - node_self_mlp = self.act(self.node_self_mlp(node_ebd)) - n_update_list.append(node_self_mlp) - - # node sym (grrg + drrd) - node_sym_list: list[Array] = [] - node_sym_list.append( - symmetrization_op( - edge_ebd, - h2, - nlist_mask, - sw, - self.axis_neuron, + # Phase 3: Edge angle (uses updated angle a_updated + edge_ebd_s1) + edge_angle_update = self._compute_angle_update( + xp, + a_updated, + node_for_a, + edge_for_a, + "edge", + n2a_index, + eij2a_index, + eik2a_index, + nb, + nloc, ) - if not self.use_dynamic_sel - else symmetrization_op_dynamic( - edge_ebd, - h2, - sw, - owner=n2e_index, - num_owner=nb * nloc, - nb=nb, - nloc=nloc, - scale_factor=self.dynamic_e_sel ** (-0.5), - axis_neuron=self.axis_neuron, + edge_angle_processed = self._compute_edge_angle_reduction( + xp, + edge_angle_update, + edge_ebd_s1, + a_sw, + a_nlist_mask, + nb, + nloc, + n_edge, + eij2a_index, ) - ) - node_sym_list.append( - symmetrization_op( + e_updated = edge_ebd_s1 + self.e_residual[1] * edge_angle_processed + + # Phase 4+5: Node updates (uses e_updated) + node_self_mlp = self.act(self.node_self_mlp(node_ebd)) + node_sym = self._compute_node_sym( + xp, + e_updated, nei_node_ebd, h2, nlist_mask, sw, - self.axis_neuron, + n2e_index, + nb, + nloc, ) - if not self.use_dynamic_sel - else symmetrization_op_dynamic( + node_edge_update = self._compute_node_edge_message( + xp, + node_ebd, + node_ebd_ext, + e_updated, nei_node_ebd, - h2, sw, - owner=n2e_index, - num_owner=nb * nloc, - nb=nb, - nloc=nloc, - scale_factor=self.dynamic_e_sel ** (-0.5), - axis_neuron=self.axis_neuron, + nlist, + n2e_index, + n_ext2e_index, + nb, + nloc, ) - ) - node_sym = self.act(self.node_sym_linear(xp.concat(node_sym_list, axis=-1))) - n_update_list.append(node_sym) - if not self.optim_update: - if not self.use_dynamic_sel: - # nb x nloc x nnei x (n_dim * 2 + e_dim) - edge_info = xp.concat( - [ - xp.tile( - xp.reshape(node_ebd, (nb, nloc, 1, self.n_dim)), - (1, 1, self.nnei, 1), - ), - nei_node_ebd, - edge_ebd, - ], - axis=-1, + n_update_list: list[Array] = [node_ebd, node_self_mlp, node_sym] + if self.n_multi_edge_message > 1: + node_edge_update_mul_head = xp.reshape( + node_edge_update, + (nb, nloc, self.n_multi_edge_message, self.n_dim), ) + for head_index in range(self.n_multi_edge_message): + n_update_list.append(node_edge_update_mul_head[:, :, head_index, :]) else: - # n_edge x (n_dim * 2 + e_dim) - edge_info = xp.concat( - [ - xp.take( - xp.reshape(node_ebd, (-1, self.n_dim)), - n2e_index, - axis=0, - ), - nei_node_ebd, - edge_ebd, - ], - axis=-1, - ) - else: - edge_info = None + n_update_list.append(node_edge_update) + n_updated = self.list_update(n_update_list, "node") - # node edge message - # nb x nloc x nnei x (h * n_dim) - if not self.optim_update: - assert edge_info is not None - node_edge_update = self.act( - self.node_edge_linear(edge_info) - ) * xp.expand_dims(sw, axis=-1) - else: - node_edge_update = self.act( - self.optim_edge_update( - node_ebd, - node_ebd_ext, - edge_ebd, - nlist, - "node", - ) - if not self.use_dynamic_sel - else self.optim_edge_update_dynamic( - node_ebd, - node_ebd_ext, - edge_ebd, - n2e_index, - n_ext2e_index, - "node", - ) - ) * xp.expand_dims(sw, axis=-1) + return n_updated, e_updated, a_updated - node_edge_update = ( - (xp.sum(node_edge_update, axis=-2) / self.nnei) - if not self.use_dynamic_sel - else ( - xp.reshape( - aggregate( - node_edge_update, - n2e_index, - average=False, - num_owner=nb * nloc, - ), - (nb, nloc, node_edge_update.shape[-1]), - ) - / self.dynamic_e_sel - ) + # === Parallel update path === + n_update_list: list[Array] = [node_ebd] + e_update_list: list[Array] = [edge_ebd] + a_update_list: list[Array] = [angle_ebd] + + # node self mlp + node_self_mlp = self.act(self.node_self_mlp(node_ebd)) + n_update_list.append(node_self_mlp) + + # node sym (grrg + drrd) + node_sym = self._compute_node_sym( + xp, + edge_ebd, + nei_node_ebd, + h2, + nlist_mask, + sw, + n2e_index, + nb, + nloc, + ) + n_update_list.append(node_sym) + + # node edge message + node_edge_update = self._compute_node_edge_message( + xp, + node_ebd, + node_ebd_ext, + edge_ebd, + nei_node_ebd, + sw, + nlist, + n2e_index, + n_ext2e_index, + nb, + nloc, ) if self.n_multi_edge_message > 1: # nb x nloc x h x n_dim @@ -2038,234 +2036,58 @@ def call( n_updated = self.list_update(n_update_list, "node") # edge self message - if not self.optim_update: - assert edge_info is not None - edge_self_update = self.act(self.edge_self_linear(edge_info)) - else: - edge_self_update = self.act( - self.optim_edge_update( - node_ebd, - node_ebd_ext, - edge_ebd, - nlist, - "edge", - ) - if not self.use_dynamic_sel - else self.optim_edge_update_dynamic( - node_ebd, - node_ebd_ext, - edge_ebd, - n2e_index, - n_ext2e_index, - "edge", - ) - ) e_update_list.append(edge_self_update) if self.update_angle: assert self.angle_self_linear is not None assert self.edge_angle_linear1 is not None assert self.edge_angle_linear2 is not None - # get angle info - if self.a_compress_rate != 0: - if not self.a_compress_use_split: - assert self.a_compress_n_linear is not None - assert self.a_compress_e_linear is not None - node_ebd_for_angle = self.a_compress_n_linear(node_ebd) - edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd) - else: - # use the first a_compress_dim dim for node and edge - node_ebd_for_angle = node_ebd[..., : self.n_a_compress_dim] - edge_ebd_for_angle = edge_ebd[..., : self.e_a_compress_dim] - else: - node_ebd_for_angle = node_ebd - edge_ebd_for_angle = edge_ebd - - if not self.use_dynamic_sel: - # nb x nloc x a_nnei x e_dim - edge_ebd_for_angle = edge_ebd_for_angle[..., : self.a_sel, :] - # nb x nloc x a_nnei x e_dim - edge_ebd_for_angle = xp.where( - xp.expand_dims(a_nlist_mask, axis=-1), - edge_ebd_for_angle, - xp.zeros_like(edge_ebd_for_angle), - ) - if not self.optim_update: - # nb x nloc x a_nnei x a_nnei x n_dim [OR] n_angle x n_dim - node_for_angle_info = ( - xp.tile( - xp.reshape( - node_ebd_for_angle, (nb, nloc, 1, 1, self.n_a_compress_dim) - ), - (1, 1, self.a_sel, self.a_sel, 1), - ) - if not self.use_dynamic_sel - else xp.take( - xp.reshape(node_ebd_for_angle, (-1, self.n_a_compress_dim)), - n2a_index, - axis=0, - ) - ) - # nb x nloc x (a_nnei) x a_nnei x e_dim [OR] n_angle x e_dim - edge_for_angle_k = ( - xp.tile( - xp.reshape( - edge_ebd_for_angle, - (nb, nloc, 1, self.a_sel, self.e_a_compress_dim), - ), - (1, 1, self.a_sel, 1, 1), - ) - if not self.use_dynamic_sel - else xp.take( - edge_ebd_for_angle, - eik2a_index, - axis=0, - ) - ) - # nb x nloc x a_nnei x (a_nnei) x e_dim - edge_for_angle_j = ( - xp.tile( - xp.reshape( - edge_ebd_for_angle, - (nb, nloc, self.a_sel, 1, self.e_a_compress_dim), - ), - (1, 1, 1, self.a_sel, 1), - ) - if not self.use_dynamic_sel - else xp.take( - edge_ebd_for_angle, - eij2a_index, - axis=0, - ) - ) - # nb x nloc x a_nnei x a_nnei x (e_dim + e_dim) - edge_for_angle_info = xp.concat( - [edge_for_angle_k, edge_for_angle_j], axis=-1 - ) - angle_info_list = [angle_ebd] - angle_info_list.append(node_for_angle_info) - angle_info_list.append(edge_for_angle_info) - # nb x nloc x a_nnei x a_nnei x (a + n_dim + e_dim*2) or (a + a/c + a/c) - # [OR] - # n_angle x (a + n_dim + e_dim*2) or (a + a/c + a/c) - angle_info = xp.concat(angle_info_list, axis=-1) - else: - angle_info = None + node_for_a, edge_for_a = self._prepare_angle_embeddings( + xp, node_ebd, edge_ebd, a_nlist_mask + ) # edge angle message - # nb x nloc x a_nnei x a_nnei x e_dim [OR] n_angle x e_dim - if not self.optim_update: - assert angle_info is not None - edge_angle_update = self.act(self.edge_angle_linear1(angle_info)) - else: - edge_angle_update = self.act( - self.optim_angle_update( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - "edge", - ) - if not self.use_dynamic_sel - else self.optim_angle_update_dynamic( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - n2a_index, - eij2a_index, - eik2a_index, - "edge", - ) - ) - if not self.use_dynamic_sel: - # nb x nloc x a_nnei x a_nnei x e_dim - weighted_edge_angle_update = ( - a_sw[:, :, :, xp.newaxis, xp.newaxis] - * a_sw[:, :, xp.newaxis, :, xp.newaxis] - * edge_angle_update - ) - # nb x nloc x a_nnei x e_dim - reduced_edge_angle_update = xp.sum( - weighted_edge_angle_update, axis=-2 - ) / (self.a_sel**0.5) - # nb x nloc x nnei x e_dim - padding_edge_angle_update = xp.concat( - [ - reduced_edge_angle_update, - xp.zeros( - (nb, nloc, self.nnei - self.a_sel, self.e_dim), - dtype=edge_ebd.dtype, - device=array_api_compat.device(edge_ebd), - ), - ], - axis=2, - ) - else: - # n_angle x e_dim - weighted_edge_angle_update = edge_angle_update * xp.expand_dims( - a_sw, axis=-1 - ) - # n_edge x e_dim - padding_edge_angle_update = aggregate( - weighted_edge_angle_update, - eij2a_index, - average=False, - num_owner=n_edge, - ) / (self.dynamic_a_sel**0.5) - - if not self.smooth_edge_update: - # will be deprecated in the future - # not support dynamic index, will pass anyway - if self.use_dynamic_sel: - raise NotImplementedError( - "smooth_edge_update must be True when use_dynamic_sel is True!" - ) - full_mask = xp.concat( - [ - a_nlist_mask, - xp.zeros( - (nb, nloc, self.nnei - self.a_sel), - dtype=a_nlist_mask.dtype, - device=array_api_compat.device(a_nlist_mask), - ), - ], - axis=-1, - ) - padding_edge_angle_update = xp.where( - xp.expand_dims(full_mask, axis=-1), - padding_edge_angle_update, - edge_ebd, - ) - e_update_list.append( - self.act(self.edge_angle_linear2(padding_edge_angle_update)) + edge_angle_update = self._compute_angle_update( + xp, + angle_ebd, + node_for_a, + edge_for_a, + "edge", + n2a_index, + eij2a_index, + eik2a_index, + nb, + nloc, ) + edge_angle_processed = self._compute_edge_angle_reduction( + xp, + edge_angle_update, + edge_ebd, + a_sw, + a_nlist_mask, + nb, + nloc, + n_edge, + eij2a_index, + ) + e_update_list.append(edge_angle_processed) # update edge_ebd e_updated = self.list_update(e_update_list, "edge") # angle self message - # nb x nloc x a_nnei x a_nnei x dim_a - if not self.optim_update: - assert angle_info is not None - angle_self_update = self.act(self.angle_self_linear(angle_info)) - else: - angle_self_update = self.act( - self.optim_angle_update( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - "angle", - ) - if not self.use_dynamic_sel - else self.optim_angle_update_dynamic( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - n2a_index, - eij2a_index, - eik2a_index, - "angle", - ) - ) + angle_self_update = self._compute_angle_update( + xp, + angle_ebd, + node_for_a, + edge_for_a, + "angle", + n2a_index, + eij2a_index, + eik2a_index, + nb, + nloc, + ) a_update_list.append(angle_self_update) else: # update edge_ebd diff --git a/deepmd/pt/model/descriptor/repflow_layer.py b/deepmd/pt/model/descriptor/repflow_layer.py index 57c9368839..a094105487 100644 --- a/deepmd/pt/model/descriptor/repflow_layer.py +++ b/deepmd/pt/model/descriptor/repflow_layer.py @@ -702,41 +702,17 @@ def optim_edge_update_dynamic( result_update = bias + sub_node_update + sub_edge_update + sub_node_ext_update return result_update - def _forward_sequential( + def _compute_edge_self_update( self, node_ebd: torch.Tensor, node_ebd_ext: torch.Tensor, edge_ebd: torch.Tensor, - h2: torch.Tensor, - angle_ebd: torch.Tensor, - nlist: torch.Tensor, - nlist_mask: torch.Tensor, - sw: torch.Tensor, - a_nlist_mask: torch.Tensor, - a_sw: torch.Tensor, nei_node_ebd: torch.Tensor, + nlist: torch.Tensor, n2e_index: torch.Tensor, n_ext2e_index: torch.Tensor, - n2a_index: torch.Tensor, - eij2a_index: torch.Tensor, - eik2a_index: torch.Tensor, - nb: int, - nloc: int, - nnei: int, - nall: int, - n_edge: int | None, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Sequential update path: edge_self → angle_self → edge_angle → node. - - Only supports update_style='res_residual'. - """ - assert self.edge_angle_linear1 is not None - assert self.edge_angle_linear2 is not None - assert self.angle_self_linear is not None - - # ==================================================================== - # Phase 1: Edge self update (uses original node_ebd, edge_ebd) - # ==================================================================== + ) -> torch.Tensor: + """Compute edge self update.""" if not self.optim_update: if not self.use_dynamic_sel: edge_info = torch.cat( @@ -758,9 +734,9 @@ def _forward_sequential( ], dim=-1, ) - edge_self_update = self.act(self.edge_self_linear(edge_info)) + return self.act(self.edge_self_linear(edge_info)) else: - edge_self_update = self.act( + return self.act( self.optim_edge_update( node_ebd, node_ebd_ext, @@ -779,36 +755,50 @@ def _forward_sequential( ) ) - # Apply edge self residual: edge_ebd_s1 = edge_ebd + e_residual[0] * edge_self_update - edge_ebd_s1 = edge_ebd + self.e_residual[0] * edge_self_update - - # ==================================================================== - # Phase 2: Angle self update (uses original node_ebd, updated edge_ebd_s1) - # ==================================================================== - # Prepare edge for angle from edge_ebd_s1 (updated edge) + def _prepare_angle_embeddings( + self, + node_ebd: torch.Tensor, + edge_ebd: torch.Tensor, + a_nlist_mask: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Prepare compressed node/edge embeddings for angle computation.""" if self.a_compress_rate != 0: if not self.a_compress_use_split: assert self.a_compress_n_linear is not None assert self.a_compress_e_linear is not None node_ebd_for_angle = self.a_compress_n_linear(node_ebd) - edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd_s1) + edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd) else: node_ebd_for_angle = node_ebd[..., : self.n_a_compress_dim] - edge_ebd_for_angle = edge_ebd_s1[..., : self.e_a_compress_dim] + edge_ebd_for_angle = edge_ebd[..., : self.e_a_compress_dim] else: node_ebd_for_angle = node_ebd - edge_ebd_for_angle = edge_ebd_s1 + edge_ebd_for_angle = edge_ebd if not self.use_dynamic_sel: edge_ebd_for_angle = edge_ebd_for_angle[..., : self.a_sel, :] edge_ebd_for_angle = torch.where( a_nlist_mask.unsqueeze(-1), edge_ebd_for_angle, 0.0 ) + return node_ebd_for_angle, edge_ebd_for_angle - # Initialize for JIT: these are only used in non-optim_update path - node_for_angle_info = angle_ebd # placeholder, overwritten below - edge_for_angle_info = angle_ebd # placeholder, overwritten below + def _compute_angle_update( + self, + angle_ebd: torch.Tensor, + node_ebd_for_angle: torch.Tensor, + edge_ebd_for_angle: torch.Tensor, + feat: str, + n2a_index: torch.Tensor, + eij2a_index: torch.Tensor, + eik2a_index: torch.Tensor, + ) -> torch.Tensor: + """Compute angle-based update (for edge_angle or angle_self). + Parameters + ---------- + feat : str + "edge" for edge_angle_linear1, "angle" for angle_self_linear. + """ if not self.optim_update: node_for_angle_info = ( torch.tile( @@ -838,60 +828,71 @@ def _forward_sequential( angle_info = torch.cat( [angle_ebd, node_for_angle_info, edge_for_angle_info], dim=-1 ) - angle_self_update = self.act(self.angle_self_linear(angle_info)) + if feat == "edge": + assert self.edge_angle_linear1 is not None + return self.act(self.edge_angle_linear1(angle_info)) + else: + assert self.angle_self_linear is not None + return self.act(self.angle_self_linear(angle_info)) else: - angle_self_update = self.act( - self.optim_angle_update( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - "angle", + if feat == "edge": + return self.act( + self.optim_angle_update( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + "edge", + ) + if not self.use_dynamic_sel + else self.optim_angle_update_dynamic( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + n2a_index, + eij2a_index, + eik2a_index, + "edge", + ) ) - if not self.use_dynamic_sel - else self.optim_angle_update_dynamic( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - n2a_index, - eij2a_index, - eik2a_index, - "angle", + else: + return self.act( + self.optim_angle_update( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + "angle", + ) + if not self.use_dynamic_sel + else self.optim_angle_update_dynamic( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + n2a_index, + eij2a_index, + eik2a_index, + "angle", + ) ) - ) - # Apply angle self residual: angle_ebd_s2 = angle_ebd + a_residual[0] * angle_self_update - a_updated = angle_ebd + self.a_residual[0] * angle_self_update - - # ==================================================================== - # Phase 3: Edge angle update (uses updated angle_ebd_s2, updated edge_ebd_s1) - # ==================================================================== - if not self.optim_update: - # Rebuild angle_info with updated angle (a_updated) - angle_info_s2 = torch.cat( - [a_updated, node_for_angle_info, edge_for_angle_info], dim=-1 - ) - edge_angle_update = self.act(self.edge_angle_linear1(angle_info_s2)) - else: - edge_angle_update = self.act( - self.optim_angle_update( - a_updated, - node_ebd_for_angle, - edge_ebd_for_angle, - "edge", - ) - if not self.use_dynamic_sel - else self.optim_angle_update_dynamic( - a_updated, - node_ebd_for_angle, - edge_ebd_for_angle, - n2a_index, - eij2a_index, - eik2a_index, - "edge", - ) - ) + def _compute_edge_angle_reduction( + self, + edge_angle_update: torch.Tensor, + edge_ebd_fallback: torch.Tensor, + a_sw: torch.Tensor, + a_nlist_mask: torch.Tensor, + nb: int, + nloc: int, + n_edge: int | None, + eij2a_index: torch.Tensor, + ) -> torch.Tensor: + """Reduce edge angle update over angle dimension, pad, and apply linear2. - # Reduce edge angle update over angle dimension + Parameters + ---------- + edge_ebd_fallback : torch.Tensor + Edge embedding used for non-smooth padding fallback. + """ + assert self.edge_angle_linear2 is not None if not self.use_dynamic_sel: weighted_edge_angle_update = ( a_sw.unsqueeze(-1).unsqueeze(-1) @@ -906,8 +907,8 @@ def _forward_sequential( reduced_edge_angle_update, torch.zeros( [nb, nloc, self.nnei - self.a_sel, self.e_dim], - dtype=edge_ebd.dtype, - device=edge_ebd.device, + dtype=edge_ebd_fallback.dtype, + device=edge_ebd_fallback.device, ), ], dim=2, @@ -939,49 +940,57 @@ def _forward_sequential( dim=-1, ) padding_edge_angle_update = torch.where( - full_mask.unsqueeze(-1), padding_edge_angle_update, edge_ebd + full_mask.unsqueeze(-1), + padding_edge_angle_update, + edge_ebd_fallback, ) - edge_angle_processed = self.act( - self.edge_angle_linear2(padding_edge_angle_update) - ) - - # Apply edge angle residual on top of edge_ebd_s1 (no recomputation) - e_updated = edge_ebd_s1 + self.e_residual[1] * edge_angle_processed + return self.act(self.edge_angle_linear2(padding_edge_angle_update)) - # ==================================================================== - # Phase 4: Node edge message (uses e_updated) - # ==================================================================== + def _compute_node_edge_message( + self, + node_ebd: torch.Tensor, + node_ebd_ext: torch.Tensor, + edge_ebd: torch.Tensor, + nei_node_ebd: torch.Tensor, + sw: torch.Tensor, + nlist: torch.Tensor, + n2e_index: torch.Tensor, + n_ext2e_index: torch.Tensor, + nb: int, + nloc: int, + ) -> torch.Tensor: + """Compute node edge message and reduce over neighbor dimension.""" if not self.optim_update: if not self.use_dynamic_sel: - edge_info_updated = torch.cat( + edge_info = torch.cat( [ torch.tile(node_ebd.unsqueeze(-2), [1, 1, self.nnei, 1]), nei_node_ebd, - e_updated, + edge_ebd, ], dim=-1, ) else: - edge_info_updated = torch.cat( + edge_info = torch.cat( [ torch.index_select( node_ebd.reshape(-1, self.n_dim), 0, n2e_index ), nei_node_ebd, - e_updated, + edge_ebd, ], dim=-1, ) node_edge_update = self.act( - self.node_edge_linear(edge_info_updated) + self.node_edge_linear(edge_info) ) * sw.unsqueeze(-1) else: node_edge_update = self.act( self.optim_edge_update( node_ebd, node_ebd_ext, - e_updated, + edge_ebd, nlist, "node", ) @@ -989,7 +998,7 @@ def _forward_sequential( else self.optim_edge_update_dynamic( node_ebd, node_ebd_ext, - e_updated, + edge_ebd, n2e_index, n_ext2e_index, "node", @@ -1009,21 +1018,24 @@ def _forward_sequential( / self.dynamic_e_sel ) ) + return node_edge_update - # ==================================================================== - # Phase 5: Node updates (node_self, node_sym with e_updated, node_edge) - # ==================================================================== - n_update_list: list[torch.Tensor] = [node_ebd] - - # node self mlp (uses original node_ebd) - node_self_mlp = self.act(self.node_self_mlp(node_ebd)) - n_update_list.append(node_self_mlp) - - # node sym using e_updated + def _compute_node_sym( + self, + edge_ebd: torch.Tensor, + nei_node_ebd: torch.Tensor, + h2: torch.Tensor, + nlist_mask: torch.Tensor, + sw: torch.Tensor, + n2e_index: torch.Tensor, + nb: int, + nloc: int, + ) -> torch.Tensor: + """Compute node symmetrization update (grrg + drrd).""" node_sym_list: list[torch.Tensor] = [] node_sym_list.append( self.symmetrization_op( - e_updated, + edge_ebd, h2, nlist_mask, sw, @@ -1031,7 +1043,7 @@ def _forward_sequential( ) if not self.use_dynamic_sel else self.symmetrization_op_dynamic( - e_updated, + edge_ebd, h2, sw, owner=n2e_index, @@ -1063,21 +1075,7 @@ def _forward_sequential( axis_neuron=self.axis_neuron, ) ) - node_sym = self.act(self.node_sym_linear(torch.cat(node_sym_list, dim=-1))) - n_update_list.append(node_sym) - - if self.n_multi_edge_message > 1: - node_edge_update_mul_head = node_edge_update.view( - nb, nloc, self.n_multi_edge_message, self.n_dim - ) - for head_index in range(self.n_multi_edge_message): - n_update_list.append(node_edge_update_mul_head[..., head_index, :]) - else: - n_update_list.append(node_edge_update) - - n_updated = self.list_update(n_update_list, "node") - - return n_updated, e_updated, a_updated + return self.act(self.node_sym_linear(torch.cat(node_sym_list, dim=-1))) def forward( self, @@ -1168,31 +1166,95 @@ def forward( ) ) + # Edge self update (always from original embeddings) + edge_self_update = self._compute_edge_self_update( + node_ebd, + node_ebd_ext, + edge_ebd, + nei_node_ebd, + nlist, + n2e_index, + n_ext2e_index, + ) + if self.sequential_update and self.update_angle: - return self._forward_sequential( - node_ebd, - node_ebd_ext, - edge_ebd, - h2, + # === Sequential update path === + # Phase 1: Apply edge self residual + edge_ebd_s1 = edge_ebd + self.e_residual[0] * edge_self_update + + # Phase 2: Angle self (uses updated edge_ebd_s1) + node_for_a, edge_for_a = self._prepare_angle_embeddings( + node_ebd, edge_ebd_s1, a_nlist_mask + ) + angle_self_update = self._compute_angle_update( angle_ebd, - nlist, - nlist_mask, - sw, - a_nlist_mask, - a_sw, - nei_node_ebd, - n2e_index, - n_ext2e_index, + node_for_a, + edge_for_a, + "angle", n2a_index, eij2a_index, eik2a_index, + ) + a_updated = angle_ebd + self.a_residual[0] * angle_self_update + + # Phase 3: Edge angle (uses updated angle a_updated + edge_ebd_s1) + edge_angle_update = self._compute_angle_update( + a_updated, + node_for_a, + edge_for_a, + "edge", + n2a_index, + eij2a_index, + eik2a_index, + ) + edge_angle_processed = self._compute_edge_angle_reduction( + edge_angle_update, + edge_ebd_s1, + a_sw, + a_nlist_mask, nb, nloc, - nnei, - nall, n_edge, + eij2a_index, + ) + e_updated = edge_ebd_s1 + self.e_residual[1] * edge_angle_processed + + # Phase 4+5: Node updates (uses e_updated) + node_self_mlp = self.act(self.node_self_mlp(node_ebd)) + node_sym = self._compute_node_sym( + e_updated, nei_node_ebd, h2, nlist_mask, sw, n2e_index, nb, nloc + ) + node_edge_update = self._compute_node_edge_message( + node_ebd, + node_ebd_ext, + e_updated, + nei_node_ebd, + sw, + nlist, + n2e_index, + n_ext2e_index, + nb, + nloc, ) + n_update_list: list[torch.Tensor] = [ + node_ebd, + node_self_mlp, + node_sym, + ] + if self.n_multi_edge_message > 1: + node_edge_update_mul_head = node_edge_update.view( + nb, nloc, self.n_multi_edge_message, self.n_dim + ) + for head_index in range(self.n_multi_edge_message): + n_update_list.append(node_edge_update_mul_head[..., head_index, :]) + else: + n_update_list.append(node_edge_update) + n_updated = self.list_update(n_update_list, "node") + + return n_updated, e_updated, a_updated + + # === Parallel update path === n_update_list: list[torch.Tensor] = [node_ebd] e_update_list: list[torch.Tensor] = [edge_ebd] a_update_list: list[torch.Tensor] = [angle_ebd] @@ -1202,118 +1264,24 @@ def forward( n_update_list.append(node_self_mlp) # node sym (grrg + drrd) - node_sym_list: list[torch.Tensor] = [] - node_sym_list.append( - self.symmetrization_op( - edge_ebd, - h2, - nlist_mask, - sw, - self.axis_neuron, - ) - if not self.use_dynamic_sel - else self.symmetrization_op_dynamic( - edge_ebd, - h2, - sw, - owner=n2e_index, - num_owner=nb * nloc, - nb=nb, - nloc=nloc, - scale_factor=self.dynamic_e_sel ** (-0.5), - axis_neuron=self.axis_neuron, - ) + node_sym = self._compute_node_sym( + edge_ebd, nei_node_ebd, h2, nlist_mask, sw, n2e_index, nb, nloc ) - node_sym_list.append( - self.symmetrization_op( - nei_node_ebd, - h2, - nlist_mask, - sw, - self.axis_neuron, - ) - if not self.use_dynamic_sel - else self.symmetrization_op_dynamic( - nei_node_ebd, - h2, - sw, - owner=n2e_index, - num_owner=nb * nloc, - nb=nb, - nloc=nloc, - scale_factor=self.dynamic_e_sel ** (-0.5), - axis_neuron=self.axis_neuron, - ) - ) - node_sym = self.act(self.node_sym_linear(torch.cat(node_sym_list, dim=-1))) n_update_list.append(node_sym) - if not self.optim_update: - if not self.use_dynamic_sel: - # nb x nloc x nnei x (n_dim * 2 + e_dim) - edge_info = torch.cat( - [ - torch.tile(node_ebd.unsqueeze(-2), [1, 1, self.nnei, 1]), - nei_node_ebd, - edge_ebd, - ], - dim=-1, - ) - else: - # n_edge x (n_dim * 2 + e_dim) - edge_info = torch.cat( - [ - torch.index_select( - node_ebd.reshape(-1, self.n_dim), 0, n2e_index - ), - nei_node_ebd, - edge_ebd, - ], - dim=-1, - ) - else: - edge_info = None - # node edge message - # nb x nloc x nnei x (h * n_dim) - if not self.optim_update: - assert edge_info is not None - node_edge_update = self.act( - self.node_edge_linear(edge_info) - ) * sw.unsqueeze(-1) - else: - node_edge_update = self.act( - self.optim_edge_update( - node_ebd, - node_ebd_ext, - edge_ebd, - nlist, - "node", - ) - if not self.use_dynamic_sel - else self.optim_edge_update_dynamic( - node_ebd, - node_ebd_ext, - edge_ebd, - n2e_index, - n_ext2e_index, - "node", - ) - ) * sw.unsqueeze(-1) - node_edge_update = ( - (torch.sum(node_edge_update, dim=-2) / self.nnei) - if not self.use_dynamic_sel - else ( - aggregate( - node_edge_update, - n2e_index, - average=False, - num_owner=nb * nloc, - ).reshape(nb, nloc, node_edge_update.shape[-1]) - / self.dynamic_e_sel - ) + node_edge_update = self._compute_node_edge_message( + node_ebd, + node_ebd_ext, + edge_ebd, + nei_node_ebd, + sw, + nlist, + n2e_index, + n_ext2e_index, + nb, + nloc, ) - if self.n_multi_edge_message > 1: # nb x nloc x h x n_dim node_edge_update_mul_head = node_edge_update.view( @@ -1327,211 +1295,51 @@ def forward( n_updated = self.list_update(n_update_list, "node") # edge self message - if not self.optim_update: - assert edge_info is not None - edge_self_update = self.act(self.edge_self_linear(edge_info)) - else: - edge_self_update = self.act( - self.optim_edge_update( - node_ebd, - node_ebd_ext, - edge_ebd, - nlist, - "edge", - ) - if not self.use_dynamic_sel - else self.optim_edge_update_dynamic( - node_ebd, - node_ebd_ext, - edge_ebd, - n2e_index, - n_ext2e_index, - "edge", - ) - ) e_update_list.append(edge_self_update) if self.update_angle: assert self.angle_self_linear is not None assert self.edge_angle_linear1 is not None assert self.edge_angle_linear2 is not None - # get angle info - if self.a_compress_rate != 0: - if not self.a_compress_use_split: - assert self.a_compress_n_linear is not None - assert self.a_compress_e_linear is not None - node_ebd_for_angle = self.a_compress_n_linear(node_ebd) - edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd) - else: - # use the first a_compress_dim dim for node and edge - node_ebd_for_angle = node_ebd[..., : self.n_a_compress_dim] - edge_ebd_for_angle = edge_ebd[..., : self.e_a_compress_dim] - else: - node_ebd_for_angle = node_ebd - edge_ebd_for_angle = edge_ebd - if not self.use_dynamic_sel: - # nb x nloc x a_nnei x e_dim - edge_ebd_for_angle = edge_ebd_for_angle[..., : self.a_sel, :] - # nb x nloc x a_nnei x e_dim - edge_ebd_for_angle = torch.where( - a_nlist_mask.unsqueeze(-1), edge_ebd_for_angle, 0.0 - ) - if not self.optim_update: - # nb x nloc x a_nnei x a_nnei x n_dim [OR] n_angle x n_dim - node_for_angle_info = ( - torch.tile( - node_ebd_for_angle.unsqueeze(2).unsqueeze(2), - (1, 1, self.a_sel, self.a_sel, 1), - ) - if not self.use_dynamic_sel - else torch.index_select( - node_ebd_for_angle.reshape(-1, self.n_a_compress_dim), - 0, - n2a_index, - ) - ) - - # nb x nloc x (a_nnei) x a_nnei x e_dim [OR] n_angle x e_dim - edge_for_angle_k = ( - torch.tile( - edge_ebd_for_angle.unsqueeze(2), (1, 1, self.a_sel, 1, 1) - ) - if not self.use_dynamic_sel - else torch.index_select(edge_ebd_for_angle, 0, eik2a_index) - ) - # nb x nloc x a_nnei x (a_nnei) x e_dim [OR] n_angle x e_dim - edge_for_angle_j = ( - torch.tile( - edge_ebd_for_angle.unsqueeze(3), (1, 1, 1, self.a_sel, 1) - ) - if not self.use_dynamic_sel - else torch.index_select(edge_ebd_for_angle, 0, eij2a_index) - ) - # nb x nloc x a_nnei x a_nnei x (e_dim + e_dim) [OR] n_angle x (e_dim + e_dim) - edge_for_angle_info = torch.cat( - [edge_for_angle_k, edge_for_angle_j], dim=-1 - ) - angle_info_list = [angle_ebd] - angle_info_list.append(node_for_angle_info) - angle_info_list.append(edge_for_angle_info) - # nb x nloc x a_nnei x a_nnei x (a + n_dim + e_dim*2) or (a + a/c + a/c) - # [OR] - # n_angle x (a + n_dim + e_dim*2) or (a + a/c + a/c) - angle_info = torch.cat(angle_info_list, dim=-1) - else: - angle_info = None + node_for_a, edge_for_a = self._prepare_angle_embeddings( + node_ebd, edge_ebd, a_nlist_mask + ) # edge angle message - # nb x nloc x a_nnei x a_nnei x e_dim [OR] n_angle x e_dim - if not self.optim_update: - assert angle_info is not None - edge_angle_update = self.act(self.edge_angle_linear1(angle_info)) - else: - edge_angle_update = self.act( - self.optim_angle_update( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - "edge", - ) - if not self.use_dynamic_sel - else self.optim_angle_update_dynamic( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - n2a_index, - eij2a_index, - eik2a_index, - "edge", - ) - ) - - if not self.use_dynamic_sel: - # nb x nloc x a_nnei x a_nnei x e_dim - weighted_edge_angle_update = ( - a_sw.unsqueeze(-1).unsqueeze(-1) - * a_sw.unsqueeze(-2).unsqueeze(-1) - * edge_angle_update - ) - # nb x nloc x a_nnei x e_dim - reduced_edge_angle_update = torch.sum( - weighted_edge_angle_update, dim=-2 - ) / (self.a_sel**0.5) - # nb x nloc x nnei x e_dim - padding_edge_angle_update = torch.concat( - [ - reduced_edge_angle_update, - torch.zeros( - [nb, nloc, self.nnei - self.a_sel, self.e_dim], - dtype=edge_ebd.dtype, - device=edge_ebd.device, - ), - ], - dim=2, - ) - else: - # n_angle x e_dim - weighted_edge_angle_update = edge_angle_update * a_sw.unsqueeze(-1) - # n_edge x e_dim - padding_edge_angle_update = aggregate( - weighted_edge_angle_update, - eij2a_index, - average=False, - num_owner=n_edge, - ) / (self.dynamic_a_sel**0.5) - - if not self.smooth_edge_update: - # will be deprecated in the future - # not support dynamic index, will pass anyway - if self.use_dynamic_sel: - raise NotImplementedError( - "smooth_edge_update must be True when use_dynamic_sel is True!" - ) - full_mask = torch.concat( - [ - a_nlist_mask, - torch.zeros( - [nb, nloc, self.nnei - self.a_sel], - dtype=a_nlist_mask.dtype, - device=a_nlist_mask.device, - ), - ], - dim=-1, - ) - padding_edge_angle_update = torch.where( - full_mask.unsqueeze(-1), padding_edge_angle_update, edge_ebd - ) - e_update_list.append( - self.act(self.edge_angle_linear2(padding_edge_angle_update)) + edge_angle_update = self._compute_angle_update( + angle_ebd, + node_for_a, + edge_for_a, + "edge", + n2a_index, + eij2a_index, + eik2a_index, + ) + edge_angle_processed = self._compute_edge_angle_reduction( + edge_angle_update, + edge_ebd, + a_sw, + a_nlist_mask, + nb, + nloc, + n_edge, + eij2a_index, ) + e_update_list.append(edge_angle_processed) # update edge_ebd e_updated = self.list_update(e_update_list, "edge") # angle self message - # nb x nloc x a_nnei x a_nnei x dim_a - if not self.optim_update: - assert angle_info is not None - angle_self_update = self.act(self.angle_self_linear(angle_info)) - else: - angle_self_update = self.act( - self.optim_angle_update( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - "angle", - ) - if not self.use_dynamic_sel - else self.optim_angle_update_dynamic( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - n2a_index, - eij2a_index, - eik2a_index, - "angle", - ) - ) + angle_self_update = self._compute_angle_update( + angle_ebd, + node_for_a, + edge_for_a, + "angle", + n2a_index, + eij2a_index, + eik2a_index, + ) a_update_list.append(angle_self_update) else: # update edge_ebd diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 70a7985702..74e8aad13e 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1555,7 +1555,7 @@ def dpa3_repflow_args() -> list[Argument]: "When True, updates are applied sequentially: edge self → angle self (using updated edge) " "→ edge angle (using updated angle) → node (using final edge), " "instead of the default parallel mode where all updates use original embeddings. " - "Currently only supports update_style='res_residual'." + "Currently only supports update_style='res_residual' and requires update_angle=True." ) return [ diff --git a/source/tests/consistent/descriptor/test_dpa3.py b/source/tests/consistent/descriptor/test_dpa3.py index b980c584a1..c158d81a93 100644 --- a/source/tests/consistent/descriptor/test_dpa3.py +++ b/source/tests/consistent/descriptor/test_dpa3.py @@ -67,8 +67,8 @@ ("const",), # update_residual_init ([], [[0, 1]]), # exclude_types (True,), # update_angle - (0, 1), # a_compress_rate - (1, 2), # a_compress_e_rate + (1,), # a_compress_rate + (2,), # a_compress_e_rate (True,), # a_compress_use_split (True, False), # optim_update (True, False), # edge_init_use_dist @@ -444,8 +444,8 @@ def atol(self) -> float: ("const",), # update_residual_init ([], [[0, 1]]), # exclude_types (True,), # update_angle - (0, 1), # a_compress_rate - (1, 2), # a_compress_e_rate + (1,), # a_compress_rate + (2,), # a_compress_e_rate (True,), # a_compress_use_split (True, False), # optim_update (True, False), # edge_init_use_dist From 5fe472e7c7a83e8066c4e6099d562386fb78c3a2 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Fri, 8 May 2026 21:11:00 +0800 Subject: [PATCH 3/5] Update test_dpa3.py --- source/tests/consistent/descriptor/test_dpa3.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/source/tests/consistent/descriptor/test_dpa3.py b/source/tests/consistent/descriptor/test_dpa3.py index 6b93b849c0..c3254f68fe 100644 --- a/source/tests/consistent/descriptor/test_dpa3.py +++ b/source/tests/consistent/descriptor/test_dpa3.py @@ -81,6 +81,7 @@ "n_multi_edge_message", "precision", "add_chg_spin_ebd", + "sequential_update", ) @@ -100,6 +101,7 @@ "n_multi_edge_message": 1, "precision": "float64", "add_chg_spin_ebd": False, + "sequential_update": False, } @@ -131,6 +133,7 @@ def dpa3_case(**overrides: Any) -> tuple: dpa3_case(edge_init_use_dist=False), dpa3_case(use_exp_switch=False), dpa3_case(use_dynamic_sel=False), + dpa3_case(sequential_update=True), # One mixed high-risk path to keep interactions covered. dpa3_case( exclude_types=[[0, 1]], @@ -142,6 +145,7 @@ def dpa3_case(**overrides: Any) -> tuple: use_dynamic_sel=False, use_loc_mapping=False, add_chg_spin_ebd=True, + sequential_update=True, ), ) @@ -169,6 +173,7 @@ def dpa3_descriptor_api_case(**overrides: Any) -> tuple: dpa3_descriptor_api_case(edge_init_use_dist=False), dpa3_descriptor_api_case(use_exp_switch=False), dpa3_descriptor_api_case(use_dynamic_sel=False), + dpa3_descriptor_api_case(sequential_update=False), # One mixed high-risk path to keep interactions covered. dpa3_descriptor_api_case( exclude_types=[[0, 1]], @@ -181,6 +186,7 @@ def dpa3_descriptor_api_case(**overrides: Any) -> tuple: use_loc_mapping=False, fix_stat_std=0.0, add_chg_spin_ebd=True, + sequential_update=True, ), ) @@ -268,6 +274,7 @@ def skip_pt(self) -> bool: _n_multi_edge_message, _precision, _add_chg_spin_ebd, + _sequential_update, ) = self.param return CommonTest.skip_pt @@ -290,6 +297,7 @@ def skip_pd(self) -> bool: _precision, add_chg_spin_ebd, sequential_update, + _sequential_update, ) = self.param return True if add_chg_spin_ebd else CommonTest.skip_pd @@ -311,6 +319,7 @@ def skip_dp(self) -> bool: _n_multi_edge_message, _precision, _add_chg_spin_ebd, + _sequential_update, ) = self.param return CommonTest.skip_dp @@ -332,6 +341,7 @@ def skip_tf(self) -> bool: _n_multi_edge_message, _precision, _add_chg_spin_ebd, + _sequential_update, ) = self.param return True @@ -398,6 +408,7 @@ def setUp(self) -> None: _precision, add_chg_spin_ebd, sequential_update, + _sequential_update, ) = self.param # fparam for charge=5, spin=1 when add_chg_spin_ebd is True self.fparam = ( @@ -504,6 +515,7 @@ def rtol(self) -> float: _n_multi_edge_message, precision, _add_chg_spin_ebd, + _sequential_update, ) = self.param if precision == "float64": return 1e-10 @@ -531,6 +543,7 @@ def atol(self) -> float: _n_multi_edge_message, precision, _add_chg_spin_ebd, + _sequential_update, ) = self.param if precision == "float64": return 1e-6 # need to fix in the future, see issue https://github.com/deepmodeling/deepmd-kit/issues/3786 @@ -567,6 +580,7 @@ def data(self) -> dict: n_multi_edge_message, precision, add_chg_spin_ebd, + sequential_update, ) = self.param return { "ntypes": self.ntypes, @@ -598,6 +612,7 @@ def data(self) -> dict: "update_style": "res_residual", "update_residual": 0.1, "update_residual_init": update_residual_init, + "sequential_update": sequential_update, } ), # kwargs for descriptor From b4d43ee0e4e70edaf31213d5787066e13778eb66 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Fri, 8 May 2026 21:11:29 +0800 Subject: [PATCH 4/5] Update test_dpa3.py --- source/tests/consistent/descriptor/test_dpa3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/tests/consistent/descriptor/test_dpa3.py b/source/tests/consistent/descriptor/test_dpa3.py index c3254f68fe..2a4e2e2b20 100644 --- a/source/tests/consistent/descriptor/test_dpa3.py +++ b/source/tests/consistent/descriptor/test_dpa3.py @@ -173,7 +173,7 @@ def dpa3_descriptor_api_case(**overrides: Any) -> tuple: dpa3_descriptor_api_case(edge_init_use_dist=False), dpa3_descriptor_api_case(use_exp_switch=False), dpa3_descriptor_api_case(use_dynamic_sel=False), - dpa3_descriptor_api_case(sequential_update=False), + dpa3_descriptor_api_case(sequential_update=True), # One mixed high-risk path to keep interactions covered. dpa3_descriptor_api_case( exclude_types=[[0, 1]], From daf68ba739817b8729437200c65afdbed22e2eb9 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Sat, 9 May 2026 16:03:30 +0800 Subject: [PATCH 5/5] Update test_dpa3.py --- source/tests/consistent/descriptor/test_dpa3.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/source/tests/consistent/descriptor/test_dpa3.py b/source/tests/consistent/descriptor/test_dpa3.py index 2a4e2e2b20..2aa0fd931b 100644 --- a/source/tests/consistent/descriptor/test_dpa3.py +++ b/source/tests/consistent/descriptor/test_dpa3.py @@ -296,7 +296,6 @@ def skip_pd(self) -> bool: _n_multi_edge_message, _precision, add_chg_spin_ebd, - sequential_update, _sequential_update, ) = self.param return True if add_chg_spin_ebd else CommonTest.skip_pd @@ -407,7 +406,6 @@ def setUp(self) -> None: _n_multi_edge_message, _precision, add_chg_spin_ebd, - sequential_update, _sequential_update, ) = self.param # fparam for charge=5, spin=1 when add_chg_spin_ebd is True