diff --git a/deepmd/dpmodel/descriptor/dpa3.py b/deepmd/dpmodel/descriptor/dpa3.py index 5f5aea50e5..52b2ca9c2c 100644 --- a/deepmd/dpmodel/descriptor/dpa3.py +++ b/deepmd/dpmodel/descriptor/dpa3.py @@ -170,6 +170,13 @@ class RepFlowArgs: In the dynamic selection case, neighbor-scale normalization will use `e_sel / sel_reduce_factor` or `a_sel / sel_reduce_factor` instead of the raw `e_sel` or `a_sel` values, accommodating larger selection numbers. + sequential_update : bool, optional + Whether to use sequential update mode within each repflow layer. + When True, updates are applied sequentially: edge self → angle self (using updated edge) + → edge angle (using updated angle) → node (using final edge), + instead of the default parallel mode where all updates use original embeddings. + Currently only supports ``update_style='res_residual'`` and requires ``update_angle=True``; + otherwise, a ``ValueError`` will be raised during initialization. """ def __init__( @@ -201,6 +208,7 @@ def __init__( use_exp_switch: bool = False, use_dynamic_sel: bool = False, sel_reduce_factor: float = 10.0, + sequential_update: bool = False, ) -> None: self.n_dim = n_dim self.e_dim = e_dim @@ -231,6 +239,15 @@ def __init__( self.use_exp_switch = use_exp_switch self.use_dynamic_sel = use_dynamic_sel self.sel_reduce_factor = sel_reduce_factor + self.sequential_update = sequential_update + if self.sequential_update: + if self.update_style != "res_residual": + raise ValueError( + "sequential_update only supports update_style='res_residual', " + f"got '{self.update_style}'!" + ) + if not self.update_angle: + raise ValueError("sequential_update requires update_angle=True!") def __getitem__(self, key: str) -> Any: if hasattr(self, key): @@ -266,6 +283,7 @@ def serialize(self) -> dict: "use_exp_switch": self.use_exp_switch, "use_dynamic_sel": self.use_dynamic_sel, "sel_reduce_factor": self.sel_reduce_factor, + "sequential_update": self.sequential_update, } @classmethod @@ -404,6 +422,7 @@ def init_subclass_params(sub_data: dict | Any, sub_class: type) -> Any: use_exp_switch=self.repflow_args.use_exp_switch, use_dynamic_sel=self.repflow_args.use_dynamic_sel, sel_reduce_factor=self.repflow_args.sel_reduce_factor, + sequential_update=self.repflow_args.sequential_update, use_loc_mapping=use_loc_mapping, exclude_types=exclude_types, env_protection=env_protection, diff --git a/deepmd/dpmodel/descriptor/repflows.py b/deepmd/dpmodel/descriptor/repflows.py index 30637dc75a..1af1d15bd9 100644 --- a/deepmd/dpmodel/descriptor/repflows.py +++ b/deepmd/dpmodel/descriptor/repflows.py @@ -230,6 +230,7 @@ def __init__( use_exp_switch: bool = False, use_dynamic_sel: bool = False, sel_reduce_factor: float = 10.0, + sequential_update: bool = False, use_loc_mapping: bool = True, seed: int | list[int] | None = None, trainable: bool = True, @@ -268,6 +269,7 @@ def __init__( self.use_dynamic_sel = use_dynamic_sel self.use_loc_mapping = use_loc_mapping self.sel_reduce_factor = sel_reduce_factor + self.sequential_update = sequential_update if self.use_dynamic_sel and not self.smooth_edge_update: raise NotImplementedError( "smooth_edge_update must be True when use_dynamic_sel is True!" @@ -339,6 +341,7 @@ def __init__( optim_update=self.optim_update, use_dynamic_sel=self.use_dynamic_sel, sel_reduce_factor=self.sel_reduce_factor, + sequential_update=self.sequential_update, smooth_edge_update=self.smooth_edge_update, seed=child_seed(child_seed(seed, 1), ii), trainable=trainable, @@ -757,6 +760,7 @@ def serialize(self) -> dict: "smooth_edge_update": self.smooth_edge_update, "use_dynamic_sel": self.use_dynamic_sel, "sel_reduce_factor": self.sel_reduce_factor, + "sequential_update": self.sequential_update, "use_loc_mapping": self.use_loc_mapping, # variables "edge_embd": self.edge_embd.serialize(), @@ -905,6 +909,7 @@ def __init__( optim_update: bool = True, use_dynamic_sel: bool = False, sel_reduce_factor: float = 10.0, + sequential_update: bool = False, smooth_edge_update: bool = False, activation_function: str = "silu", update_style: str = "res_residual", @@ -954,8 +959,15 @@ def __init__( self.smooth_edge_update = smooth_edge_update self.use_dynamic_sel = use_dynamic_sel self.sel_reduce_factor = sel_reduce_factor + self.sequential_update = sequential_update self.dynamic_e_sel = self.nnei / self.sel_reduce_factor self.dynamic_a_sel = self.a_sel / self.sel_reduce_factor + if self.sequential_update and self.update_style != "res_residual": + raise NotImplementedError( + "sequential_update only supports update_style='res_residual'!" + ) + if self.sequential_update and not self.update_angle: + raise NotImplementedError("sequential_update requires update_angle=True!") assert update_residual_init in [ "norm", @@ -1342,168 +1354,22 @@ def optim_edge_update_dynamic( result_update = bias + sub_node_update + sub_edge_update + sub_node_ext_update return result_update - def call( + def _compute_edge_self_update( self, - node_ebd_ext: Array, # nf x nall x n_dim - edge_ebd: Array, # nf x nloc x nnei x e_dim - h2: Array, # nf x nloc x nnei x 3 - angle_ebd: Array, # nf x nloc x a_nnei x a_nnei x a_dim - nlist: Array, # nf x nloc x nnei - nlist_mask: Array, # nf x nloc x nnei - sw: Array, # switch func, nf x nloc x nnei - a_nlist: Array, # nf x nloc x a_nnei - a_nlist_mask: Array, # nf x nloc x a_nnei - a_sw: Array, # switch func, nf x nloc x a_nnei - edge_index: Array, # 2 x n_edge - angle_index: Array, # 3 x n_angle - ) -> tuple[Array, Array, Array]: - """ - Parameters - ---------- - node_ebd_ext : nf x nall x n_dim - Extended node embedding. - edge_ebd : nf x nloc x nnei x e_dim - Edge embedding. - h2 : nf x nloc x nnei x 3 - Pair-atom channel, equivariant. - angle_ebd : nf x nloc x a_nnei x a_nnei x a_dim - Angle embedding. - nlist : nf x nloc x nnei - Neighbor list. (padded neis are set to 0) - nlist_mask : nf x nloc x nnei - Masks of the neighbor list. real nei 1 otherwise 0 - sw : nf x nloc x nnei - Switch function. - a_nlist : nf x nloc x a_nnei - Neighbor list for angle. (padded neis are set to 0) - a_nlist_mask : nf x nloc x a_nnei - Masks of the neighbor list for angle. real nei 1 otherwise 0 - a_sw : nf x nloc x a_nnei - Switch function for angle. - edge_index : Optional for dynamic sel, 2 x n_edge - n2e_index : n_edge - Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i). - n_ext2e_index : n_edge - Broadcast indices from extended node(j) to edge(ij). - angle_index : Optional for dynamic sel, 3 x n_angle - n2a_index : n_angle - Broadcast indices from extended node(j) to angle(ijk). - eij2a_index : n_angle - Broadcast indices from extended edge(ij) to angle(ijk), or reduction indices from angle(ijk) to edge(ij). - eik2a_index : n_angle - Broadcast indices from extended edge(ik) to angle(ijk). - - Returns - ------- - n_updated: nf x nloc x n_dim - Updated node embedding. - e_updated: nf x nloc x nnei x e_dim - Updated edge embedding. - a_updated : nf x nloc x a_nnei x a_nnei x a_dim - Updated angle embedding. - """ - xp = array_api_compat.array_namespace( - node_ebd_ext, - edge_ebd, - h2, - angle_ebd, - nlist, - nlist_mask, - sw, - a_nlist, - a_nlist_mask, - a_sw, - edge_index, - angle_index, - ) - nb, nloc, nnei = nlist.shape - nall = node_ebd_ext.shape[1] - # int cannot jit; do not run it when self.use_dynamic_sel == False - n_edge = ( - int(xp.sum(xp.astype(nlist_mask, xp.int32))) if self.use_dynamic_sel else 0 - ) - node_ebd = xp_take_first_n(node_ebd_ext, 1, nloc) - assert (nb, nloc) == node_ebd.shape[:2] - if not self.use_dynamic_sel: - assert (nb, nloc, nnei) == h2.shape[:3] - else: - assert (n_edge, 3) == h2.shape - del a_nlist # may be used in the future - - n2e_index, n_ext2e_index = edge_index[0, :], edge_index[1, :] - n2a_index, eij2a_index, eik2a_index = ( - angle_index[0, :], - angle_index[1, :], - angle_index[2, :], - ) - - # nb x nloc x nnei x n_dim [OR] n_edge x n_dim - nei_node_ebd = ( - _make_nei_g1(node_ebd_ext, nlist) - if not self.use_dynamic_sel - else xp.take( - xp.reshape(node_ebd_ext, (-1, self.n_dim)), n_ext2e_index, axis=0 - ) - ) - - n_update_list: list[Array] = [node_ebd] - e_update_list: list[Array] = [edge_ebd] - a_update_list: list[Array] = [angle_ebd] - - # node self mlp - node_self_mlp = self.act(self.node_self_mlp(node_ebd)) - n_update_list.append(node_self_mlp) - - # node sym (grrg + drrd) - node_sym_list: list[Array] = [] - node_sym_list.append( - symmetrization_op( - edge_ebd, - h2, - nlist_mask, - sw, - self.axis_neuron, - ) - if not self.use_dynamic_sel - else symmetrization_op_dynamic( - edge_ebd, - h2, - sw, - owner=n2e_index, - num_owner=nb * nloc, - nb=nb, - nloc=nloc, - scale_factor=self.dynamic_e_sel ** (-0.5), - axis_neuron=self.axis_neuron, - ) - ) - node_sym_list.append( - symmetrization_op( - nei_node_ebd, - h2, - nlist_mask, - sw, - self.axis_neuron, - ) - if not self.use_dynamic_sel - else symmetrization_op_dynamic( - nei_node_ebd, - h2, - sw, - owner=n2e_index, - num_owner=nb * nloc, - nb=nb, - nloc=nloc, - scale_factor=self.dynamic_e_sel ** (-0.5), - axis_neuron=self.axis_neuron, - ) - ) - node_sym = self.act(self.node_sym_linear(xp.concat(node_sym_list, axis=-1))) - n_update_list.append(node_sym) - + xp: object, + node_ebd: Array, + node_ebd_ext: Array, + edge_ebd: Array, + nei_node_ebd: Array, + nlist: Array, + n2e_index: Array, + n_ext2e_index: Array, + nb: int, + nloc: int, + ) -> Array: + """Compute edge self update.""" if not self.optim_update: if not self.use_dynamic_sel: - # nb x nloc x nnei x (n_dim * 2 + e_dim) edge_info = xp.concat( [ xp.tile( @@ -1516,7 +1382,6 @@ def call( axis=-1, ) else: - # n_edge x (n_dim * 2 + e_dim) edge_info = xp.concat( [ xp.take( @@ -1529,24 +1394,15 @@ def call( ], axis=-1, ) + return self.act(self.edge_self_linear(edge_info)) else: - edge_info = None - - # node edge message - # nb x nloc x nnei x (h * n_dim) - if not self.optim_update: - assert edge_info is not None - node_edge_update = self.act( - self.node_edge_linear(edge_info) - ) * xp.expand_dims(sw, axis=-1) - else: - node_edge_update = self.act( + return self.act( self.optim_edge_update( node_ebd, node_ebd_ext, edge_ebd, nlist, - "node", + "edge", ) if not self.use_dynamic_sel else self.optim_edge_update_dynamic( @@ -1555,166 +1411,145 @@ def call( edge_ebd, n2e_index, n_ext2e_index, - "node", - ) - ) * xp.expand_dims(sw, axis=-1) - - node_edge_update = ( - (xp.sum(node_edge_update, axis=-2) / self.nnei) - if not self.use_dynamic_sel - else ( - xp.reshape( - aggregate( - node_edge_update, - n2e_index, - average=False, - num_owner=nb * nloc, - ), - (nb, nloc, node_edge_update.shape[-1]), + "edge", ) - / self.dynamic_e_sel - ) - ) - if self.n_multi_edge_message > 1: - # nb x nloc x h x n_dim - node_edge_update_mul_head = xp.reshape( - node_edge_update, (nb, nloc, self.n_multi_edge_message, self.n_dim) ) - for head_index in range(self.n_multi_edge_message): - n_update_list.append(node_edge_update_mul_head[:, :, head_index, :]) + + def _prepare_angle_embeddings( + self, + xp: object, + node_ebd: Array, + edge_ebd: Array, + a_nlist_mask: Array, + ) -> tuple[Array, Array]: + """Prepare compressed node/edge embeddings for angle computation.""" + if self.a_compress_rate != 0: + if not self.a_compress_use_split: + assert self.a_compress_n_linear is not None + assert self.a_compress_e_linear is not None + node_ebd_for_angle = self.a_compress_n_linear(node_ebd) + edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd) + else: + node_ebd_for_angle = node_ebd[..., : self.n_a_compress_dim] + edge_ebd_for_angle = edge_ebd[..., : self.e_a_compress_dim] else: - n_update_list.append(node_edge_update) - # update node_ebd - n_updated = self.list_update(n_update_list, "node") + node_ebd_for_angle = node_ebd + edge_ebd_for_angle = edge_ebd - # edge self message + if not self.use_dynamic_sel: + edge_ebd_for_angle = edge_ebd_for_angle[..., : self.a_sel, :] + edge_ebd_for_angle = xp.where( + xp.expand_dims(a_nlist_mask, axis=-1), + edge_ebd_for_angle, + xp.zeros_like(edge_ebd_for_angle), + ) + return node_ebd_for_angle, edge_ebd_for_angle + + def _compute_angle_update( + self, + xp: object, + angle_ebd: Array, + node_ebd_for_angle: Array, + edge_ebd_for_angle: Array, + feat: str, + n2a_index: Array, + eij2a_index: Array, + eik2a_index: Array, + nb: int, + nloc: int, + ) -> Array: + """Compute angle-based update (for edge_angle or angle_self). + + Parameters + ---------- + feat : str + "edge" for edge_angle_linear1, "angle" for angle_self_linear. + """ if not self.optim_update: - assert edge_info is not None - edge_self_update = self.act(self.edge_self_linear(edge_info)) - else: - edge_self_update = self.act( - self.optim_edge_update( - node_ebd, - node_ebd_ext, - edge_ebd, - nlist, - "edge", + node_for_angle_info = ( + xp.tile( + xp.reshape( + node_ebd_for_angle, + (nb, nloc, 1, 1, self.n_a_compress_dim), + ), + (1, 1, self.a_sel, self.a_sel, 1), ) if not self.use_dynamic_sel - else self.optim_edge_update_dynamic( - node_ebd, - node_ebd_ext, - edge_ebd, - n2e_index, - n_ext2e_index, - "edge", + else xp.take( + xp.reshape(node_ebd_for_angle, (-1, self.n_a_compress_dim)), + n2a_index, + axis=0, ) ) - e_update_list.append(edge_self_update) - - if self.update_angle: - assert self.angle_self_linear is not None - assert self.edge_angle_linear1 is not None - assert self.edge_angle_linear2 is not None - # get angle info - if self.a_compress_rate != 0: - if not self.a_compress_use_split: - assert self.a_compress_n_linear is not None - assert self.a_compress_e_linear is not None - node_ebd_for_angle = self.a_compress_n_linear(node_ebd) - edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd) - else: - # use the first a_compress_dim dim for node and edge - node_ebd_for_angle = node_ebd[..., : self.n_a_compress_dim] - edge_ebd_for_angle = edge_ebd[..., : self.e_a_compress_dim] - else: - node_ebd_for_angle = node_ebd - edge_ebd_for_angle = edge_ebd - - if not self.use_dynamic_sel: - # nb x nloc x a_nnei x e_dim - edge_ebd_for_angle = edge_ebd_for_angle[..., : self.a_sel, :] - # nb x nloc x a_nnei x e_dim - edge_ebd_for_angle = xp.where( - xp.expand_dims(a_nlist_mask, axis=-1), + edge_for_angle_k = ( + xp.tile( + xp.reshape( + edge_ebd_for_angle, + (nb, nloc, 1, self.a_sel, self.e_a_compress_dim), + ), + (1, 1, self.a_sel, 1, 1), + ) + if not self.use_dynamic_sel + else xp.take( edge_ebd_for_angle, - xp.zeros_like(edge_ebd_for_angle), + eik2a_index, + axis=0, ) - if not self.optim_update: - # nb x nloc x a_nnei x a_nnei x n_dim [OR] n_angle x n_dim - node_for_angle_info = ( - xp.tile( - xp.reshape( - node_ebd_for_angle, (nb, nloc, 1, 1, self.n_a_compress_dim) - ), - (1, 1, self.a_sel, self.a_sel, 1), - ) - if not self.use_dynamic_sel - else xp.take( - xp.reshape(node_ebd_for_angle, (-1, self.n_a_compress_dim)), - n2a_index, - axis=0, - ) - ) - - # nb x nloc x (a_nnei) x a_nnei x e_dim [OR] n_angle x e_dim - edge_for_angle_k = ( - xp.tile( - xp.reshape( - edge_ebd_for_angle, - (nb, nloc, 1, self.a_sel, self.e_a_compress_dim), - ), - (1, 1, self.a_sel, 1, 1), - ) - if not self.use_dynamic_sel - else xp.take( + ) + edge_for_angle_j = ( + xp.tile( + xp.reshape( edge_ebd_for_angle, - eik2a_index, - axis=0, - ) + (nb, nloc, self.a_sel, 1, self.e_a_compress_dim), + ), + (1, 1, 1, self.a_sel, 1), ) - # nb x nloc x a_nnei x (a_nnei) x e_dim - edge_for_angle_j = ( - xp.tile( - xp.reshape( - edge_ebd_for_angle, - (nb, nloc, self.a_sel, 1, self.e_a_compress_dim), - ), - (1, 1, 1, self.a_sel, 1), + if not self.use_dynamic_sel + else xp.take( + edge_ebd_for_angle, + eij2a_index, + axis=0, + ) + ) + edge_for_angle_info = xp.concat( + [edge_for_angle_k, edge_for_angle_j], axis=-1 + ) + angle_info = xp.concat( + [angle_ebd, node_for_angle_info, edge_for_angle_info], axis=-1 + ) + if feat == "edge": + assert self.edge_angle_linear1 is not None + return self.act(self.edge_angle_linear1(angle_info)) + else: + assert self.angle_self_linear is not None + return self.act(self.angle_self_linear(angle_info)) + else: + if feat == "edge": + return self.act( + self.optim_angle_update( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + "edge", ) if not self.use_dynamic_sel - else xp.take( + else self.optim_angle_update_dynamic( + angle_ebd, + node_ebd_for_angle, edge_ebd_for_angle, + n2a_index, eij2a_index, - axis=0, + eik2a_index, + "edge", ) ) - # nb x nloc x a_nnei x a_nnei x (e_dim + e_dim) - edge_for_angle_info = xp.concat( - [edge_for_angle_k, edge_for_angle_j], axis=-1 - ) - angle_info_list = [angle_ebd] - angle_info_list.append(node_for_angle_info) - angle_info_list.append(edge_for_angle_info) - # nb x nloc x a_nnei x a_nnei x (a + n_dim + e_dim*2) or (a + a/c + a/c) - # [OR] - # n_angle x (a + n_dim + e_dim*2) or (a + a/c + a/c) - angle_info = xp.concat(angle_info_list, axis=-1) - else: - angle_info = None - - # edge angle message - # nb x nloc x a_nnei x a_nnei x e_dim [OR] n_angle x e_dim - if not self.optim_update: - assert angle_info is not None - edge_angle_update = self.act(self.edge_angle_linear1(angle_info)) else: - edge_angle_update = self.act( + return self.act( self.optim_angle_update( angle_ebd, node_ebd_for_angle, edge_ebd_for_angle, - "edge", + "angle", ) if not self.use_dynamic_sel else self.optim_angle_update_dynamic( @@ -1724,98 +1559,535 @@ def call( n2a_index, eij2a_index, eik2a_index, - "edge", + "angle", ) ) - if not self.use_dynamic_sel: - # nb x nloc x a_nnei x a_nnei x e_dim - weighted_edge_angle_update = ( - a_sw[:, :, :, xp.newaxis, xp.newaxis] - * a_sw[:, :, xp.newaxis, :, xp.newaxis] - * edge_angle_update + + def _compute_edge_angle_reduction( + self, + xp: object, + edge_angle_update: Array, + edge_ebd_fallback: Array, + a_sw: Array, + a_nlist_mask: Array, + nb: int, + nloc: int, + n_edge: int, + eij2a_index: Array, + ) -> Array: + """Reduce edge angle update over angle dimension, pad, and apply linear2. + + Parameters + ---------- + edge_ebd_fallback : Array + Edge embedding used for non-smooth padding fallback. + """ + assert self.edge_angle_linear2 is not None + if not self.use_dynamic_sel: + weighted_edge_angle_update = ( + a_sw[:, :, :, xp.newaxis, xp.newaxis] + * a_sw[:, :, xp.newaxis, :, xp.newaxis] + * edge_angle_update + ) + reduced_edge_angle_update = xp.sum(weighted_edge_angle_update, axis=-2) / ( + self.a_sel**0.5 + ) + padding_edge_angle_update = xp.concat( + [ + reduced_edge_angle_update, + xp.zeros( + (nb, nloc, self.nnei - self.a_sel, self.e_dim), + dtype=edge_ebd_fallback.dtype, + device=array_api_compat.device(edge_ebd_fallback), + ), + ], + axis=2, + ) + else: + weighted_edge_angle_update = edge_angle_update * xp.expand_dims( + a_sw, axis=-1 + ) + padding_edge_angle_update = aggregate( + weighted_edge_angle_update, + eij2a_index, + average=False, + num_owner=n_edge, + ) / (self.dynamic_a_sel**0.5) + + if not self.smooth_edge_update: + if self.use_dynamic_sel: + raise NotImplementedError( + "smooth_edge_update must be True when use_dynamic_sel is True!" ) - # nb x nloc x a_nnei x e_dim - reduced_edge_angle_update = xp.sum( - weighted_edge_angle_update, axis=-2 - ) / (self.a_sel**0.5) - # nb x nloc x nnei x e_dim - padding_edge_angle_update = xp.concat( + full_mask = xp.concat( + [ + a_nlist_mask, + xp.zeros( + (nb, nloc, self.nnei - self.a_sel), + dtype=a_nlist_mask.dtype, + device=array_api_compat.device(a_nlist_mask), + ), + ], + axis=-1, + ) + padding_edge_angle_update = xp.where( + xp.expand_dims(full_mask, axis=-1), + padding_edge_angle_update, + edge_ebd_fallback, + ) + + return self.act(self.edge_angle_linear2(padding_edge_angle_update)) + + def _compute_node_edge_message( + self, + xp: object, + node_ebd: Array, + node_ebd_ext: Array, + edge_ebd: Array, + nei_node_ebd: Array, + sw: Array, + nlist: Array, + n2e_index: Array, + n_ext2e_index: Array, + nb: int, + nloc: int, + ) -> Array: + """Compute node edge message and reduce over neighbor dimension.""" + if not self.optim_update: + if not self.use_dynamic_sel: + edge_info = xp.concat( [ - reduced_edge_angle_update, - xp.zeros( - (nb, nloc, self.nnei - self.a_sel, self.e_dim), - dtype=edge_ebd.dtype, - device=array_api_compat.device(edge_ebd), + xp.tile( + xp.reshape(node_ebd, (nb, nloc, 1, self.n_dim)), + (1, 1, self.nnei, 1), ), + nei_node_ebd, + edge_ebd, ], - axis=2, + axis=-1, ) else: - # n_angle x e_dim - weighted_edge_angle_update = edge_angle_update * xp.expand_dims( - a_sw, axis=-1 - ) - # n_edge x e_dim - padding_edge_angle_update = aggregate( - weighted_edge_angle_update, - eij2a_index, - average=False, - num_owner=n_edge, - ) / (self.dynamic_a_sel**0.5) - - if not self.smooth_edge_update: - # will be deprecated in the future - # not support dynamic index, will pass anyway - if self.use_dynamic_sel: - raise NotImplementedError( - "smooth_edge_update must be True when use_dynamic_sel is True!" - ) - full_mask = xp.concat( + edge_info = xp.concat( [ - a_nlist_mask, - xp.zeros( - (nb, nloc, self.nnei - self.a_sel), - dtype=a_nlist_mask.dtype, - device=array_api_compat.device(a_nlist_mask), + xp.take( + xp.reshape(node_ebd, (-1, self.n_dim)), + n2e_index, + axis=0, ), + nei_node_ebd, + edge_ebd, ], axis=-1, ) - padding_edge_angle_update = xp.where( - xp.expand_dims(full_mask, axis=-1), - padding_edge_angle_update, + node_edge_update = self.act( + self.node_edge_linear(edge_info) + ) * xp.expand_dims(sw, axis=-1) + else: + node_edge_update = self.act( + self.optim_edge_update( + node_ebd, + node_ebd_ext, + edge_ebd, + nlist, + "node", + ) + if not self.use_dynamic_sel + else self.optim_edge_update_dynamic( + node_ebd, + node_ebd_ext, edge_ebd, + n2e_index, + n_ext2e_index, + "node", + ) + ) * xp.expand_dims(sw, axis=-1) + + node_edge_update = ( + (xp.sum(node_edge_update, axis=-2) / self.nnei) + if not self.use_dynamic_sel + else ( + xp.reshape( + aggregate( + node_edge_update, + n2e_index, + average=False, + num_owner=nb * nloc, + ), + (nb, nloc, node_edge_update.shape[-1]), + ) + / self.dynamic_e_sel + ) + ) + return node_edge_update + + def _compute_node_sym( + self, + xp: object, + edge_ebd: Array, + nei_node_ebd: Array, + h2: Array, + nlist_mask: Array, + sw: Array, + n2e_index: Array, + nb: int, + nloc: int, + ) -> Array: + """Compute node symmetrization update (grrg + drrd).""" + node_sym_list: list[Array] = [] + node_sym_list.append( + symmetrization_op( + edge_ebd, + h2, + nlist_mask, + sw, + self.axis_neuron, + ) + if not self.use_dynamic_sel + else symmetrization_op_dynamic( + edge_ebd, + h2, + sw, + owner=n2e_index, + num_owner=nb * nloc, + nb=nb, + nloc=nloc, + scale_factor=self.dynamic_e_sel ** (-0.5), + axis_neuron=self.axis_neuron, + ) + ) + node_sym_list.append( + symmetrization_op( + nei_node_ebd, + h2, + nlist_mask, + sw, + self.axis_neuron, + ) + if not self.use_dynamic_sel + else symmetrization_op_dynamic( + nei_node_ebd, + h2, + sw, + owner=n2e_index, + num_owner=nb * nloc, + nb=nb, + nloc=nloc, + scale_factor=self.dynamic_e_sel ** (-0.5), + axis_neuron=self.axis_neuron, + ) + ) + return self.act(self.node_sym_linear(xp.concat(node_sym_list, axis=-1))) + + def call( + self, + node_ebd_ext: Array, # nf x nall x n_dim + edge_ebd: Array, # nf x nloc x nnei x e_dim + h2: Array, # nf x nloc x nnei x 3 + angle_ebd: Array, # nf x nloc x a_nnei x a_nnei x a_dim + nlist: Array, # nf x nloc x nnei + nlist_mask: Array, # nf x nloc x nnei + sw: Array, # switch func, nf x nloc x nnei + a_nlist: Array, # nf x nloc x a_nnei + a_nlist_mask: Array, # nf x nloc x a_nnei + a_sw: Array, # switch func, nf x nloc x a_nnei + edge_index: Array, # 2 x n_edge + angle_index: Array, # 3 x n_angle + ) -> tuple[Array, Array, Array]: + """ + Parameters + ---------- + node_ebd_ext : nf x nall x n_dim + Extended node embedding. + edge_ebd : nf x nloc x nnei x e_dim + Edge embedding. + h2 : nf x nloc x nnei x 3 + Pair-atom channel, equivariant. + angle_ebd : nf x nloc x a_nnei x a_nnei x a_dim + Angle embedding. + nlist : nf x nloc x nnei + Neighbor list. (padded neis are set to 0) + nlist_mask : nf x nloc x nnei + Masks of the neighbor list. real nei 1 otherwise 0 + sw : nf x nloc x nnei + Switch function. + a_nlist : nf x nloc x a_nnei + Neighbor list for angle. (padded neis are set to 0) + a_nlist_mask : nf x nloc x a_nnei + Masks of the neighbor list for angle. real nei 1 otherwise 0 + a_sw : nf x nloc x a_nnei + Switch function for angle. + edge_index : Optional for dynamic sel, 2 x n_edge + n2e_index : n_edge + Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i). + n_ext2e_index : n_edge + Broadcast indices from extended node(j) to edge(ij). + angle_index : Optional for dynamic sel, 3 x n_angle + n2a_index : n_angle + Broadcast indices from extended node(j) to angle(ijk). + eij2a_index : n_angle + Broadcast indices from extended edge(ij) to angle(ijk), or reduction indices from angle(ijk) to edge(ij). + eik2a_index : n_angle + Broadcast indices from extended edge(ik) to angle(ijk). + + Returns + ------- + n_updated: nf x nloc x n_dim + Updated node embedding. + e_updated: nf x nloc x nnei x e_dim + Updated edge embedding. + a_updated : nf x nloc x a_nnei x a_nnei x a_dim + Updated angle embedding. + """ + xp = array_api_compat.array_namespace( + node_ebd_ext, + edge_ebd, + h2, + angle_ebd, + nlist, + nlist_mask, + sw, + a_nlist, + a_nlist_mask, + a_sw, + edge_index, + angle_index, + ) + nb, nloc, nnei = nlist.shape + nall = node_ebd_ext.shape[1] + # int cannot jit; do not run it when self.use_dynamic_sel == False + n_edge = ( + int(xp.sum(xp.astype(nlist_mask, xp.int32))) if self.use_dynamic_sel else 0 + ) + node_ebd = xp_take_first_n(node_ebd_ext, 1, nloc) + assert (nb, nloc) == node_ebd.shape[:2] + if not self.use_dynamic_sel: + assert (nb, nloc, nnei) == h2.shape[:3] + else: + assert (n_edge, 3) == h2.shape + del a_nlist # may be used in the future + + n2e_index, n_ext2e_index = edge_index[0, :], edge_index[1, :] + n2a_index, eij2a_index, eik2a_index = ( + angle_index[0, :], + angle_index[1, :], + angle_index[2, :], + ) + + # nb x nloc x nnei x n_dim [OR] n_edge x n_dim + nei_node_ebd = ( + _make_nei_g1(node_ebd_ext, nlist) + if not self.use_dynamic_sel + else xp.take( + xp.reshape(node_ebd_ext, (-1, self.n_dim)), n_ext2e_index, axis=0 + ) + ) + + # Edge self update (always from original embeddings) + edge_self_update = self._compute_edge_self_update( + xp, + node_ebd, + node_ebd_ext, + edge_ebd, + nei_node_ebd, + nlist, + n2e_index, + n_ext2e_index, + nb, + nloc, + ) + + if self.sequential_update and self.update_angle: + # === Sequential update path === + # Phase 1: Apply edge self residual + edge_ebd_s1 = edge_ebd + self.e_residual[0] * edge_self_update + + # Phase 2: Angle self (uses updated edge_ebd_s1) + node_for_a, edge_for_a = self._prepare_angle_embeddings( + xp, node_ebd, edge_ebd_s1, a_nlist_mask + ) + angle_self_update = self._compute_angle_update( + xp, + angle_ebd, + node_for_a, + edge_for_a, + "angle", + n2a_index, + eij2a_index, + eik2a_index, + nb, + nloc, + ) + a_updated = angle_ebd + self.a_residual[0] * angle_self_update + + # Phase 3: Edge angle (uses updated angle a_updated + edge_ebd_s1) + edge_angle_update = self._compute_angle_update( + xp, + a_updated, + node_for_a, + edge_for_a, + "edge", + n2a_index, + eij2a_index, + eik2a_index, + nb, + nloc, + ) + edge_angle_processed = self._compute_edge_angle_reduction( + xp, + edge_angle_update, + edge_ebd_s1, + a_sw, + a_nlist_mask, + nb, + nloc, + n_edge, + eij2a_index, + ) + e_updated = edge_ebd_s1 + self.e_residual[1] * edge_angle_processed + + # Phase 4+5: Node updates (uses e_updated) + node_self_mlp = self.act(self.node_self_mlp(node_ebd)) + node_sym = self._compute_node_sym( + xp, + e_updated, + nei_node_ebd, + h2, + nlist_mask, + sw, + n2e_index, + nb, + nloc, + ) + node_edge_update = self._compute_node_edge_message( + xp, + node_ebd, + node_ebd_ext, + e_updated, + nei_node_ebd, + sw, + nlist, + n2e_index, + n_ext2e_index, + nb, + nloc, + ) + + n_update_list: list[Array] = [node_ebd, node_self_mlp, node_sym] + if self.n_multi_edge_message > 1: + node_edge_update_mul_head = xp.reshape( + node_edge_update, + (nb, nloc, self.n_multi_edge_message, self.n_dim), ) - e_update_list.append( - self.act(self.edge_angle_linear2(padding_edge_angle_update)) + for head_index in range(self.n_multi_edge_message): + n_update_list.append(node_edge_update_mul_head[:, :, head_index, :]) + else: + n_update_list.append(node_edge_update) + n_updated = self.list_update(n_update_list, "node") + + return n_updated, e_updated, a_updated + + # === Parallel update path === + n_update_list: list[Array] = [node_ebd] + e_update_list: list[Array] = [edge_ebd] + a_update_list: list[Array] = [angle_ebd] + + # node self mlp + node_self_mlp = self.act(self.node_self_mlp(node_ebd)) + n_update_list.append(node_self_mlp) + + # node sym (grrg + drrd) + node_sym = self._compute_node_sym( + xp, + edge_ebd, + nei_node_ebd, + h2, + nlist_mask, + sw, + n2e_index, + nb, + nloc, + ) + n_update_list.append(node_sym) + + # node edge message + node_edge_update = self._compute_node_edge_message( + xp, + node_ebd, + node_ebd_ext, + edge_ebd, + nei_node_ebd, + sw, + nlist, + n2e_index, + n_ext2e_index, + nb, + nloc, + ) + if self.n_multi_edge_message > 1: + # nb x nloc x h x n_dim + node_edge_update_mul_head = xp.reshape( + node_edge_update, (nb, nloc, self.n_multi_edge_message, self.n_dim) + ) + for head_index in range(self.n_multi_edge_message): + n_update_list.append(node_edge_update_mul_head[:, :, head_index, :]) + else: + n_update_list.append(node_edge_update) + # update node_ebd + n_updated = self.list_update(n_update_list, "node") + + # edge self message + e_update_list.append(edge_self_update) + + if self.update_angle: + assert self.angle_self_linear is not None + assert self.edge_angle_linear1 is not None + assert self.edge_angle_linear2 is not None + + node_for_a, edge_for_a = self._prepare_angle_embeddings( + xp, node_ebd, edge_ebd, a_nlist_mask + ) + + # edge angle message + edge_angle_update = self._compute_angle_update( + xp, + angle_ebd, + node_for_a, + edge_for_a, + "edge", + n2a_index, + eij2a_index, + eik2a_index, + nb, + nloc, ) + edge_angle_processed = self._compute_edge_angle_reduction( + xp, + edge_angle_update, + edge_ebd, + a_sw, + a_nlist_mask, + nb, + nloc, + n_edge, + eij2a_index, + ) + e_update_list.append(edge_angle_processed) # update edge_ebd e_updated = self.list_update(e_update_list, "edge") # angle self message - # nb x nloc x a_nnei x a_nnei x dim_a - if not self.optim_update: - assert angle_info is not None - angle_self_update = self.act(self.angle_self_linear(angle_info)) - else: - angle_self_update = self.act( - self.optim_angle_update( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - "angle", - ) - if not self.use_dynamic_sel - else self.optim_angle_update_dynamic( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - n2a_index, - eij2a_index, - eik2a_index, - "angle", - ) - ) + angle_self_update = self._compute_angle_update( + xp, + angle_ebd, + node_for_a, + edge_for_a, + "angle", + n2a_index, + eij2a_index, + eik2a_index, + nb, + nloc, + ) a_update_list.append(angle_self_update) else: # update edge_ebd @@ -1907,6 +2179,7 @@ def serialize(self) -> dict: "smooth_edge_update": self.smooth_edge_update, "use_dynamic_sel": self.use_dynamic_sel, "sel_reduce_factor": self.sel_reduce_factor, + "sequential_update": self.sequential_update, "node_self_mlp": self.node_self_mlp.serialize(), "node_sym_linear": self.node_sym_linear.serialize(), "node_edge_linear": self.node_edge_linear.serialize(), diff --git a/deepmd/pt/model/descriptor/dpa3.py b/deepmd/pt/model/descriptor/dpa3.py index 0c6982afe5..a5f79280fa 100644 --- a/deepmd/pt/model/descriptor/dpa3.py +++ b/deepmd/pt/model/descriptor/dpa3.py @@ -167,6 +167,7 @@ def init_subclass_params(sub_data: Any, sub_class: Any) -> Any: use_exp_switch=self.repflow_args.use_exp_switch, use_dynamic_sel=self.repflow_args.use_dynamic_sel, sel_reduce_factor=self.repflow_args.sel_reduce_factor, + sequential_update=self.repflow_args.sequential_update, use_loc_mapping=use_loc_mapping, exclude_types=exclude_types, env_protection=env_protection, diff --git a/deepmd/pt/model/descriptor/repflow_layer.py b/deepmd/pt/model/descriptor/repflow_layer.py index 338f48b060..a094105487 100644 --- a/deepmd/pt/model/descriptor/repflow_layer.py +++ b/deepmd/pt/model/descriptor/repflow_layer.py @@ -53,6 +53,7 @@ def __init__( optim_update: bool = True, use_dynamic_sel: bool = False, sel_reduce_factor: float = 10.0, + sequential_update: bool = False, smooth_edge_update: bool = False, activation_function: str = "silu", update_style: str = "res_residual", @@ -102,8 +103,15 @@ def __init__( self.smooth_edge_update = smooth_edge_update self.use_dynamic_sel = use_dynamic_sel self.sel_reduce_factor = sel_reduce_factor + self.sequential_update = sequential_update self.dynamic_e_sel = self.nnei / self.sel_reduce_factor self.dynamic_a_sel = self.a_sel / self.sel_reduce_factor + if self.sequential_update and self.update_style != "res_residual": + raise NotImplementedError( + "sequential_update only supports update_style='res_residual'!" + ) + if self.sequential_update and not self.update_angle: + raise NotImplementedError("sequential_update requires update_angle=True!") assert update_residual_init in [ "norm", @@ -694,6 +702,381 @@ def optim_edge_update_dynamic( result_update = bias + sub_node_update + sub_edge_update + sub_node_ext_update return result_update + def _compute_edge_self_update( + self, + node_ebd: torch.Tensor, + node_ebd_ext: torch.Tensor, + edge_ebd: torch.Tensor, + nei_node_ebd: torch.Tensor, + nlist: torch.Tensor, + n2e_index: torch.Tensor, + n_ext2e_index: torch.Tensor, + ) -> torch.Tensor: + """Compute edge self update.""" + if not self.optim_update: + if not self.use_dynamic_sel: + edge_info = torch.cat( + [ + torch.tile(node_ebd.unsqueeze(-2), [1, 1, self.nnei, 1]), + nei_node_ebd, + edge_ebd, + ], + dim=-1, + ) + else: + edge_info = torch.cat( + [ + torch.index_select( + node_ebd.reshape(-1, self.n_dim), 0, n2e_index + ), + nei_node_ebd, + edge_ebd, + ], + dim=-1, + ) + return self.act(self.edge_self_linear(edge_info)) + else: + return self.act( + self.optim_edge_update( + node_ebd, + node_ebd_ext, + edge_ebd, + nlist, + "edge", + ) + if not self.use_dynamic_sel + else self.optim_edge_update_dynamic( + node_ebd, + node_ebd_ext, + edge_ebd, + n2e_index, + n_ext2e_index, + "edge", + ) + ) + + def _prepare_angle_embeddings( + self, + node_ebd: torch.Tensor, + edge_ebd: torch.Tensor, + a_nlist_mask: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Prepare compressed node/edge embeddings for angle computation.""" + if self.a_compress_rate != 0: + if not self.a_compress_use_split: + assert self.a_compress_n_linear is not None + assert self.a_compress_e_linear is not None + node_ebd_for_angle = self.a_compress_n_linear(node_ebd) + edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd) + else: + node_ebd_for_angle = node_ebd[..., : self.n_a_compress_dim] + edge_ebd_for_angle = edge_ebd[..., : self.e_a_compress_dim] + else: + node_ebd_for_angle = node_ebd + edge_ebd_for_angle = edge_ebd + + if not self.use_dynamic_sel: + edge_ebd_for_angle = edge_ebd_for_angle[..., : self.a_sel, :] + edge_ebd_for_angle = torch.where( + a_nlist_mask.unsqueeze(-1), edge_ebd_for_angle, 0.0 + ) + return node_ebd_for_angle, edge_ebd_for_angle + + def _compute_angle_update( + self, + angle_ebd: torch.Tensor, + node_ebd_for_angle: torch.Tensor, + edge_ebd_for_angle: torch.Tensor, + feat: str, + n2a_index: torch.Tensor, + eij2a_index: torch.Tensor, + eik2a_index: torch.Tensor, + ) -> torch.Tensor: + """Compute angle-based update (for edge_angle or angle_self). + + Parameters + ---------- + feat : str + "edge" for edge_angle_linear1, "angle" for angle_self_linear. + """ + if not self.optim_update: + node_for_angle_info = ( + torch.tile( + node_ebd_for_angle.unsqueeze(2).unsqueeze(2), + (1, 1, self.a_sel, self.a_sel, 1), + ) + if not self.use_dynamic_sel + else torch.index_select( + node_ebd_for_angle.reshape(-1, self.n_a_compress_dim), + 0, + n2a_index, + ) + ) + edge_for_angle_k = ( + torch.tile(edge_ebd_for_angle.unsqueeze(2), (1, 1, self.a_sel, 1, 1)) + if not self.use_dynamic_sel + else torch.index_select(edge_ebd_for_angle, 0, eik2a_index) + ) + edge_for_angle_j = ( + torch.tile(edge_ebd_for_angle.unsqueeze(3), (1, 1, 1, self.a_sel, 1)) + if not self.use_dynamic_sel + else torch.index_select(edge_ebd_for_angle, 0, eij2a_index) + ) + edge_for_angle_info = torch.cat( + [edge_for_angle_k, edge_for_angle_j], dim=-1 + ) + angle_info = torch.cat( + [angle_ebd, node_for_angle_info, edge_for_angle_info], dim=-1 + ) + if feat == "edge": + assert self.edge_angle_linear1 is not None + return self.act(self.edge_angle_linear1(angle_info)) + else: + assert self.angle_self_linear is not None + return self.act(self.angle_self_linear(angle_info)) + else: + if feat == "edge": + return self.act( + self.optim_angle_update( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + "edge", + ) + if not self.use_dynamic_sel + else self.optim_angle_update_dynamic( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + n2a_index, + eij2a_index, + eik2a_index, + "edge", + ) + ) + else: + return self.act( + self.optim_angle_update( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + "angle", + ) + if not self.use_dynamic_sel + else self.optim_angle_update_dynamic( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + n2a_index, + eij2a_index, + eik2a_index, + "angle", + ) + ) + + def _compute_edge_angle_reduction( + self, + edge_angle_update: torch.Tensor, + edge_ebd_fallback: torch.Tensor, + a_sw: torch.Tensor, + a_nlist_mask: torch.Tensor, + nb: int, + nloc: int, + n_edge: int | None, + eij2a_index: torch.Tensor, + ) -> torch.Tensor: + """Reduce edge angle update over angle dimension, pad, and apply linear2. + + Parameters + ---------- + edge_ebd_fallback : torch.Tensor + Edge embedding used for non-smooth padding fallback. + """ + assert self.edge_angle_linear2 is not None + if not self.use_dynamic_sel: + weighted_edge_angle_update = ( + a_sw.unsqueeze(-1).unsqueeze(-1) + * a_sw.unsqueeze(-2).unsqueeze(-1) + * edge_angle_update + ) + reduced_edge_angle_update = torch.sum( + weighted_edge_angle_update, dim=-2 + ) / (self.a_sel**0.5) + padding_edge_angle_update = torch.concat( + [ + reduced_edge_angle_update, + torch.zeros( + [nb, nloc, self.nnei - self.a_sel, self.e_dim], + dtype=edge_ebd_fallback.dtype, + device=edge_ebd_fallback.device, + ), + ], + dim=2, + ) + else: + assert n_edge is not None + weighted_edge_angle_update = edge_angle_update * a_sw.unsqueeze(-1) + padding_edge_angle_update = aggregate( + weighted_edge_angle_update, + eij2a_index, + average=False, + num_owner=n_edge, + ) / (self.dynamic_a_sel**0.5) + + if not self.smooth_edge_update: + if self.use_dynamic_sel: + raise NotImplementedError( + "smooth_edge_update must be True when use_dynamic_sel is True!" + ) + full_mask = torch.concat( + [ + a_nlist_mask, + torch.zeros( + [nb, nloc, self.nnei - self.a_sel], + dtype=a_nlist_mask.dtype, + device=a_nlist_mask.device, + ), + ], + dim=-1, + ) + padding_edge_angle_update = torch.where( + full_mask.unsqueeze(-1), + padding_edge_angle_update, + edge_ebd_fallback, + ) + + return self.act(self.edge_angle_linear2(padding_edge_angle_update)) + + def _compute_node_edge_message( + self, + node_ebd: torch.Tensor, + node_ebd_ext: torch.Tensor, + edge_ebd: torch.Tensor, + nei_node_ebd: torch.Tensor, + sw: torch.Tensor, + nlist: torch.Tensor, + n2e_index: torch.Tensor, + n_ext2e_index: torch.Tensor, + nb: int, + nloc: int, + ) -> torch.Tensor: + """Compute node edge message and reduce over neighbor dimension.""" + if not self.optim_update: + if not self.use_dynamic_sel: + edge_info = torch.cat( + [ + torch.tile(node_ebd.unsqueeze(-2), [1, 1, self.nnei, 1]), + nei_node_ebd, + edge_ebd, + ], + dim=-1, + ) + else: + edge_info = torch.cat( + [ + torch.index_select( + node_ebd.reshape(-1, self.n_dim), 0, n2e_index + ), + nei_node_ebd, + edge_ebd, + ], + dim=-1, + ) + node_edge_update = self.act( + self.node_edge_linear(edge_info) + ) * sw.unsqueeze(-1) + else: + node_edge_update = self.act( + self.optim_edge_update( + node_ebd, + node_ebd_ext, + edge_ebd, + nlist, + "node", + ) + if not self.use_dynamic_sel + else self.optim_edge_update_dynamic( + node_ebd, + node_ebd_ext, + edge_ebd, + n2e_index, + n_ext2e_index, + "node", + ) + ) * sw.unsqueeze(-1) + + node_edge_update = ( + (torch.sum(node_edge_update, dim=-2) / self.nnei) + if not self.use_dynamic_sel + else ( + aggregate( + node_edge_update, + n2e_index, + average=False, + num_owner=nb * nloc, + ).reshape(nb, nloc, node_edge_update.shape[-1]) + / self.dynamic_e_sel + ) + ) + return node_edge_update + + def _compute_node_sym( + self, + edge_ebd: torch.Tensor, + nei_node_ebd: torch.Tensor, + h2: torch.Tensor, + nlist_mask: torch.Tensor, + sw: torch.Tensor, + n2e_index: torch.Tensor, + nb: int, + nloc: int, + ) -> torch.Tensor: + """Compute node symmetrization update (grrg + drrd).""" + node_sym_list: list[torch.Tensor] = [] + node_sym_list.append( + self.symmetrization_op( + edge_ebd, + h2, + nlist_mask, + sw, + self.axis_neuron, + ) + if not self.use_dynamic_sel + else self.symmetrization_op_dynamic( + edge_ebd, + h2, + sw, + owner=n2e_index, + num_owner=nb * nloc, + nb=nb, + nloc=nloc, + scale_factor=self.dynamic_e_sel ** (-0.5), + axis_neuron=self.axis_neuron, + ) + ) + node_sym_list.append( + self.symmetrization_op( + nei_node_ebd, + h2, + nlist_mask, + sw, + self.axis_neuron, + ) + if not self.use_dynamic_sel + else self.symmetrization_op_dynamic( + nei_node_ebd, + h2, + sw, + owner=n2e_index, + num_owner=nb * nloc, + nb=nb, + nloc=nloc, + scale_factor=self.dynamic_e_sel ** (-0.5), + axis_neuron=self.axis_neuron, + ) + ) + return self.act(self.node_sym_linear(torch.cat(node_sym_list, dim=-1))) + def forward( self, node_ebd_ext: torch.Tensor, # nf x nall x n_dim [OR] nf x nloc x n_dim when not parallel_mode @@ -783,127 +1166,122 @@ def forward( ) ) - n_update_list: list[torch.Tensor] = [node_ebd] - e_update_list: list[torch.Tensor] = [edge_ebd] - a_update_list: list[torch.Tensor] = [angle_ebd] + # Edge self update (always from original embeddings) + edge_self_update = self._compute_edge_self_update( + node_ebd, + node_ebd_ext, + edge_ebd, + nei_node_ebd, + nlist, + n2e_index, + n_ext2e_index, + ) - # node self mlp - node_self_mlp = self.act(self.node_self_mlp(node_ebd)) - n_update_list.append(node_self_mlp) + if self.sequential_update and self.update_angle: + # === Sequential update path === + # Phase 1: Apply edge self residual + edge_ebd_s1 = edge_ebd + self.e_residual[0] * edge_self_update - # node sym (grrg + drrd) - node_sym_list: list[torch.Tensor] = [] - node_sym_list.append( - self.symmetrization_op( - edge_ebd, - h2, - nlist_mask, - sw, - self.axis_neuron, + # Phase 2: Angle self (uses updated edge_ebd_s1) + node_for_a, edge_for_a = self._prepare_angle_embeddings( + node_ebd, edge_ebd_s1, a_nlist_mask ) - if not self.use_dynamic_sel - else self.symmetrization_op_dynamic( - edge_ebd, - h2, - sw, - owner=n2e_index, - num_owner=nb * nloc, - nb=nb, - nloc=nloc, - scale_factor=self.dynamic_e_sel ** (-0.5), - axis_neuron=self.axis_neuron, + angle_self_update = self._compute_angle_update( + angle_ebd, + node_for_a, + edge_for_a, + "angle", + n2a_index, + eij2a_index, + eik2a_index, ) - ) - node_sym_list.append( - self.symmetrization_op( - nei_node_ebd, - h2, - nlist_mask, - sw, - self.axis_neuron, + a_updated = angle_ebd + self.a_residual[0] * angle_self_update + + # Phase 3: Edge angle (uses updated angle a_updated + edge_ebd_s1) + edge_angle_update = self._compute_angle_update( + a_updated, + node_for_a, + edge_for_a, + "edge", + n2a_index, + eij2a_index, + eik2a_index, ) - if not self.use_dynamic_sel - else self.symmetrization_op_dynamic( + edge_angle_processed = self._compute_edge_angle_reduction( + edge_angle_update, + edge_ebd_s1, + a_sw, + a_nlist_mask, + nb, + nloc, + n_edge, + eij2a_index, + ) + e_updated = edge_ebd_s1 + self.e_residual[1] * edge_angle_processed + + # Phase 4+5: Node updates (uses e_updated) + node_self_mlp = self.act(self.node_self_mlp(node_ebd)) + node_sym = self._compute_node_sym( + e_updated, nei_node_ebd, h2, nlist_mask, sw, n2e_index, nb, nloc + ) + node_edge_update = self._compute_node_edge_message( + node_ebd, + node_ebd_ext, + e_updated, nei_node_ebd, - h2, sw, - owner=n2e_index, - num_owner=nb * nloc, - nb=nb, - nloc=nloc, - scale_factor=self.dynamic_e_sel ** (-0.5), - axis_neuron=self.axis_neuron, + nlist, + n2e_index, + n_ext2e_index, + nb, + nloc, ) - ) - node_sym = self.act(self.node_sym_linear(torch.cat(node_sym_list, dim=-1))) - n_update_list.append(node_sym) - if not self.optim_update: - if not self.use_dynamic_sel: - # nb x nloc x nnei x (n_dim * 2 + e_dim) - edge_info = torch.cat( - [ - torch.tile(node_ebd.unsqueeze(-2), [1, 1, self.nnei, 1]), - nei_node_ebd, - edge_ebd, - ], - dim=-1, + n_update_list: list[torch.Tensor] = [ + node_ebd, + node_self_mlp, + node_sym, + ] + if self.n_multi_edge_message > 1: + node_edge_update_mul_head = node_edge_update.view( + nb, nloc, self.n_multi_edge_message, self.n_dim ) + for head_index in range(self.n_multi_edge_message): + n_update_list.append(node_edge_update_mul_head[..., head_index, :]) else: - # n_edge x (n_dim * 2 + e_dim) - edge_info = torch.cat( - [ - torch.index_select( - node_ebd.reshape(-1, self.n_dim), 0, n2e_index - ), - nei_node_ebd, - edge_ebd, - ], - dim=-1, - ) - else: - edge_info = None + n_update_list.append(node_edge_update) + n_updated = self.list_update(n_update_list, "node") - # node edge message - # nb x nloc x nnei x (h * n_dim) - if not self.optim_update: - assert edge_info is not None - node_edge_update = self.act( - self.node_edge_linear(edge_info) - ) * sw.unsqueeze(-1) - else: - node_edge_update = self.act( - self.optim_edge_update( - node_ebd, - node_ebd_ext, - edge_ebd, - nlist, - "node", - ) - if not self.use_dynamic_sel - else self.optim_edge_update_dynamic( - node_ebd, - node_ebd_ext, - edge_ebd, - n2e_index, - n_ext2e_index, - "node", - ) - ) * sw.unsqueeze(-1) - node_edge_update = ( - (torch.sum(node_edge_update, dim=-2) / self.nnei) - if not self.use_dynamic_sel - else ( - aggregate( - node_edge_update, - n2e_index, - average=False, - num_owner=nb * nloc, - ).reshape(nb, nloc, node_edge_update.shape[-1]) - / self.dynamic_e_sel - ) + return n_updated, e_updated, a_updated + + # === Parallel update path === + n_update_list: list[torch.Tensor] = [node_ebd] + e_update_list: list[torch.Tensor] = [edge_ebd] + a_update_list: list[torch.Tensor] = [angle_ebd] + + # node self mlp + node_self_mlp = self.act(self.node_self_mlp(node_ebd)) + n_update_list.append(node_self_mlp) + + # node sym (grrg + drrd) + node_sym = self._compute_node_sym( + edge_ebd, nei_node_ebd, h2, nlist_mask, sw, n2e_index, nb, nloc ) + n_update_list.append(node_sym) + # node edge message + node_edge_update = self._compute_node_edge_message( + node_ebd, + node_ebd_ext, + edge_ebd, + nei_node_ebd, + sw, + nlist, + n2e_index, + n_ext2e_index, + nb, + nloc, + ) if self.n_multi_edge_message > 1: # nb x nloc x h x n_dim node_edge_update_mul_head = node_edge_update.view( @@ -917,211 +1295,51 @@ def forward( n_updated = self.list_update(n_update_list, "node") # edge self message - if not self.optim_update: - assert edge_info is not None - edge_self_update = self.act(self.edge_self_linear(edge_info)) - else: - edge_self_update = self.act( - self.optim_edge_update( - node_ebd, - node_ebd_ext, - edge_ebd, - nlist, - "edge", - ) - if not self.use_dynamic_sel - else self.optim_edge_update_dynamic( - node_ebd, - node_ebd_ext, - edge_ebd, - n2e_index, - n_ext2e_index, - "edge", - ) - ) e_update_list.append(edge_self_update) if self.update_angle: assert self.angle_self_linear is not None assert self.edge_angle_linear1 is not None assert self.edge_angle_linear2 is not None - # get angle info - if self.a_compress_rate != 0: - if not self.a_compress_use_split: - assert self.a_compress_n_linear is not None - assert self.a_compress_e_linear is not None - node_ebd_for_angle = self.a_compress_n_linear(node_ebd) - edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd) - else: - # use the first a_compress_dim dim for node and edge - node_ebd_for_angle = node_ebd[..., : self.n_a_compress_dim] - edge_ebd_for_angle = edge_ebd[..., : self.e_a_compress_dim] - else: - node_ebd_for_angle = node_ebd - edge_ebd_for_angle = edge_ebd - if not self.use_dynamic_sel: - # nb x nloc x a_nnei x e_dim - edge_ebd_for_angle = edge_ebd_for_angle[..., : self.a_sel, :] - # nb x nloc x a_nnei x e_dim - edge_ebd_for_angle = torch.where( - a_nlist_mask.unsqueeze(-1), edge_ebd_for_angle, 0.0 - ) - if not self.optim_update: - # nb x nloc x a_nnei x a_nnei x n_dim [OR] n_angle x n_dim - node_for_angle_info = ( - torch.tile( - node_ebd_for_angle.unsqueeze(2).unsqueeze(2), - (1, 1, self.a_sel, self.a_sel, 1), - ) - if not self.use_dynamic_sel - else torch.index_select( - node_ebd_for_angle.reshape(-1, self.n_a_compress_dim), - 0, - n2a_index, - ) - ) - - # nb x nloc x (a_nnei) x a_nnei x e_dim [OR] n_angle x e_dim - edge_for_angle_k = ( - torch.tile( - edge_ebd_for_angle.unsqueeze(2), (1, 1, self.a_sel, 1, 1) - ) - if not self.use_dynamic_sel - else torch.index_select(edge_ebd_for_angle, 0, eik2a_index) - ) - # nb x nloc x a_nnei x (a_nnei) x e_dim [OR] n_angle x e_dim - edge_for_angle_j = ( - torch.tile( - edge_ebd_for_angle.unsqueeze(3), (1, 1, 1, self.a_sel, 1) - ) - if not self.use_dynamic_sel - else torch.index_select(edge_ebd_for_angle, 0, eij2a_index) - ) - # nb x nloc x a_nnei x a_nnei x (e_dim + e_dim) [OR] n_angle x (e_dim + e_dim) - edge_for_angle_info = torch.cat( - [edge_for_angle_k, edge_for_angle_j], dim=-1 - ) - angle_info_list = [angle_ebd] - angle_info_list.append(node_for_angle_info) - angle_info_list.append(edge_for_angle_info) - # nb x nloc x a_nnei x a_nnei x (a + n_dim + e_dim*2) or (a + a/c + a/c) - # [OR] - # n_angle x (a + n_dim + e_dim*2) or (a + a/c + a/c) - angle_info = torch.cat(angle_info_list, dim=-1) - else: - angle_info = None + node_for_a, edge_for_a = self._prepare_angle_embeddings( + node_ebd, edge_ebd, a_nlist_mask + ) # edge angle message - # nb x nloc x a_nnei x a_nnei x e_dim [OR] n_angle x e_dim - if not self.optim_update: - assert angle_info is not None - edge_angle_update = self.act(self.edge_angle_linear1(angle_info)) - else: - edge_angle_update = self.act( - self.optim_angle_update( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - "edge", - ) - if not self.use_dynamic_sel - else self.optim_angle_update_dynamic( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - n2a_index, - eij2a_index, - eik2a_index, - "edge", - ) - ) - - if not self.use_dynamic_sel: - # nb x nloc x a_nnei x a_nnei x e_dim - weighted_edge_angle_update = ( - a_sw.unsqueeze(-1).unsqueeze(-1) - * a_sw.unsqueeze(-2).unsqueeze(-1) - * edge_angle_update - ) - # nb x nloc x a_nnei x e_dim - reduced_edge_angle_update = torch.sum( - weighted_edge_angle_update, dim=-2 - ) / (self.a_sel**0.5) - # nb x nloc x nnei x e_dim - padding_edge_angle_update = torch.concat( - [ - reduced_edge_angle_update, - torch.zeros( - [nb, nloc, self.nnei - self.a_sel, self.e_dim], - dtype=edge_ebd.dtype, - device=edge_ebd.device, - ), - ], - dim=2, - ) - else: - # n_angle x e_dim - weighted_edge_angle_update = edge_angle_update * a_sw.unsqueeze(-1) - # n_edge x e_dim - padding_edge_angle_update = aggregate( - weighted_edge_angle_update, - eij2a_index, - average=False, - num_owner=n_edge, - ) / (self.dynamic_a_sel**0.5) - - if not self.smooth_edge_update: - # will be deprecated in the future - # not support dynamic index, will pass anyway - if self.use_dynamic_sel: - raise NotImplementedError( - "smooth_edge_update must be True when use_dynamic_sel is True!" - ) - full_mask = torch.concat( - [ - a_nlist_mask, - torch.zeros( - [nb, nloc, self.nnei - self.a_sel], - dtype=a_nlist_mask.dtype, - device=a_nlist_mask.device, - ), - ], - dim=-1, - ) - padding_edge_angle_update = torch.where( - full_mask.unsqueeze(-1), padding_edge_angle_update, edge_ebd - ) - e_update_list.append( - self.act(self.edge_angle_linear2(padding_edge_angle_update)) + edge_angle_update = self._compute_angle_update( + angle_ebd, + node_for_a, + edge_for_a, + "edge", + n2a_index, + eij2a_index, + eik2a_index, ) + edge_angle_processed = self._compute_edge_angle_reduction( + edge_angle_update, + edge_ebd, + a_sw, + a_nlist_mask, + nb, + nloc, + n_edge, + eij2a_index, + ) + e_update_list.append(edge_angle_processed) # update edge_ebd e_updated = self.list_update(e_update_list, "edge") # angle self message - # nb x nloc x a_nnei x a_nnei x dim_a - if not self.optim_update: - assert angle_info is not None - angle_self_update = self.act(self.angle_self_linear(angle_info)) - else: - angle_self_update = self.act( - self.optim_angle_update( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - "angle", - ) - if not self.use_dynamic_sel - else self.optim_angle_update_dynamic( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - n2a_index, - eij2a_index, - eik2a_index, - "angle", - ) - ) + angle_self_update = self._compute_angle_update( + angle_ebd, + node_for_a, + edge_for_a, + "angle", + n2a_index, + eij2a_index, + eik2a_index, + ) a_update_list.append(angle_self_update) else: # update edge_ebd @@ -1220,6 +1438,7 @@ def serialize(self) -> dict: "smooth_edge_update": self.smooth_edge_update, "use_dynamic_sel": self.use_dynamic_sel, "sel_reduce_factor": self.sel_reduce_factor, + "sequential_update": self.sequential_update, "node_self_mlp": self.node_self_mlp.serialize(), "node_sym_linear": self.node_sym_linear.serialize(), "node_edge_linear": self.node_edge_linear.serialize(), diff --git a/deepmd/pt/model/descriptor/repflows.py b/deepmd/pt/model/descriptor/repflows.py index 031287a797..e29fe01ac6 100644 --- a/deepmd/pt/model/descriptor/repflows.py +++ b/deepmd/pt/model/descriptor/repflows.py @@ -222,6 +222,7 @@ def __init__( use_exp_switch: bool = False, use_dynamic_sel: bool = False, sel_reduce_factor: float = 10.0, + sequential_update: bool = False, use_loc_mapping: bool = True, optim_update: bool = True, seed: int | list[int] | None = None, @@ -261,6 +262,7 @@ def __init__( self.use_exp_switch = use_exp_switch self.use_dynamic_sel = use_dynamic_sel self.sel_reduce_factor = sel_reduce_factor + self.sequential_update = sequential_update if self.use_dynamic_sel and not self.smooth_edge_update: raise NotImplementedError( "smooth_edge_update must be True when use_dynamic_sel is True!" @@ -332,6 +334,7 @@ def __init__( optim_update=self.optim_update, use_dynamic_sel=self.use_dynamic_sel, sel_reduce_factor=self.sel_reduce_factor, + sequential_update=self.sequential_update, smooth_edge_update=self.smooth_edge_update, seed=child_seed(child_seed(seed, 1), ii), trainable=trainable, diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 0213c01f9c..73b73af353 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1556,6 +1556,13 @@ def dpa3_repflow_args() -> list[Argument]: "or `a_sel / sel_reduce_factor` instead of the raw `e_sel` or `a_sel` values, " "accommodating larger selection numbers." ) + doc_sequential_update = ( + "Whether to use sequential update mode within each repflow layer. " + "When True, updates are applied sequentially: edge self → angle self (using updated edge) " + "→ edge angle (using updated angle) → node (using final edge), " + "instead of the default parallel mode where all updates use original embeddings. " + "Currently only supports update_style='res_residual' and requires update_angle=True." + ) return [ # repflow args @@ -1686,6 +1693,13 @@ def dpa3_repflow_args() -> list[Argument]: default=10.0, doc=doc_sel_reduce_factor, ), + Argument( + "sequential_update", + bool, + optional=True, + default=False, + doc=doc_sequential_update, + ), ] diff --git a/source/tests/consistent/descriptor/test_dpa3.py b/source/tests/consistent/descriptor/test_dpa3.py index b067ca94dc..2aa0fd931b 100644 --- a/source/tests/consistent/descriptor/test_dpa3.py +++ b/source/tests/consistent/descriptor/test_dpa3.py @@ -81,6 +81,7 @@ "n_multi_edge_message", "precision", "add_chg_spin_ebd", + "sequential_update", ) @@ -100,6 +101,7 @@ "n_multi_edge_message": 1, "precision": "float64", "add_chg_spin_ebd": False, + "sequential_update": False, } @@ -131,6 +133,7 @@ def dpa3_case(**overrides: Any) -> tuple: dpa3_case(edge_init_use_dist=False), dpa3_case(use_exp_switch=False), dpa3_case(use_dynamic_sel=False), + dpa3_case(sequential_update=True), # One mixed high-risk path to keep interactions covered. dpa3_case( exclude_types=[[0, 1]], @@ -142,6 +145,7 @@ def dpa3_case(**overrides: Any) -> tuple: use_dynamic_sel=False, use_loc_mapping=False, add_chg_spin_ebd=True, + sequential_update=True, ), ) @@ -169,6 +173,7 @@ def dpa3_descriptor_api_case(**overrides: Any) -> tuple: dpa3_descriptor_api_case(edge_init_use_dist=False), dpa3_descriptor_api_case(use_exp_switch=False), dpa3_descriptor_api_case(use_dynamic_sel=False), + dpa3_descriptor_api_case(sequential_update=True), # One mixed high-risk path to keep interactions covered. dpa3_descriptor_api_case( exclude_types=[[0, 1]], @@ -181,6 +186,7 @@ def dpa3_descriptor_api_case(**overrides: Any) -> tuple: use_loc_mapping=False, fix_stat_std=0.0, add_chg_spin_ebd=True, + sequential_update=True, ), ) @@ -205,6 +211,7 @@ def data(self) -> dict: n_multi_edge_message, precision, add_chg_spin_ebd, + sequential_update, ) = self.param return { "ntypes": self.ntypes, @@ -236,6 +243,7 @@ def data(self) -> dict: "update_style": "res_residual", "update_residual": 0.1, "update_residual_init": update_residual_init, + "sequential_update": sequential_update, } ), # kwargs for descriptor @@ -266,6 +274,7 @@ def skip_pt(self) -> bool: _n_multi_edge_message, _precision, _add_chg_spin_ebd, + _sequential_update, ) = self.param return CommonTest.skip_pt @@ -287,6 +296,7 @@ def skip_pd(self) -> bool: _n_multi_edge_message, _precision, add_chg_spin_ebd, + _sequential_update, ) = self.param return True if add_chg_spin_ebd else CommonTest.skip_pd @@ -308,6 +318,7 @@ def skip_dp(self) -> bool: _n_multi_edge_message, _precision, _add_chg_spin_ebd, + _sequential_update, ) = self.param return CommonTest.skip_dp @@ -329,6 +340,7 @@ def skip_tf(self) -> bool: _n_multi_edge_message, _precision, _add_chg_spin_ebd, + _sequential_update, ) = self.param return True @@ -394,6 +406,7 @@ def setUp(self) -> None: _n_multi_edge_message, _precision, add_chg_spin_ebd, + _sequential_update, ) = self.param # fparam for charge=5, spin=1 when add_chg_spin_ebd is True self.fparam = ( @@ -500,6 +513,7 @@ def rtol(self) -> float: _n_multi_edge_message, precision, _add_chg_spin_ebd, + _sequential_update, ) = self.param if precision == "float64": return 1e-10 @@ -527,6 +541,7 @@ def atol(self) -> float: _n_multi_edge_message, precision, _add_chg_spin_ebd, + _sequential_update, ) = self.param if precision == "float64": return 1e-6 # need to fix in the future, see issue https://github.com/deepmodeling/deepmd-kit/issues/3786 @@ -563,6 +578,7 @@ def data(self) -> dict: n_multi_edge_message, precision, add_chg_spin_ebd, + sequential_update, ) = self.param return { "ntypes": self.ntypes, @@ -594,6 +610,7 @@ def data(self) -> dict: "update_style": "res_residual", "update_residual": 0.1, "update_residual_init": update_residual_init, + "sequential_update": sequential_update, } ), # kwargs for descriptor diff --git a/source/tests/pt/model/test_dpa3.py b/source/tests/pt/model/test_dpa3.py index 12b0be4532..d66eab9dea 100644 --- a/source/tests/pt/model/test_dpa3.py +++ b/source/tests/pt/model/test_dpa3.py @@ -56,6 +56,7 @@ def test_consistency( prec, ect, add_chg_spin, + seq_upd, ) in itertools.product( [True, False], # update_angle ["res_residual"], # update_style @@ -67,7 +68,11 @@ def test_consistency( ["float64"], # precision [False], # use_econf_tebd [False, True], # add_chg_spin_ebd + [False, True], # sequential_update ): + # sequential_update only works with update_angle=True + if seq_upd and not ua: + continue dtype = PRECISION_DICT[prec] rtol, atol = get_tols(prec) if prec == "float64": @@ -93,6 +98,7 @@ def test_consistency( update_style=rus, update_residual_init=ruri, smooth_edge_update=True, + sequential_update=seq_upd, ) # dpa3 new impl @@ -177,6 +183,7 @@ def test_jit( nme, prec, ect, + seq_upd, ) in itertools.product( [True], # update_angle ["res_residual"], # update_style @@ -187,6 +194,7 @@ def test_jit( [1, 2], # n_multi_edge_message ["float64"], # precision [False], # use_econf_tebd + [False, True], # sequential_update ): dtype = PRECISION_DICT[prec] rtol, atol = get_tols(prec) @@ -211,6 +219,7 @@ def test_jit( update_style=rus, update_residual_init=ruri, smooth_edge_update=True, + sequential_update=seq_upd, ) # dpa3 new impl