diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 13f8f3e351..b498d3d296 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -428,15 +428,14 @@ def get_numb_attn_layer(self) -> int: def uses_graph_lower(self) -> bool: """Returns whether this descriptor supports the graph-native lower. - The graph-native energy lower (``call_graph``) currently covers only the - non-attention (``attn_layer == 0``) factorizable path with concat - type-embedding and no type exclusion. Any other config (attention, - ``tebd_input_mode == "strip"``, ``exclude_types``) falls back to the - legacy dense path, so those models keep working unchanged. + The graph-native lower (``call_graph``) covers the factorizable path + AND transformer attention (``attn_layer >= 0``, NeighborGraph PR-D) + with concat type-embedding and no type exclusion. Remaining ineligible + configs (``tebd_input_mode == "strip"``, ``exclude_types``) fall back + to the legacy dense path, so those models keep working unchanged. """ return ( - self.se_atten.attn_layer == 0 - and self.se_atten.tebd_input_mode == "concat" + self.se_atten.tebd_input_mode == "concat" and not self.se_atten.exclude_types ) @@ -643,6 +642,9 @@ def _call_graph_adapter( graph, atype_local, type_embedding=self.type_embedding.call(), + # the adapter graph is shape-static center-major (compact=False): + # keep the attention pair enumeration nonzero-free (traceable) + static_nnei=nnei, ) # call_graph returns flat (N, ...) node axis; reshape to (nf, nloc, ...) # for the dense 5-tuple ABI -- this reshape is LOCAL to the adapter shim. @@ -727,8 +729,9 @@ def call_graph( graph: Any, atype: Array, type_embedding: Array | None = None, + static_nnei: int | None = None, ) -> tuple[Array, Array]: - """Descriptor-level graph-native forward (``attn_layer == 0``). + """Descriptor-level graph-native forward. Wraps the block kernel :meth:`DescrptBlockSeAtten.call_graph`, adds the descriptor-level @@ -760,7 +763,7 @@ def call_graph( xp = array_api_compat.array_namespace(graph.edge_vec) dev = array_api_compat.device(graph.edge_vec) grrg, rot_mat = self.se_atten.call_graph( - graph, atype, type_embedding=type_embedding + graph, atype, type_embedding=type_embedding, static_nnei=static_nnei ) # FLAT node axis (N, ...): no (nf, nloc) reshape -- ragged-native, spec. if self.concat_output_tebd: @@ -1670,12 +1673,15 @@ def call_graph( graph: Any, atype: Array, type_embedding: Array | None = None, + static_nnei: int | None = None, ) -> tuple[Array, Array]: - """Graph-native forward (``attn_layer=0`` only). + """Graph-native forward. Bit-exact analogue of :meth:`call` on the SAME neighbor list, with the neighbor-axis reduction replaced by a ``segment_sum`` over edge centers - (``dst``). Geometry enters only through ``graph.edge_vec``. + (``dst``) and the dense ``(nnei, nnei)`` transformer attention replaced + by pairs of edges sharing a center (``center_edge_pairs`` + + ``segment_softmax``). Geometry enters only through ``graph.edge_vec``. Parameters ---------- @@ -1687,6 +1693,12 @@ def call_graph( (N,) flat node atom types (``N = sum(graph.n_node)``). type_embedding (ntypes_with_padding, tebd_dim) type-embedding table. + static_nnei + When the graph uses the shape-static center-major layout + (``from_dense_quartet(compact=False)``, ``E = n_center * nnei``), + pass ``nnei`` so the attention edge-pair enumeration stays + jit/export-traceable (no ``nonzero``). ``None`` (carry-all / + compact graphs) selects the dynamic eager form. Returns ------- @@ -1699,8 +1711,7 @@ def call_graph( Notes ----- - Known limitations (NeighborGraph PR-A): - - ``attn_layer == 0`` only (attention lands in PR-D); + Known limitations: - ``tebd_input_mode == "concat"`` only (strip mode lands later); - ``exclude_types`` is not yet supported and raises (lands in a later PR). """ @@ -1709,11 +1720,6 @@ def call_graph( segment_sum, ) - if self.attn_layer != 0: - raise NotImplementedError( - "graph path supports attn_layer=0 only (NeighborGraph PR-A); " - "attn_layer>0 lands in PR-D" - ) if self.tebd_input_mode not in ["concat"]: raise NotImplementedError( "graph path supports tebd_input_mode='concat' only (NeighborGraph PR-A)" @@ -1738,7 +1744,7 @@ def call_graph( # per-edge env-mat 4-vector, normalized by the center (dst) atom type. # self.mean/self.stddev are slot-independent (ntypes, nnei, 4); slot 0 is # the canonical per-type vector. - rr = edge_env_mat( + rr, sw_e = edge_env_mat( graph.edge_vec, center_type, self.mean[:, 0, :], @@ -1747,7 +1753,8 @@ def call_graph( self.rcut_smth, protection=self.env_protection, edge_mask=graph.edge_mask, - ) # (E, 4) + return_sw=True, + ) # (E, 4), (E, 1) sw zeroed on padding # radial channel ss = rr[:, 0:1] # (E, 1) # neighbor / center type embeddings (concat mode); ghost type == owner type @@ -1764,6 +1771,13 @@ def call_graph( ss = xp.concat([ss, atype_embd_nlist], axis=-1) # embedding net (same weights as the dense path); applies on the last axis gg = self.embeddings[0].call(ss) # (E, ng) + # transformer attention over each center's edges — mirrors the dense + # self.dpa1_attention(gg, nlist_mask, input_r, sw), which also runs on + # the UNMASKED gg (padding rows are neutralized afterwards). + if self.attn_layer > 0: + gg = self._graph_attention( + gg, rr, dst, n_total, graph.edge_mask, sw_e, static_nnei + ) # zero padding/guard edges BEFORE the segment sum gg = gg * xp.astype(graph.edge_mask[:, None], gg.dtype) # outer product (replaces the dense gg[:,:,:,None] * rr[:,:,None,:]) @@ -1784,6 +1798,138 @@ def call_graph( rot_mat = gr[:, :, 1:] return grrg, rot_mat + def _graph_attention( + self, + gg: Array, + rr: Array, + dst: Array, + n_total: int, + edge_mask: Array, + sw_e: Array, + static_nnei: int | None, + ) -> Array: + """Graph-native transformer attention over each center's edges. + + Ragged reproduction of :class:`NeighborGatedAttention` / + :class:`GatedAttentionLayer`: edges sharing a center attend to each + other. The dense ``(nnei, nnei)`` square per center becomes the + edge-pair axis from ``center_edge_pairs(ordered=True, + include_self=True)``; softmax over the key axis becomes + ``segment_softmax`` grouped by the query edge. + + Parameters + ---------- + gg : (E, ng) per-edge embedding (UNMASKED, as in the dense path). + rr : (E, 4) per-edge env-mat vector (``rr[:, 1:4]`` carries direction). + dst : (E,) center of each edge. + n_total : number of centers. + edge_mask : (E,) real-vs-padding edge mask. + sw_e : (E, 1) smooth switch, zeroed on padding edges. + static_nnei : shape-static layout ``nnei`` or ``None`` (compact eager). + """ + from deepmd.dpmodel.utils.neighbor_graph import ( + center_edge_pairs, + ) + + xp = array_api_compat.array_namespace(gg) + # per-edge normalized direction (mirrors the dense input_r, + # rr[..., 1:4] / max(|rr[..., 1:4]|, 1e-12)) + dir3 = rr[:, 1:4] + normed = safe_for_vector_norm(dir3, axis=-1, keepdims=True) + input_r = dir3 / xp.maximum(normed, xp.full_like(normed, 1e-12)) # (E, 3) + # transformer neighbor-pairs: full ordered square incl. the diagonal + # (q_m . k_n is not symmetric and self-attention keeps m == n) + q_e, k_e, pair_mask = center_edge_pairs( + dst, + edge_mask, + n_total, + include_self=True, + ordered=True, + static_nnei=static_nnei, + ) + for layer in self.dpa1_attention.attention_layers: + gg = self._graph_attention_one_layer( + layer, gg, input_r, sw_e, q_e, k_e, pair_mask + ) + return gg + + def _graph_attention_one_layer( + self, + layer: "NeighborGatedAttentionLayer", + gg: Array, + input_r: Array, + sw_e: Array, + q_e: Array, + k_e: Array, + pair_mask: Array, + ) -> Array: + """One residual attention layer, op-for-op vs the dense reference. + + Mirrors ``NeighborGatedAttentionLayer.call`` (residual + + ``GatedAttentionLayer.call`` + LayerNorm). Structural translation: + per-center ``q @ k^T`` -> per-pair ``q_m . k_n``; softmax over the key + axis -> ``segment_softmax`` grouped by the query edge. The smooth + branch keeps padding pairs IN the softmax denominator with ``sw = 0`` + (weight ``exp(-attnw_shift)``), exactly like the dense branch, which + replaces the ``-inf`` masking by the switch weighting. + """ + from deepmd.dpmodel.utils.neighbor_graph import ( + segment_softmax, + segment_sum, + ) + + xp = array_api_compat.array_namespace(gg) + e_tot = gg.shape[0] + gal = layer.attention_layer # GatedAttentionLayer + if gal.num_heads != 1: + raise NotImplementedError( + "graph attention assumes num_heads == 1 (dpa1 never exposes " + "num_heads; the dense head_dim QKV slicing relies on it)" + ) + hd = gal.head_dim # == hidden_dim for num_heads == 1 + residual = gg + # in_proj -> Q, K, V; mirror the dense HEAD_DIM slicing exactly + qkv = gal.in_proj.call(gg) # (E, 3 * hidden) + q = qkv[:, 0:hd] + k = qkv[:, hd : hd * 2] + v = qkv[:, hd * 2 : hd * 3] + if gal.normalize: + q = np_normalize(q, axis=-1) + k = np_normalize(k, axis=-1) + v = np_normalize(v, axis=-1) + q = q * gal.scaling + # per-pair logits q_m . k_n (num_heads == 1) + logits = xp.sum( + xp.take(q, q_e, axis=0) * xp.take(k, k_e, axis=0), axis=-1 + ) # (P,) + if gal.smooth: + # (logits + shift) * sw_m * sw_n - shift, then softmax WITHOUT the + # pair mask: padding pairs stay in the denominator at exp(-shift), + # mirroring the dense smooth branch (sw already zeroed on padding). + attnw_shift = 20.0 # dense GatedAttentionLayer.call default + sw_flat = sw_e[:, 0] # (E,) + sw_q = xp.take(sw_flat, q_e, axis=0) + sw_k = xp.take(sw_flat, k_e, axis=0) + logits = (logits + attnw_shift) * sw_q * sw_k - attnw_shift + w = segment_softmax(logits, q_e, e_tot) # (P,) + w = w * sw_q * sw_k + else: + # non-smooth: dense masks padding keys to -inf pre-softmax == + # excluding them from the softmax entirely + w = segment_softmax(logits, q_e, e_tot, mask=pair_mask) + if gal.dotr: + angular = xp.sum( + xp.take(input_r, q_e, axis=0) * xp.take(input_r, k_e, axis=0), + axis=-1, + ) # (P,) = input_r_m . input_r_n + w = w * angular + # o_m = sum_n w[m, n] v[n] -> segment_sum over the query edge + wv = w[:, None] * xp.take(v, k_e, axis=0) # (P, hd) + o = segment_sum(wv, q_e, e_tot) # (E, hd) + out = gal.out_proj.call(o) # (E, ng) + x = residual + out + return layer.attn_layer_norm.call(x) + def has_message_passing(self) -> bool: """Returns whether the descriptor block has message passing.""" return False diff --git a/deepmd/dpmodel/utils/neighbor_graph/__init__.py b/deepmd/dpmodel/utils/neighbor_graph/__init__.py index 6e041805b2..ca524991d2 100644 --- a/deepmd/dpmodel/utils/neighbor_graph/__init__.py +++ b/deepmd/dpmodel/utils/neighbor_graph/__init__.py @@ -9,6 +9,14 @@ See the design discussion wanghan-iapcm/deepmd-kit#4. """ +from .angles import ( + angle_padding_fraction, + angle_to_edge_sum, + angle_to_node_sum, + attach_angles, + build_angle_index, + graph_angle_cos, +) from .ase_builder import ( build_neighbor_graph_ase, ) @@ -30,25 +38,41 @@ NeighborGraph, frame_id_from_n_node, node_validity_mask, + pad_and_guard_angles, pad_and_guard_edges, ) +from .pairs import ( + center_edge_pairs, +) from .segment import ( + segment_max, segment_mean, + segment_softmax, segment_sum, ) __all__ = [ "GraphLayout", "NeighborGraph", + "angle_padding_fraction", + "angle_to_edge_sum", + "angle_to_node_sum", + "attach_angles", + "build_angle_index", "build_neighbor_graph", "build_neighbor_graph_ase", + "center_edge_pairs", "edge_env_mat", "edge_force_virial", "frame_id_from_n_node", "from_dense_quartet", + "graph_angle_cos", "neighbor_graph_from_ijs", "node_validity_mask", + "pad_and_guard_angles", "pad_and_guard_edges", + "segment_max", "segment_mean", + "segment_softmax", "segment_sum", ] diff --git a/deepmd/dpmodel/utils/neighbor_graph/angles.py b/deepmd/dpmodel/utils/neighbor_graph/angles.py new file mode 100644 index 0000000000..be208986ed --- /dev/null +++ b/deepmd/dpmodel/utils/neighbor_graph/angles.py @@ -0,0 +1,254 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""3-body angle graph: pairs of edges sharing a center within a_rcut. + +Angles reference EDGES (angle_index into [0,E)); edge_vec stays the only +geometry leaf. a_sel is normalization-only (not a truncation). Reuses PR-D's +center_edge_pairs; a_rcut filters the participating edges. +""" + +from __future__ import ( + annotations, +) + +from typing import ( + TYPE_CHECKING, +) + +import array_api_compat + +if TYPE_CHECKING: + from deepmd.dpmodel.array_api import Array + +import dataclasses + +from .graph import ( + GraphLayout, + NeighborGraph, + pad_and_guard_angles, +) +from .pairs import ( + center_edge_pairs, +) +from .segment import ( + segment_sum, +) + + +def build_angle_index( + edge_index: Array, + edge_vec: Array, + edge_mask: Array, + n_total: int, + a_rcut: float, + *, + ordered: bool = False, + include_self: bool = False, + layout: GraphLayout | None = None, +) -> tuple[Array, Array]: + """Build angle index for 3-body terms. + + Parameters + ---------- + edge_index : Array + Shape (2, E) [src, dst] SoA edge indices. + edge_vec : Array + Shape (E, 3) edge vectors (neighbor - center). + edge_mask : Array + Shape (E,) boolean validity mask for edges. + n_total : int + Total number of nodes. + a_rcut : float + Angle cutoff. Only edges with norm < a_rcut participate in angles. + ordered : bool, optional + If True, include both (a, b) and (b, a) pairs (ordered pairs). + include_self : bool, optional + If True, include self-angle pairs (a, a). + layout : GraphLayout or None, optional + If provided, uses layout.angle_capacity as static padding capacity. + + Returns + ------- + angle_index : Array + Shape (2, A) index pairs into the edge list. + angle_mask : Array + Shape (A,) boolean mask for valid angles. + """ + xp = array_api_compat.array_namespace(edge_index) + # a_rcut edge gate: only edges within a_rcut may participate in an angle + dist = xp.linalg.vector_norm(edge_vec, axis=-1) # (E,) + a_edge_mask = xp.astype(edge_mask, xp.bool) & (dist < a_rcut) + # compact eager form only (static_nnei not exposed until angle export is + # needed, PR-G). dst = edge_index[1, :] per the [src, dst] SoA convention. + q_e, k_e, pair_mask = center_edge_pairs( + edge_index[1, :], + a_edge_mask, + n_total, + include_self=include_self, + ordered=ordered, + ) + # compact form returns all-True pair_mask, but NEVER discard it: the + # shape-static form keeps filtered pairs and invalidates them only here. + angle_index = xp.stack([q_e, k_e], axis=0) # (2, A_real) + cap = layout.angle_capacity if layout is not None else None + ai, am = pad_and_guard_angles(angle_index, cap, min_angles=2) + # fold pair_mask into the real-angle prefix of the padded mask + pm_padded = xp.concat( + [ + pair_mask, + xp.zeros( + (am.shape[0] - pair_mask.shape[0],), + dtype=xp.bool, + device=array_api_compat.device(pair_mask), + ), + ], + axis=0, + ) + return ai, am & pm_padded + + +def attach_angles( + graph: NeighborGraph, + a_rcut: float, + *, + ordered: bool = False, + include_self: bool = False, + layout: GraphLayout | None = None, +) -> NeighborGraph: + """Attach angle_index/angle_mask to an existing edge-only NeighborGraph. + + Parameters + ---------- + graph : NeighborGraph + Input graph (edge fields must be populated). + a_rcut : float + Angle cutoff radius. Only edges with norm < a_rcut participate. + ordered : bool, optional + If True, include both (a, b) and (b, a) angle pairs. + include_self : bool, optional + If True, include self-angle pairs (a, a). + layout : GraphLayout or None, optional + If provided, uses layout.angle_capacity and layout.node_capacity. + + Returns + ------- + NeighborGraph + A new NeighborGraph with angle_index and angle_mask populated; + all edge/node fields are unchanged. + """ + xp = array_api_compat.array_namespace(graph.edge_index) + if layout is not None and layout.node_capacity is not None: + n_total = layout.node_capacity + else: + n_total = int(xp.sum(graph.n_node)) + ai, am = build_angle_index( + graph.edge_index, + graph.edge_vec, + graph.edge_mask, + n_total, + a_rcut, + ordered=ordered, + include_self=include_self, + layout=layout, + ) + return dataclasses.replace(graph, angle_index=ai, angle_mask=am) + + +def graph_angle_cos(angle_index: Array, edge_vec: Array, eps: float = 1e-6) -> Array: + """Per-angle cosine, mirroring dpa3 ``cosine_ij`` (repflows.py:632-644). + + Parameters + ---------- + angle_index : Array + Shape (2, A) index pairs into edge list. ``angle_index[0, a]`` is + edge_a and ``angle_index[1, a]`` is edge_b for angle ``a``. + edge_vec : Array + Shape (E, 3) edge vectors (r_src - r_dst, i.e. neighbor - center). + eps : float, optional + Numerical stabiliser: norm denominators use ``||v|| + eps`` and the + dot product is scaled by ``(1 - eps)``. Mirrors the dpa3 dense + channel exactly (repflows.py:643-649). + + Returns + ------- + Array + Shape (A,) cosine values, one per angle slot (valid and padding). + Padding slots carry arbitrary values; mask with angle_mask before use. + """ + xp = array_api_compat.array_namespace(edge_vec) + va = xp.take(edge_vec, angle_index[0, :], axis=0) # (A, 3) + vb = xp.take(edge_vec, angle_index[1, :], axis=0) # (A, 3) + na = va / (xp.linalg.vector_norm(va, axis=-1, keepdims=True) + eps) + nb = vb / (xp.linalg.vector_norm(vb, axis=-1, keepdims=True) + eps) + return xp.sum(na * nb, axis=-1) * (1.0 - eps) + + +def angle_to_edge_sum(data: Array, angle_index: Array, num_edges: int) -> Array: + """Aggregate per-angle data to the angle's query edge (edge_a). + + Parameters + ---------- + data : Array + Shape (A,) or (A, ...) per-angle data to aggregate. + angle_index : Array + Shape (2, A) angle index pairs into edges. + num_edges : int + Total number of edges (E). + + Returns + ------- + Array + Shape (E,) or (E, ...) aggregated per-edge data. + """ + return segment_sum(data, angle_index[0, :], num_edges) + + +def angle_to_node_sum( + data: Array, angle_index: Array, edge_index: Array, num_nodes: int +) -> Array: + """Aggregate per-angle data to the shared center (dst of edge_a). + + Parameters + ---------- + data : Array + Shape (A,) or (A, ...) per-angle data to aggregate. + angle_index : Array + Shape (2, A) angle index pairs into edges. + edge_index : Array + Shape (2, E) edge indices [src, dst]. + num_nodes : int + Total number of nodes (N). + + Returns + ------- + Array + Shape (N,) or (N, ...) aggregated per-node data. + """ + xp = array_api_compat.array_namespace(data) + center = xp.take(edge_index[1, :], angle_index[0, :], axis=0) + return segment_sum(data, center, num_nodes) + + +def angle_padding_fraction(graph: NeighborGraph) -> float: + """Return the fraction of angle slots that are padding (guard entries). + + Parameters + ---------- + graph : NeighborGraph + A graph with ``angle_mask`` set (i.e., after :func:`attach_angles` + with a static ``GraphLayout.angle_capacity``). + + Returns + ------- + float + ``1 - A_real / A_max`` where ``A_real`` is the count of valid angles + and ``A_max`` is ``angle_mask.shape[0]``. Returns ``0.0`` when the + mask is empty. + """ + if graph.angle_mask is None: + return 0.0 + xp = array_api_compat.array_namespace(graph.angle_mask) + total = graph.angle_mask.shape[0] + if total == 0: + return 0.0 + real = int(xp.sum(xp.astype(graph.angle_mask, xp.int64))) + return 1.0 - real / total diff --git a/deepmd/dpmodel/utils/neighbor_graph/env.py b/deepmd/dpmodel/utils/neighbor_graph/env.py index 55bbe1b02f..4057cd8640 100644 --- a/deepmd/dpmodel/utils/neighbor_graph/env.py +++ b/deepmd/dpmodel/utils/neighbor_graph/env.py @@ -41,7 +41,8 @@ def edge_env_mat( rcut_smth: float, protection: float = 0.0, edge_mask: Array | None = None, -) -> Array: + return_sw: bool = False, +) -> Array | tuple[Array, Array]: """Compute the per-edge environment-matrix 4-vector. Mirrors the math in ``_make_env_mat`` / ``EnvMat.call`` (env_mat.py) @@ -79,6 +80,9 @@ def edge_env_mat( (E, 4) normalized environment-matrix vectors. Padding edges (``edge_vec = 0``) produce nonzero values but are masked by ``NeighborGraph.edge_mask`` downstream. + When ``return_sw`` is True, returns ``(em, sw)`` where ``sw`` is the + (E, 1) smooth switch, zeroed on padding edges (mirrors the dense + ``_make_env_mat`` mask; consumed by the smooth attention branch). """ xp = array_api_compat.array_namespace(edge_vec) dev = array_api_compat.device(edge_vec) @@ -114,4 +118,13 @@ def edge_env_mat( avg = xp.take(xp.asarray(davg, device=dev), center_type, axis=0) # (E, 4) std = xp.take(xp.asarray(dstd, device=dev), center_type, axis=0) # (E, 4) + if return_sw: + # per-edge switch, zeroed on padding edges — mirrors the dense + # ``_make_env_mat`` (``weight = weight * mask``); used by the smooth + # attention branch. + if edge_mask is not None: + sw_out = sw * xp.astype(edge_mask[:, None], sw.dtype) + else: + sw_out = sw + return (em - avg) / std, sw_out return (em - avg) / std diff --git a/deepmd/dpmodel/utils/neighbor_graph/graph.py b/deepmd/dpmodel/utils/neighbor_graph/graph.py index 0ce10efdf6..67f7d53845 100644 --- a/deepmd/dpmodel/utils/neighbor_graph/graph.py +++ b/deepmd/dpmodel/utils/neighbor_graph/graph.py @@ -123,6 +123,57 @@ def pad_and_guard_edges( return ei, ev, edge_mask +def pad_and_guard_angles( + angle_index: Array, + angle_capacity: int | None = None, + min_angles: int = 2, + pad_value: int = 0, +) -> tuple[Array, Array]: + """Append padding/guard angles as a contiguous suffix and build angle_mask. + + Real angles (``angle_index``) stay at the front (compact layout). + Dummy angles point at edge ``pad_value`` (in-range). + + Parameters + ---------- + angle_index + (2, A_real) ``[edge_a, edge_b]`` edge endpoints of the real angles. + angle_capacity + Target angle-axis length ``A_max``. ``None`` (torch dynamic) appends + exactly ``min_angles`` masked dummy angles so the axis has a known lower + bound and shape-stable guards for export; an int (jax static) pads to + ``A_max = angle_capacity`` and raises ``ValueError`` on overflow. + min_angles + Number of dummy angles appended when ``angle_capacity is None``. + pad_value + Edge index the dummy angles point at (must be in range). + + Returns + ------- + angle_index + (2, target) padded angle endpoints. + angle_mask + (target,) boolean mask, ``True`` for the real-angle prefix. + """ + xp = array_api_compat.array_namespace(angle_index) + dev = array_api_compat.device(angle_index) + a_real = angle_index.shape[1] + if angle_capacity is None: + target = a_real + min_angles + else: + if a_real > angle_capacity: + raise ValueError( + f"angle overflow: {a_real} real angles > angle_capacity {angle_capacity}" + ) + target = angle_capacity + n_pad = target - a_real + pad_idx = xp.full((2, n_pad), pad_value, dtype=angle_index.dtype, device=dev) + ai = xp.concat([angle_index, pad_idx], axis=1) + arange = xp.arange(target, dtype=angle_index.dtype, device=dev) + angle_mask = arange < a_real + return ai, angle_mask + + def frame_id_from_n_node(n_node: Array, n_total: int | None = None) -> Array: """Node->frame map for a flat node axis: ``repeat(arange(nf), n_node)``. diff --git a/deepmd/dpmodel/utils/neighbor_graph/pairs.py b/deepmd/dpmodel/utils/neighbor_graph/pairs.py new file mode 100644 index 0000000000..f8614d821b --- /dev/null +++ b/deepmd/dpmodel/utils/neighbor_graph/pairs.py @@ -0,0 +1,174 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Pairs of edges sharing a center (``dst``) — the edge-pair axis. + +Shared primitive: graph-native attention (NeighborGraph PR-D) uses +``(ordered=True, include_self=True)`` = the full transformer neighbor-pair +square per center; 3-body angles (PR-E) use ``(ordered=False, +include_self=False)``. + +Two forms: + +- **compact eager** (``static_nnei=None``): segment-based enumeration over the + real edges only — sort edge ids by center, expand each center's Cartesian + square via cumsum offsets. Dynamic ``P = sum(deg**2)``; memory ``O(P)`` + (same order as dense attention's ``O(nloc * nnei**2)``). Uses data-dependent + shapes (``nonzero``) so it is EAGER-ONLY. +- **shape-static** (``static_nnei`` set): assumes the center-major static + layout (``E = n_center * static_nnei``, edge ``c * static_nnei + m`` belongs + to center ``c`` — the layout ``from_dense_quartet(compact=False)`` emits). + Pure arange/reshape arithmetic, ``P = n_center * static_nnei**2`` with all + pairs materialized and validity carried by ``pair_mask`` — no data-dependent + ops, so it stays jit/export/make_fx-traceable. + +A global ``(E, E)`` same-center boolean is deliberately NOT used: with +``E ~ N * nnei`` it costs ``O(N**2 * nnei**2)`` memory. +""" + +from __future__ import ( + annotations, +) + +from typing import ( + Any, +) + +import array_api_compat + +from deepmd.dpmodel.array_api import ( + Array, + xp_add_at, +) + + +def center_edge_pairs( + dst: Array, + edge_mask: Array, + n_total: int, + *, + include_self: bool = True, + ordered: bool = True, + static_nnei: int | None = None, +) -> tuple[Array, Array, Array]: + """Enumerate pairs of edges sharing a center. + + Parameters + ---------- + dst : Array + (E,) int64 center of each edge (``edge_index[1]``). + edge_mask : Array + (E,) bool, real (True) vs padding (False) edges. + n_total : int + Number of centers (bounds ``dst``). + include_self : bool + Keep the ``m == n`` diagonal (transformer self-attention needs it). + ordered : bool + Keep both ``(m, n)`` and ``(n, m)`` (attention: yes, ``q_m . k_n`` is + not symmetric). ``False`` keeps only ``n >= m`` (with + ``include_self=False``: ``n > m`` — the angle set). + static_nnei : int | None + ``None`` -> compact eager form. Set -> shape-static form assuming the + center-major layout ``E = n_center * static_nnei``. + + Returns + ------- + query_edge : Array + (P,) int64 edge index of the query (``m``). + key_edge : Array + (P,) int64 edge index of the key (``n``). + pair_mask : Array + (P,) bool; False where either edge is padding or the pair is filtered + by the ``include_self`` / ``ordered`` policy (shape-static form; the + compact form drops such pairs and returns all-True). + """ + xp = array_api_compat.array_namespace(dst) + dev = array_api_compat.device(dst) + if static_nnei is not None: + return _pairs_shape_static( + xp, dev, dst, edge_mask, static_nnei, include_self, ordered + ) + return _pairs_compact(xp, dev, dst, edge_mask, n_total, include_self, ordered) + + +def _pairs_shape_static( + xp: Any, + dev: Any, + dst: Array, + edge_mask: Array, + nn: int, + include_self: bool, + ordered: bool, +) -> tuple[Array, Array, Array]: + e_tot = dst.shape[0] + # (E, nn): every edge queries the nn slots of its own center block + eids = xp.arange(e_tot, dtype=xp.int64, device=dev) + base = (eids // nn) * nn # start of each edge's center block + slots = xp.arange(nn, dtype=xp.int64, device=dev) + q2 = xp.broadcast_to(eids[:, None], (e_tot, nn)) + k2 = base[:, None] + slots[None, :] + query_edge = xp.reshape(q2, (-1,)) + key_edge = xp.reshape(k2, (-1,)) + pair_mask = xp.take(edge_mask, query_edge, axis=0) & xp.take( + edge_mask, key_edge, axis=0 + ) + if not include_self: + pair_mask = pair_mask & (query_edge != key_edge) + if not ordered: + pair_mask = pair_mask & (key_edge >= query_edge) + return query_edge, key_edge, pair_mask + + +def _pairs_compact( + xp: Any, + dev: Any, + dst: Array, + edge_mask: Array, + n_total: int, + include_self: bool, + ordered: bool, +) -> tuple[Array, Array, Array]: + empty = ( + xp.zeros((0,), dtype=xp.int64, device=dev), + xp.zeros((0,), dtype=xp.int64, device=dev), + xp.zeros((0,), dtype=xp.bool, device=dev), + ) + if dst.shape[0] == 0: + return empty + # real edges only, grouped by center (stable sort keeps original order + # within a center — irrelevant for correctness, deterministic for tests) + (real_idx,) = xp.nonzero(edge_mask) + r_tot = real_idx.shape[0] + if r_tot == 0: + return empty + d_real = xp.take(dst, real_idx, axis=0) + order = xp.argsort(d_real, stable=True) + eid = xp.take(real_idx, order, axis=0) # (R,) edge ids, center-grouped + ds = xp.take(d_real, order, axis=0) # (R,) sorted centers + # per-center degree and group start (over the sorted layout) + ones = xp.ones((r_tot,), dtype=xp.int64, device=dev) + counts = xp_add_at( + xp.zeros((n_total,), dtype=xp.int64, device=dev), ds, ones + ) # (n_total,) + csum = xp.cumulative_sum(counts) + start = csum - counts # (n_total,) group start per center + deg = xp.take(counts, ds, axis=0) # (R,) degree of each edge's center + # each sorted edge t emits deg[t] pairs; P = sum(deg**2) + query_sorted = xp.repeat(xp.arange(r_tot, dtype=xp.int64, device=dev), deg) # (P,) + # within each query's block, a 0..deg-1 ramp indexes the key group + pair_off = xp.cumulative_sum(deg) - deg # (R,) exclusive prefix of deg + p_tot = query_sorted.shape[0] + ramp = xp.arange(p_tot, dtype=xp.int64, device=dev) - xp.take( + pair_off, query_sorted, axis=0 + ) + key_sorted = xp.take(start, xp.take(ds, query_sorted, axis=0), axis=0) + ramp + query_edge = xp.take(eid, query_sorted, axis=0) + key_edge = xp.take(eid, key_sorted, axis=0) + keep = xp.ones((p_tot,), dtype=xp.bool, device=dev) + if not include_self: + keep = keep & (query_edge != key_edge) + if not ordered: + keep = keep & (key_edge >= query_edge) + (kept,) = xp.nonzero(keep) + query_edge = xp.take(query_edge, kept, axis=0) + key_edge = xp.take(key_edge, kept, axis=0) + pair_mask = xp.ones((query_edge.shape[0],), dtype=xp.bool, device=dev) + return query_edge, key_edge, pair_mask diff --git a/deepmd/dpmodel/utils/neighbor_graph/segment.py b/deepmd/dpmodel/utils/neighbor_graph/segment.py index 45d64af08c..6f6d946f77 100644 --- a/deepmd/dpmodel/utils/neighbor_graph/segment.py +++ b/deepmd/dpmodel/utils/neighbor_graph/segment.py @@ -9,6 +9,7 @@ from deepmd.dpmodel.array_api import ( Array, xp_add_at, + xp_maximum_at, ) @@ -37,3 +38,52 @@ def segment_mean(data: Array, segment_ids: Array, num_segments: int) -> Array: # broadcast counts over the trailing dims of summed shape = (num_segments,) + (1,) * (summed.ndim - 1) return summed / xp.reshape(safe, shape) + + +def segment_max(data: Array, segment_ids: Array, num_segments: int) -> Array: + """out[s] = max of data[i] over i with segment_ids[i] == s. + + Shape ``(num_segments, *data.shape[1:])``; empty segments are ``-inf`` + (neutral element — callers guard with masks before consuming them). + """ + xp = array_api_compat.array_namespace(data) + out = xp.full( + (num_segments, *tuple(data.shape[1:])), + -xp.inf, + dtype=data.dtype, + device=array_api_compat.device(data), + ) + return xp_maximum_at(out, segment_ids, data) + + +def segment_softmax( + data: Array, + segment_ids: Array, + num_segments: int, + mask: Array | None = None, +) -> Array: + """Softmax over entries sharing a segment id, numerically stable. + + Mirrors the dense ``np_softmax`` max-subtraction trick with a PER-SEGMENT + max. ``mask`` (bool, per entry) removes masked entries from the softmax + entirely (zero weight AND excluded from the denominator). Empty or + fully-masked segments produce all-zero weights (no NaN). + """ + xp = array_api_compat.array_namespace(data) + if mask is not None: + # keep masked entries out of the per-segment max: send them to -inf + neg = xp.full_like(data, -xp.inf) + data_for_max = xp.where(mask, data, neg) + else: + data_for_max = data + seg_max = segment_max(data_for_max, segment_ids, num_segments) + # guard -inf (empty / fully-masked segments) so gather doesn't yield inf-inf + seg_max = xp.where(xp.isinf(seg_max), xp.zeros_like(seg_max), seg_max) + shifted = data - xp.take(seg_max, segment_ids, axis=0) + ex = xp.exp(shifted) + if mask is not None: + ex = ex * xp.astype(mask, ex.dtype) + denom = segment_sum(ex, segment_ids, num_segments) + denom_e = xp.take(denom, segment_ids, axis=0) + safe = xp.where(denom_e > 0, denom_e, xp.ones_like(denom_e)) + return ex / safe diff --git a/source/tests/common/dpmodel/test_angle_builder.py b/source/tests/common/dpmodel/test_angle_builder.py new file mode 100644 index 0000000000..17a78862d9 --- /dev/null +++ b/source/tests/common/dpmodel/test_angle_builder.py @@ -0,0 +1,456 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import numpy as np +import pytest + +from deepmd.dpmodel.utils.neighbor_graph import ( + GraphLayout, + angle_padding_fraction, + angle_to_edge_sum, + angle_to_node_sum, + attach_angles, + build_angle_index, + build_neighbor_graph, + edge_force_virial, + pad_and_guard_angles, +) + + +def test_pad_angles_dynamic_appends_min_guard(): + ai = np.array([[0, 1], [1, 0]], dtype=np.int64) # 2 real angles + out_ai, out_mask = pad_and_guard_angles(ai, angle_capacity=None, min_angles=2) + assert out_ai.shape == (2, 4) # 2 real + 2 guard + np.testing.assert_array_equal(out_mask, [True, True, False, False]) + + +def test_pad_angles_static_capacity(): + ai = np.array([[0, 1], [1, 0]], dtype=np.int64) + out_ai, out_mask = pad_and_guard_angles(ai, angle_capacity=5) + assert out_ai.shape == (2, 5) + assert int(out_mask.sum()) == 2 + + +def test_pad_angles_overflow_raises(): + ai = np.zeros((2, 6), dtype=np.int64) + with pytest.raises(ValueError): + pad_and_guard_angles(ai, angle_capacity=4) + + +def _angle_oracle(dst, evnorm, mask, a_rcut, ordered, include_self): + """All (edge_a, edge_b) sharing a center, both edges within a_rcut.""" + out = set() + for a in range(len(dst)): + if not mask[a] or evnorm[a] >= a_rcut: + continue + for b in range(len(dst)): + if not mask[b] or evnorm[b] >= a_rcut or dst[a] != dst[b]: + continue + if not include_self and a == b: + continue + if not ordered and b < a: + continue + out.add((a, b)) + return out + + +def test_build_angle_index_matches_oracle(): + # 4 edges, all dst=0; norms [0.5, 0.9, 2.5, 0.7], a_rcut=1.0 + # DEFAULT = unordered, no-self => within a_rcut {0,1,3} give {(0,1),(0,3),(1,3)} + edge_index = np.array([[1, 2, 3, 1], [0, 0, 0, 0]], dtype=np.int64) + edge_vec = np.array([[0.5, 0, 0], [0.9, 0, 0], [2.5, 0, 0], [0.7, 0, 0]]) + edge_mask = np.array([True, True, True, True]) + ai, am = build_angle_index(edge_index, edge_vec, edge_mask, 4, a_rcut=1.0) + got = {(int(ai[0, p]), int(ai[1, p])) for p in range(ai.shape[1]) if am[p]} + evnorm = np.linalg.norm(edge_vec, axis=-1) + assert got == _angle_oracle( + [0, 0, 0, 0], evnorm, edge_mask, 1.0, ordered=False, include_self=False + ) + assert got == { + (0, 1), + (0, 3), + (1, 3), + } # unordered, no-self, edge 2 dropped (norm>a_rcut) + assert all( + 2 not in (int(ai[0, p]), int(ai[1, p])) for p in range(ai.shape[1]) if am[p] + ) + assert all( + int(ai[0, p]) != int(ai[1, p]) for p in range(ai.shape[1]) if am[p] + ) # no self + + +def test_build_angle_index_ordered_include_self(): + # ordered + include_self: (0,0),(0,1),(0,3),(1,0),(1,1),(1,3),(3,0),(3,1),(3,3) + edge_index = np.array([[1, 2, 3, 1], [0, 0, 0, 0]], dtype=np.int64) + edge_vec = np.array([[0.5, 0, 0], [0.9, 0, 0], [2.5, 0, 0], [0.7, 0, 0]]) + edge_mask = np.array([True, True, True, True]) + ai, am = build_angle_index( + edge_index, edge_vec, edge_mask, 4, a_rcut=1.0, ordered=True, include_self=True + ) + got = {(int(ai[0, p]), int(ai[1, p])) for p in range(ai.shape[1]) if am[p]} + evnorm = np.linalg.norm(edge_vec, axis=-1) + assert got == _angle_oracle( + [0, 0, 0, 0], evnorm, edge_mask, 1.0, ordered=True, include_self=True + ) + + +def test_build_angle_index_masked_edge(): + # edge 1 masked out — should not appear in any angle + edge_index = np.array([[1, 2, 3, 1], [0, 0, 0, 0]], dtype=np.int64) + edge_vec = np.array([[0.5, 0, 0], [0.9, 0, 0], [0.3, 0, 0], [0.7, 0, 0]]) + edge_mask = np.array([True, False, True, True]) + ai, am = build_angle_index(edge_index, edge_vec, edge_mask, 4, a_rcut=1.0) + got = {(int(ai[0, p]), int(ai[1, p])) for p in range(ai.shape[1]) if am[p]} + evnorm = np.linalg.norm(edge_vec, axis=-1) + assert got == _angle_oracle( + [0, 0, 0, 0], evnorm, edge_mask, 1.0, ordered=False, include_self=False + ) + assert all( + 1 not in (int(ai[0, p]), int(ai[1, p])) for p in range(ai.shape[1]) if am[p] + ) + + +def test_build_angle_index_torch_namespace(): + # Step 4b: torch-namespace smoke test (function-level import for TID253) + import torch + + edge_index = np.array([[1, 2, 3, 1], [0, 0, 0, 0]], dtype=np.int64) + edge_vec = np.array([[0.5, 0, 0], [0.9, 0, 0], [2.5, 0, 0], [0.7, 0, 0]]) + edge_mask = np.array([True, True, True, True]) + + ai_np, am_np = build_angle_index(edge_index, edge_vec, edge_mask, 4, a_rcut=1.0) + got_np = { + (int(ai_np[0, p]), int(ai_np[1, p])) for p in range(ai_np.shape[1]) if am_np[p] + } + + t_edge_index = torch.from_numpy(edge_index) + t_edge_vec = torch.from_numpy(edge_vec) + t_edge_mask = torch.from_numpy(edge_mask) + ai_t, am_t = build_angle_index(t_edge_index, t_edge_vec, t_edge_mask, 4, a_rcut=1.0) + got_t = { + (int(ai_t[0, p].item()), int(ai_t[1, p].item())) + for p in range(ai_t.shape[1]) + if am_t[p].item() + } + assert got_t == got_np + + +def test_build_angle_index_multi_center(): + # Edges with MIXED centers: dst=[0,1,0,1,2]; exercises the dst[a]!=dst[b] exclusion + # src=[1,2,3,4,5], dst=[0,1,0,1,2], all norms within a_rcut + # Expected angles per center: + # dst=0: edges {0,2} => {(0,2),(2,0)} + # dst=1: edges {1,3} => {(1,3),(3,1)} + # dst=2: edge {4} => no pairs + edge_index = np.array([[1, 2, 3, 4, 5], [0, 1, 0, 1, 2]], dtype=np.int64) + edge_vec = np.array( + [[0.3, 0, 0], [0.4, 0, 0], [0.5, 0, 0], [0.6, 0, 0], [0.7, 0, 0]] + ) + edge_mask = np.array([True, True, True, True, True]) + ai, am = build_angle_index(edge_index, edge_vec, edge_mask, 6, a_rcut=1.0) + got = {(int(ai[0, p]), int(ai[1, p])) for p in range(ai.shape[1]) if am[p]} + evnorm = np.linalg.norm(edge_vec, axis=-1) + expected = _angle_oracle( + [0, 1, 0, 1, 2], evnorm, edge_mask, 1.0, ordered=False, include_self=False + ) + assert got == expected + # Verify no cross-center angles + assert all( + edge_index[1, int(ai[0, p])] == edge_index[1, int(ai[1, p])] + for p in range(ai.shape[1]) + if am[p] + ) + + +def test_build_angle_index_static_layout(): + # Test with static layout.angle_capacity; shape must be (2, capacity) + edge_index = np.array([[1, 2, 3, 1], [0, 0, 0, 0]], dtype=np.int64) + edge_vec = np.array([[0.5, 0, 0], [0.9, 0, 0], [2.5, 0, 0], [0.7, 0, 0]]) + edge_mask = np.array([True, True, True, True]) + layout = GraphLayout(edge_capacity=100, angle_capacity=10) + ai, am = build_angle_index( + edge_index, edge_vec, edge_mask, 4, a_rcut=1.0, layout=layout + ) + # Check static shape + assert ai.shape == (2, 10) + assert am.shape == (10,) + # Check real angles match the dynamic result + got_static = {(int(ai[0, p]), int(ai[1, p])) for p in range(ai.shape[1]) if am[p]} + ai_dyn, am_dyn = build_angle_index(edge_index, edge_vec, edge_mask, 4, a_rcut=1.0) + got_dynamic = { + (int(ai_dyn[0, p]), int(ai_dyn[1, p])) + for p in range(ai_dyn.shape[1]) + if am_dyn[p] + } + assert got_static == got_dynamic + # Check mask counts match + assert int(am.sum()) == int(am_dyn.sum()) + + +def test_build_angle_index_ordered_no_self(): + # Test ordered=True, include_self=False; should be symmetric pairs excluding diagonals + edge_index = np.array([[1, 2, 3, 1], [0, 0, 0, 0]], dtype=np.int64) + edge_vec = np.array([[0.5, 0, 0], [0.9, 0, 0], [2.5, 0, 0], [0.7, 0, 0]]) + edge_mask = np.array([True, True, True, True]) + ai, am = build_angle_index( + edge_index, edge_vec, edge_mask, 4, a_rcut=1.0, ordered=True, include_self=False + ) + got = {(int(ai[0, p]), int(ai[1, p])) for p in range(ai.shape[1]) if am[p]} + evnorm = np.linalg.norm(edge_vec, axis=-1) + expected = _angle_oracle( + [0, 0, 0, 0], evnorm, edge_mask, 1.0, ordered=True, include_self=False + ) + assert got == expected + # Verify no self-angles and ordered includes both directions + assert all( + int(ai[0, p]) != int(ai[1, p]) for p in range(ai.shape[1]) if am[p] + ) # no self + for pair in got: + a, b = pair + # For ordered=True, include_self=False, we expect both (a,b) and (b,a) + # if they are different and both within a_rcut and same center + rev_pair = (b, a) + if a != b and a not in [2] and b not in [2]: # edge 2 is outside a_rcut + assert rev_pair in got # symmetric pairs should both exist + + +# --------------------------------------------------------------------------- +# Task 3: attach_angles tests +# --------------------------------------------------------------------------- + + +def test_attach_angles_sets_fields_and_preserves_edges(): + """attach_angles populates angle_index/mask; edge fields are unchanged.""" + coord = np.array([[[0.0, 0, 0], [0.8, 0, 0], [0, 0.8, 0]]]) + atype = np.array([[0, 0, 0]]) # (nf, nloc) + ng = build_neighbor_graph(coord, atype, None, 2.0) + # default carry-all builder leaves angles None + assert ng.angle_index is None + assert ng.angle_mask is None + ng2 = attach_angles(ng, a_rcut=1.5) + assert ng2.angle_index is not None and ng2.angle_mask is not None + # edge fields must be identical (by value and shape) + np.testing.assert_array_equal(np.asarray(ng2.edge_index), np.asarray(ng.edge_index)) + np.testing.assert_array_equal(np.asarray(ng2.edge_mask), np.asarray(ng.edge_mask)) + np.testing.assert_array_equal(np.asarray(ng2.edge_vec), np.asarray(ng.edge_vec)) + + +def test_attach_angles_angle_shape_consistent(): + """angle_index has shape (2, A) and angle_mask has shape (A,).""" + coord = np.array([[[0.0, 0, 0], [0.5, 0, 0], [0, 0.5, 0]]]) + atype = np.array([[0, 0, 0]]) # (nf, nloc) + ng = build_neighbor_graph(coord, atype, None, 2.0) + ng2 = attach_angles(ng, a_rcut=1.5) + assert ng2.angle_index.shape[0] == 2 + assert ng2.angle_mask.shape[0] == ng2.angle_index.shape[1] + + +def test_attach_angles_valid_angles_reference_valid_edges(): + """All valid angle pairs (q_e, k_e) must index edges that are within a_rcut.""" + coord = np.array([[[0.0, 0, 0], [0.6, 0, 0], [0, 0.6, 0]]]) + atype = np.array([[0, 0, 0]]) # (nf, nloc) + ng = build_neighbor_graph(coord, atype, None, 2.0) + ng2 = attach_angles(ng, a_rcut=1.0) + ei = np.asarray(ng2.edge_index) + ev = np.asarray(ng2.edge_vec) + em = np.asarray(ng2.edge_mask) + ai = np.asarray(ng2.angle_index) + am = np.asarray(ng2.angle_mask) + for p in range(am.shape[0]): + if not am[p]: + continue + q, k = int(ai[0, p]), int(ai[1, p]) + # both referenced edges must be valid and within a_rcut + assert em[q] and em[k] + assert np.linalg.norm(ev[q]) < 1.0 + assert np.linalg.norm(ev[k]) < 1.0 + # both referenced edges must share the same center (dst) + assert ei[1, q] == ei[1, k] + + +def test_attach_angles_with_layout(): + """Static layout.angle_capacity is respected.""" + coord = np.array([[[0.0, 0, 0], [0.6, 0, 0], [0, 0.6, 0]]]) + atype = np.array([[0, 0, 0]]) # (nf, nloc) + ng = build_neighbor_graph(coord, atype, None, 2.0) + layout = GraphLayout(edge_capacity=100, angle_capacity=20) + ng2 = attach_angles(ng, a_rcut=1.5, layout=layout) + assert ng2.angle_index.shape == (2, 20) + assert ng2.angle_mask.shape == (20,) + + +def test_attach_angles_ordered_include_self(): + """ordered=True, include_self=True produces a superset of default pairs.""" + coord = np.array([[[0.0, 0, 0], [0.5, 0, 0], [0, 0.5, 0]]]) + atype = np.array([[0, 0, 0]]) # (nf, nloc) + ng = build_neighbor_graph(coord, atype, None, 2.0) + ng_default = attach_angles(ng, a_rcut=1.5) + ng_full = attach_angles(ng, a_rcut=1.5, ordered=True, include_self=True) + ai_def = np.asarray(ng_default.angle_index) + am_def = np.asarray(ng_default.angle_mask) + ai_full = np.asarray(ng_full.angle_index) + am_full = np.asarray(ng_full.angle_mask) + n_default = int(am_def.sum()) + n_full = int(am_full.sum()) + # ordered+include_self must produce at least as many angles as default + assert n_full >= n_default + + +def test_attach_angles_with_layout_node_capacity(): + """layout.node_capacity branch: node_capacity is used as n_total.""" + coord = np.array([[[0.0, 0, 0], [0.6, 0, 0], [0, 0.6, 0]]]) + atype = np.array([[0, 0, 0]]) # (nf, nloc) + ng = build_neighbor_graph(coord, atype, None, 2.0) + layout = GraphLayout(edge_capacity=100, angle_capacity=20, node_capacity=10) + ng2 = attach_angles(ng, a_rcut=1.5, layout=layout) + assert ng2.angle_index is not None + assert ng2.angle_mask.shape == (20,) + # same real-angle set as the dynamic path (node_capacity only oversizes n_total) + ng3 = attach_angles(ng, a_rcut=1.5) + got = { + (int(ng2.angle_index[0, p]), int(ng2.angle_index[1, p])) + for p in range(ng2.angle_index.shape[1]) + if ng2.angle_mask[p] + } + ref = { + (int(ng3.angle_index[0, p]), int(ng3.angle_index[1, p])) + for p in range(ng3.angle_index.shape[1]) + if ng3.angle_mask[p] + } + assert got == ref + + +# --------------------------------------------------------------------------- +# Task 4: angle aggregation (angle_to_edge_sum / angle_to_node_sum) +# --------------------------------------------------------------------------- + + +def test_angle_aggregation(): + """Test angle->edge and angle->node aggregation.""" + # edges: dst=[0,0]; angles: (a=0,b=0),(a=0,b=1),(a=1,b=0) + edge_index = np.array([[5, 5], [0, 0]], dtype=np.int64) + angle_index = np.array([[0, 0, 1], [0, 1, 0]], dtype=np.int64) + data = np.array([1.0, 2.0, 4.0]) # per-angle + # angle->edge (group by edge_a): edge0 gets angles 0,1 => 3; edge1 gets angle2 => 4 + e = angle_to_edge_sum(data, angle_index, 2) + np.testing.assert_allclose(e, [3.0, 4.0]) + # angle->node (center of edge_a): all 3 angles share center 0 => 7 + n = angle_to_node_sum(data, angle_index, edge_index, 1) + np.testing.assert_allclose(n, [7.0]) + + +def test_angle_aggregation_torch_namespace(): + """Step 4b: torch-namespace smoke test for angle aggregation.""" + import torch + + # edges: dst=[0,0]; angles: (a=0,b=0),(a=0,b=1),(a=1,b=0) + edge_index = np.array([[5, 5], [0, 0]], dtype=np.int64) + angle_index = np.array([[0, 0, 1], [0, 1, 0]], dtype=np.int64) + data = np.array([1.0, 2.0, 4.0]) # per-angle + + # numpy reference + e_np = angle_to_edge_sum(data, angle_index, 2) + n_np = angle_to_node_sum(data, angle_index, edge_index, 1) + + # torch version + t_edge_index = torch.from_numpy(edge_index) + t_angle_index = torch.from_numpy(angle_index) + t_data = torch.from_numpy(data) + + e_t = angle_to_edge_sum(t_data, t_angle_index, 2) + n_t = angle_to_node_sum(t_data, t_angle_index, t_edge_index, 1) + + # compare + np.testing.assert_allclose(np.asarray(e_t), e_np) + np.testing.assert_allclose(np.asarray(n_t), n_np) + + +# --------------------------------------------------------------------------- +# Task 6: angle-force invariance + angle_padding_fraction +# --------------------------------------------------------------------------- + + +def _small_graph_with_angles(a_rcut: float, layout: GraphLayout | None = None): + """Return (graph_no_angles, graph_with_angles) for a 3-atom, 1-frame system.""" + # 3 atoms: 0,1,2 in a line along x; shape (nf=1, nloc=3, 3) / (nf=1, nloc=3) + coord = np.array([[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [2.0, 0.0, 0.0]]]) + atype = np.array([[0, 1, 0]], dtype=np.int64) + box = np.array([[[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]]) + g = build_neighbor_graph(coord, atype, box, rcut=3.0) + g_with = attach_angles(g, a_rcut, layout=layout) + return g, g_with + + +def test_edge_force_virial_ignores_angles(): + """edge_force_virial output must be bit-identical with or without angles. + + Angles add topology (angle_index/angle_mask) to the NeighborGraph but do + NOT change edge_vec, edge_index, or edge_mask — the only inputs to + edge_force_virial. This test proves that the angle fields are truly + transparent to the force/virial assembly. + """ + g_bare, g_with_angles = _small_graph_with_angles(a_rcut=1.5) + + # Manufacture a fake per-edge gradient (same shape as edge_vec) + rng = np.random.default_rng(42) + n_edges = int(g_bare.edge_index.shape[1]) + g_e = rng.standard_normal((n_edges, 3)) + + def run(graph): + return edge_force_virial( + g_e, + graph.edge_vec, + graph.edge_index, + graph.edge_mask, + graph.n_node, + ) + + force_bare, av_bare, vir_bare = run(g_bare) + force_with, av_with, vir_with = run(g_with_angles) + + # Exact equality: same inputs → same computation → identical bits + np.testing.assert_array_equal(force_bare, force_with) + np.testing.assert_array_equal(av_bare, av_with) + np.testing.assert_array_equal(vir_bare, vir_with) + + +def test_angle_padding_fraction(): + """angle_padding_fraction returns 1 - A_real/A_max for a static layout. + + We build with a fixed angle_capacity=A_max so the fraction is deterministic + (not influenced by the dynamic min_angles guard of pad_and_guard_angles). + """ + A_max = 20 # static capacity, larger than any real angle count + layout = GraphLayout(angle_capacity=A_max) + g_bare, g_with = _small_graph_with_angles(a_rcut=1.5, layout=layout) + + # Confirm angles are present + assert g_with.angle_mask is not None + A_real = int(np.sum(g_with.angle_mask)) + assert 0 < A_real <= A_max, f"Expected 0 < A_real <= {A_max}, got {A_real}" + + expected = 1.0 - A_real / A_max + got = angle_padding_fraction(g_with) + assert got == pytest.approx(expected), f"got {got}, expected {expected}" + + # No angles → fraction is 0.0 + assert angle_padding_fraction(g_bare) == 0.0 + + +def test_angle_padding_fraction_total_zero(): + """angle_padding_fraction returns 0.0 when angle_mask.shape[0] == 0. + + This exercises the `if total == 0: return 0.0` branch in angle_padding_fraction. + A graph with angle_capacity=0 and a_rcut too small for any angles produces + angle_mask with shape (0,). + """ + # Create a layout with angle_capacity=0 and use a_rcut=0.01 (too small + # for any edge in the fixture to pass the distance gate) + layout = GraphLayout(angle_capacity=0) + g_bare, g_with = _small_graph_with_angles(a_rcut=0.01, layout=layout) + + # Verify angle_mask exists and has shape (0,) + assert g_with.angle_mask is not None + assert g_with.angle_mask.shape[0] == 0 + + # angle_padding_fraction must return 0.0 for empty mask + result = angle_padding_fraction(g_with) + assert result == 0.0 diff --git a/source/tests/common/dpmodel/test_center_edge_pairs.py b/source/tests/common/dpmodel/test_center_edge_pairs.py new file mode 100644 index 0000000000..aa018b6bb8 --- /dev/null +++ b/source/tests/common/dpmodel/test_center_edge_pairs.py @@ -0,0 +1,133 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""center_edge_pairs: pairs of edges sharing a center (NeighborGraph PR-D/E).""" + +import numpy as np + +from deepmd.dpmodel.utils.neighbor_graph import ( + center_edge_pairs, +) + + +def _oracle(dst, mask, include_self, ordered): + pairs = [] + for m in range(len(dst)): + if not mask[m]: + continue + for n in range(len(dst)): + if not mask[n] or dst[m] != dst[n]: + continue + if not include_self and m == n: + continue + if not ordered and n < m: + continue + pairs.append((m, n)) + return set(pairs) + + +def _got(q, k, pm): + return {(int(q[p]), int(k[p])) for p in range(q.shape[0]) if pm[p]} + + +class TestCompact: + def test_transformer_all_ordered_with_self(self) -> None: + # 3 edges: dst = [0, 0, 1]; center 0 has edges {0,1}, center 1 has {2} + dst = np.array([0, 0, 1], dtype=np.int64) + edge_mask = np.array([True, True, True]) + q, k, pm = center_edge_pairs(dst, edge_mask, 2) + assert _got(q, k, pm) == _oracle([0, 0, 1], [1, 1, 1], True, True) + # center 0: (0,0),(0,1),(1,0),(1,1); center 1: (2,2) => 5 pairs + assert len(_got(q, k, pm)) == 5 + + def test_unordered_no_self_is_angle_set(self) -> None: + dst = np.array([0, 0, 0], dtype=np.int64) + edge_mask = np.array([True, True, True]) + q, k, pm = center_edge_pairs( + dst, edge_mask, 1, include_self=False, ordered=False + ) + assert _got(q, k, pm) == {(0, 1), (0, 2), (1, 2)} + + def test_ignores_padding_edges(self) -> None: + dst = np.array([0, 0, 0], dtype=np.int64) + edge_mask = np.array([True, True, False]) # 3rd is a guard edge + q, k, pm = center_edge_pairs(dst, edge_mask, 1) + assert _got(q, k, pm) == {(0, 0), (0, 1), (1, 0), (1, 1)} + + def test_non_contiguous_center_order(self) -> None: + # edges NOT sorted by center: dst = [1, 0, 1, 0] + dst = np.array([1, 0, 1, 0], dtype=np.int64) + edge_mask = np.array([True, True, True, True]) + q, k, pm = center_edge_pairs(dst, edge_mask, 2) + assert _got(q, k, pm) == _oracle([1, 0, 1, 0], [1] * 4, True, True) + + def test_empty(self) -> None: + dst = np.zeros((0,), dtype=np.int64) + edge_mask = np.zeros((0,), dtype=bool) + q, k, pm = center_edge_pairs(dst, edge_mask, 3) + assert q.shape[0] == 0 and k.shape[0] == 0 and pm.shape[0] == 0 + + def test_random_vs_oracle(self) -> None: + rng = np.random.default_rng(7) + dst = rng.integers(0, 5, size=23).astype(np.int64) + edge_mask = rng.random(23) > 0.3 + for include_self in (True, False): + for ordered in (True, False): + q, k, pm = center_edge_pairs( + dst, edge_mask, 5, include_self=include_self, ordered=ordered + ) + assert _got(q, k, pm) == _oracle( + dst, edge_mask, include_self, ordered + ), (include_self, ordered) + + def test_torch_matches_numpy(self) -> None: + import torch + + dst = np.array([0, 0, 1, 1, 1], dtype=np.int64) + edge_mask = np.array([True, False, True, True, True]) + ref = _got(*center_edge_pairs(dst, edge_mask, 2)) + q, k, pm = center_edge_pairs( + torch.from_numpy(dst), torch.from_numpy(edge_mask), 2 + ) + assert _got(q.numpy(), k.numpy(), pm.numpy()) == ref + + +class TestShapeStatic: + def test_matches_compact(self) -> None: + # center-major static layout: 2 centers x static_nnei=3, edges 2,5 padded + dst = np.array([0, 0, 0, 1, 1, 1], dtype=np.int64) + edge_mask = np.array([True, True, False, True, True, False]) + qc, kc, pmc = center_edge_pairs(dst, edge_mask, 2) + qs, ks, pms = center_edge_pairs(dst, edge_mask, 2, static_nnei=3) + assert qs.shape[0] == 2 * 3 * 3 # static P, data-independent + assert _got(qs, ks, pms) == _got(qc, kc, pmc) + + def test_flags_and_masking(self) -> None: + dst = np.array([0, 0, 0, 1, 1, 1], dtype=np.int64) + edge_mask = np.array([True, True, True, True, False, False]) + for include_self in (True, False): + for ordered in (True, False): + qs, ks, pms = center_edge_pairs( + dst, + edge_mask, + 2, + include_self=include_self, + ordered=ordered, + static_nnei=3, + ) + assert qs.shape[0] == 2 * 3 * 3 # P static regardless of flags + assert _got(qs, ks, pms) == _oracle( + dst, edge_mask, include_self, ordered + ), (include_self, ordered) + + def test_torch_matches_numpy(self) -> None: + import torch + + dst = np.array([0, 0, 1, 1], dtype=np.int64) + edge_mask = np.array([True, False, True, True]) + ref = _got(*center_edge_pairs(dst, edge_mask, 2, static_nnei=2)) + q, k, pm = center_edge_pairs( + torch.from_numpy(dst), + torch.from_numpy(edge_mask), + 2, + static_nnei=2, + ) + assert _got(q.numpy(), k.numpy(), pm.numpy()) == ref diff --git a/source/tests/common/dpmodel/test_dpa1_call_graph_block.py b/source/tests/common/dpmodel/test_dpa1_call_graph_block.py index e8930101dd..9a984a30f3 100644 --- a/source/tests/common/dpmodel/test_dpa1_call_graph_block.py +++ b/source/tests/common/dpmodel/test_dpa1_call_graph_block.py @@ -90,11 +90,8 @@ def test_block_graph_equals_dense_any_sel(self, sel, type_one_side) -> None: atol=1e-12, ) - def test_attn_layer_gt0_raises(self) -> None: - """The graph block kernel fail-fasts for attn_layer > 0 (unsupported).""" - dd = DescrptDPA1(rcut=4.0, rcut_smth=0.5, sel=[20], ntypes=2, attn_layer=2) - with pytest.raises(NotImplementedError): - dd.se_atten.call_graph(None, np.array([0], dtype=np.int64)) + # attn_layer > 0 is supported since NeighborGraph PR-D; parity is covered + # by test_dpa1_graph_attention_parity.py (the fail-fast test was removed). def test_exclude_types_raises(self) -> None: """The graph block kernel fail-fasts for exclude_types (not yet applied).""" diff --git a/source/tests/common/dpmodel/test_dpa1_graph_attention_parity.py b/source/tests/common/dpmodel/test_dpa1_graph_attention_parity.py new file mode 100644 index 0000000000..97d044862f --- /dev/null +++ b/source/tests/common/dpmodel/test_dpa1_graph_attention_parity.py @@ -0,0 +1,269 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Graph-native se_atten attention (attn_layer > 0) vs the dense reference. + +Regime-1 parity (NeighborGraph PR-D): the graph is built FROM the same dense +nlist (``from_dense_quartet``), so the neighbor sets are identical and the +graph attention must reproduce ``GatedAttentionLayer``/``NeighborGatedAttention`` +bit-exactly (CPU rtol 1e-12) for ANY sel — binding or not. + +The smooth branch needs the SHAPE-STATIC graph (``compact=False`` + +``static_nnei``): dense smooth keeps padding slots in the softmax DENOMINATOR +(weight ``exp(-attnw_shift)`` since ``sw = 0``), so bit-parity requires the +same padded pairs on the graph side. The compact (carry-all-like) form drops +padding pairs and is exercised on the non-smooth branch only. +""" + +import numpy as np +import pytest + +from deepmd.dpmodel.descriptor.dpa1 import ( + DescrptDPA1, +) +from deepmd.dpmodel.utils.neighbor_graph import ( + from_dense_quartet, +) +from deepmd.dpmodel.utils.nlist import ( + extend_input_and_build_neighbor_list, +) + +GLOBAL_SEED = 20260703 + + +def _make( + attn_layer, + dotr=False, + smooth=False, + normalize=False, + temperature=1.0, + sel=(20,), +): + # attention `smooth` is wired to smooth_type_embedding (NOT rcut_smth); + # pass it explicitly — its default (True) would silently enable the + # smooth branch in the non-smooth cases. + return DescrptDPA1( + rcut=4.0, + rcut_smth=0.5, + sel=list(sel), + ntypes=2, + neuron=[6, 12], + axis_neuron=2, + attn=8, + attn_layer=attn_layer, + attn_dotr=dotr, + attn_mask=False, + normalize=normalize, + smooth_type_embedding=smooth, + temperature=temperature, + tebd_input_mode="concat", + type_one_side=True, + precision="float64", + seed=GLOBAL_SEED, + ) + + +class TestGraphAttentionParity: + def setup_method(self) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + self.nloc = 5 + self.coord = rng.normal(size=(1, self.nloc, 3)) * 1.5 + self.atype = np.array([[0, 1, 0, 1, 1]], dtype=np.int64) + + def _quartet(self, dd): + return extend_input_and_build_neighbor_list( + self.coord, + self.atype, + dd.get_rcut(), + dd.get_sel(), + mixed_types=dd.mixed_types(), + box=None, + ) + + def _dense_vs_adapter(self, dd, rtol=1e-12): + """Descriptor-level: legacy dense body vs the graph adapter (shape-static).""" + ext_coord, ext_atype, mapping, nlist = self._quartet(dd) + dense = dd._call_dense(ext_coord, ext_atype, nlist) + graph = dd._call_graph_adapter(ext_coord, ext_atype, nlist, mapping) + np.testing.assert_allclose( + graph[0], dense[0], rtol=rtol, atol=rtol, err_msg="descriptor" + ) + np.testing.assert_allclose( + graph[1], dense[1], rtol=rtol, atol=rtol, err_msg="rot_mat" + ) + + # ── Task 5a/5b/5c/5e: full matrix on the shape-static adapter path ────── + @pytest.mark.parametrize("sel", [(20,), (3,)]) # non-binding AND binding + @pytest.mark.parametrize("attn_layer", [1, 2]) # single + stacked layers + def test_core_layers_sel(self, attn_layer, sel) -> None: + dd = _make(attn_layer, sel=sel) + self._dense_vs_adapter(dd) + + @pytest.mark.parametrize("normalize", [False, True]) # q/k/v np_normalize + @pytest.mark.parametrize("temperature", [None, 1.0]) # scaling source + def test_normalize_temperature(self, normalize, temperature) -> None: + dd = _make(1, normalize=normalize, temperature=temperature) + self._dense_vs_adapter(dd) + + @pytest.mark.parametrize("dotr", [False, True]) # angular weighting + @pytest.mark.parametrize("smooth", [False, True]) # switch-fn weighting + def test_dotr_smooth(self, dotr, smooth) -> None: + dd = _make(2, dotr=dotr, smooth=smooth, normalize=True, temperature=None) + self._dense_vs_adapter(dd) + + # ── compact (carry-all-form) graph through the BLOCK kernel, non-smooth ── + @pytest.mark.parametrize("attn_layer", [1, 2]) # single + stacked layers + def test_block_compact_graph_no_smooth(self, attn_layer) -> None: + dd = _make(attn_layer, dotr=True, normalize=True) + blk = dd.se_atten + ext_coord, ext_atype, mapping, nlist = self._quartet(dd) + tebd = dd.type_embedding.call() + nf, nall = ext_atype.shape + atype_embd_ext = np.reshape( + np.take(tebd, np.reshape(ext_atype, (-1,)), axis=0), + (nf, nall, dd.tebd_dim), + ) + dense_g, *_ = blk.call( + nlist, + ext_coord, + ext_atype, + atype_embd_ext=atype_embd_ext, + mapping=None, + type_embedding=tebd, + ) + ng = from_dense_quartet(ext_coord, nlist, mapping) # compact=True + graph_g, _ = blk.call_graph( + ng, np.reshape(ext_atype, (-1,)), type_embedding=tebd + ) + np.testing.assert_allclose( + graph_g.reshape(dense_g.shape), dense_g, rtol=1e-12, atol=1e-12 + ) + + # ── smooth on the compact (carry-all) form: CLEAN DIVERGENCE by design ──── + def test_block_compact_graph_smooth_clean_divergence(self) -> None: + """Carry-all smooth attention deliberately DIVERGES from dense. + + The dense smooth branch keeps sel-padding slots in the attention + softmax DENOMINATOR at weight ``exp(-attnw_shift)``, which makes the + dense output depend on ``sel`` itself (same physical neighbors, + different sel => different output, up to ~1e-4). The carry-all graph + drops those phantom terms — the sel-independent math (user decision + 2026-07-03, PR-D). Bit-parity (1e-12) is proven on the shape-static + adapter (same padded pairs on both sides, ``test_dotr_smooth``); here + we pin only that the compact form stays CLOSE to dense (the artifact + is a bounded denominator perturbation) while NOT bit-equal. + """ + dd = _make(1, smooth=True) + blk = dd.se_atten + ext_coord, ext_atype, mapping, nlist = self._quartet(dd) + tebd = dd.type_embedding.call() + nf, nall = ext_atype.shape + atype_embd_ext = np.reshape( + np.take(tebd, np.reshape(ext_atype, (-1,)), axis=0), + (nf, nall, dd.tebd_dim), + ) + dense_g, *_ = blk.call( + nlist, + ext_coord, + ext_atype, + atype_embd_ext=atype_embd_ext, + mapping=None, + type_embedding=tebd, + ) + ng = from_dense_quartet(ext_coord, nlist, mapping) + graph_g, _ = blk.call_graph( + ng, np.reshape(ext_atype, (-1,)), type_embedding=tebd + ) + graph_g = graph_g.reshape(dense_g.shape) + # close (the artifact is a small denominator perturbation) ... + np.testing.assert_allclose(graph_g, dense_g, rtol=1e-3, atol=1e-3) + # ... but NOT bit-equal: the phantom-padding terms are really gone + assert np.max(np.abs(graph_g - dense_g)) > 1e-9 + + # ── torch namespace smoke (CLAUDE.md: catches numpy-weight leaks) ──────── + # NB: the smoke runs the BLOCK kernel with a torch type_embedding table; + # the raw dpmodel adapter is numpy-weighted by design (pt_expt wraps it). + def test_torch_block_matches_numpy(self) -> None: + import torch + + dd = _make(2, dotr=True, smooth=True, normalize=True, temperature=None) + blk = dd.se_atten + ext_coord, ext_atype, mapping, nlist = self._quartet(dd) + tebd = dd.type_embedding.call() + ng = from_dense_quartet(ext_coord, nlist, mapping, compact=False) + ref, _ = blk.call_graph( + ng, + np.reshape(ext_atype, (-1,)), + type_embedding=tebd, + static_nnei=nlist.shape[2], + ) + ng_t = from_dense_quartet( + torch.from_numpy(ext_coord), + torch.from_numpy(nlist), + torch.from_numpy(mapping), + compact=False, + ) + out, _ = blk.call_graph( + ng_t, + torch.from_numpy(np.reshape(ext_atype, (-1,))), + type_embedding=torch.from_numpy(tebd), + static_nnei=nlist.shape[2], + ) + np.testing.assert_allclose(out.numpy(), ref, rtol=1e-12, atol=1e-12) + + +class TestGraphEligibility: + def test_attention_concat_is_graph_eligible(self) -> None: + assert _make(2).uses_graph_lower() + + def test_strip_mode_stays_dense(self) -> None: + """se_atten_v2 (tebd_input_mode='strip') is NOT graph-eligible yet: + strip-mode graph support is a later PR; it must keep the dense route + (the PR-D plan's 'se_atten_v2 inherits for free' did not hold). + """ + from deepmd.dpmodel.descriptor.se_atten_v2 import ( + DescrptSeAttenV2, + ) + + dd = DescrptSeAttenV2(rcut=4.0, rcut_smth=0.5, sel=[20], ntypes=2, attn_layer=2) + assert not dd.uses_graph_lower() + + +class TestBindingSelDivergence: + """At BINDING sel the carry-all graph attends over MORE neighbors than the + sel-truncated dense path — outputs must differ (sanity, not parity; + spec decision #17). + """ + + def test_carry_all_attention_differs_at_binding_sel(self) -> None: + from deepmd.dpmodel.utils.neighbor_graph import ( + build_neighbor_graph, + ) + + rng = np.random.default_rng(GLOBAL_SEED) + nloc = 6 + coord = rng.random((1, nloc, 3)) * 2.0 # dense blob => binding sel=2 + atype = np.array([[0, 1, 0, 1, 1, 0]], dtype=np.int64) + dd = _make(2, dotr=True, sel=(2,)) + ext_coord, ext_atype, mapping, nlist = extend_input_and_build_neighbor_list( + coord, atype, dd.get_rcut(), dd.get_sel(), mixed_types=True, box=None + ) + assert (nlist >= 0).all(), "fixture must be sel-binding (all slots full)" + tebd = dd.type_embedding.call() + atype_embd_ext = np.reshape( + np.take(tebd, np.reshape(ext_atype, (-1,)), axis=0), + (1, ext_atype.shape[1], dd.tebd_dim), + ) + dense_g, *_ = dd.se_atten.call( + nlist, + ext_coord, + ext_atype, + atype_embd_ext=atype_embd_ext, + mapping=None, + type_embedding=tebd, + ) + graph = build_neighbor_graph(coord, atype, None, dd.get_rcut()) + graph_g, _ = dd.se_atten.call_graph( + graph, atype.reshape(-1), type_embedding=tebd + ) + assert np.max(np.abs(graph_g.reshape(dense_g.shape) - dense_g)) > 1e-6, ( + "carry-all attention must diverge from sel-truncated dense" + ) diff --git a/source/tests/common/dpmodel/test_graph_angle_cos_parity.py b/source/tests/common/dpmodel/test_graph_angle_cos_parity.py new file mode 100644 index 0000000000..8040141919 --- /dev/null +++ b/source/tests/common/dpmodel/test_graph_angle_cos_parity.py @@ -0,0 +1,352 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Parity tests for graph_angle_cos vs dpa3 dense cosine_ij and se_t dot form.""" + +from __future__ import ( + annotations, +) + +import numpy as np + +from deepmd.dpmodel.utils.neighbor_graph import ( + attach_angles, + build_neighbor_graph, + graph_angle_cos, +) + +# --------------------------------------------------------------------------- +# Step 1/4: explicit geometry sanity +# --------------------------------------------------------------------------- + + +def test_graph_angle_cos_matches_normalized_dot(): + """Center at origin, neighbors along x and y: cos=0 (perpendicular). + + Use rcut=1.1 so neighbors (distance 1.0 from center) are included but + the neighbor-to-neighbor distance (sqrt(2)~1.41) is NOT, making atom 0 + the ONLY center with multiple neighbors => exactly ONE unordered angle. + """ + coord = np.array([[[0.0, 0, 0], [1.0, 0, 0], [0.0, 1.0, 0]]]) + atype = np.array([[0, 0, 0]]) + ng = attach_angles(build_neighbor_graph(coord, atype, None, 1.1), a_rcut=1.1) + cos = graph_angle_cos(ng.angle_index, ng.edge_vec) + real = np.asarray(ng.angle_mask) + vals = np.asarray(cos)[real] + # only atom 0 has 2 neighbors within rcut => exactly ONE unordered angle + assert int(real.sum()) == 1 + # perpendicular vectors: dot = 0, (1-eps) scaling => val ~0.0 to atol 1e-6 + np.testing.assert_allclose(vals, [0.0], atol=1e-6) + + +def test_graph_angle_cos_antiparallel(): + """Two neighbors in ANTIPARALLEL directions as seen from the center: cos ≈ -1. + + Center is placed at (1,0,0) with one neighbor at (0,0,0) and another at + (2,0,0). The edge vectors from center are (-1,0,0) and (+1,0,0), which + point in opposite directions => cos ≈ -(1-eps). + + rcut=1.1: center sees both neighbors (dist=1.0 each); the two neighbors + do not see each other (dist=2.0 > 1.1), so exactly ONE angle is formed. + """ + coord = np.array([[[1.0, 0, 0], [0.0, 0, 0], [2.0, 0, 0]]]) + atype = np.array([[0, 0, 0]]) + ng = attach_angles(build_neighbor_graph(coord, atype, None, 1.1), a_rcut=1.1) + cos = graph_angle_cos(ng.angle_index, ng.edge_vec) + real = np.asarray(ng.angle_mask) + vals = np.asarray(cos)[real] + assert int(real.sum()) == 1 + # edge_vec = neighbor - center; edge to (0,0,0) is (-1,0,0); edge to (2,0,0) is (+1,0,0). + # Antiparallel unit vectors: dot(na, nb) = -1 / (1+eps)^2, scaled by (1-eps). + eps = 1e-6 + expected = -1.0 / (1 + eps) ** 2 * (1 - eps) + np.testing.assert_allclose(vals, [expected], rtol=1e-6) + + +def test_graph_angle_cos_no_self_angles(): + """No graph angle should have both edge slots pointing to the same edge.""" + coord = np.array([[[0.0, 0, 0], [1.0, 0, 0], [0.0, 1.0, 0], [0.0, 0.0, 1.0]]]) + atype = np.array([[0, 0, 0, 0]]) + ng = attach_angles(build_neighbor_graph(coord, atype, None, 3.0), a_rcut=3.0) + ai = np.asarray(ng.angle_index) + am = np.asarray(ng.angle_mask) + # For every VALID angle, the two edge indices must differ + real_angles = [(int(ai[0, p]), int(ai[1, p])) for p in range(am.shape[0]) if am[p]] + for a, b in real_angles: + assert a != b, f"Self-angle found: edge {a}" + + +# --------------------------------------------------------------------------- +# dpa3 dense parity +# --------------------------------------------------------------------------- + + +def _dense_cosine_ij(coord_ext, nlist, a_rcut, a_sel): + """Faithful numpy transcription of repflows.py:597-649. + + Returns cosine_ij of shape (nf, nloc, a_sel, a_sel). + a_diff convention: coord_r - coord_l = neighbor - center. + """ + nf, nloc, nnei = nlist.shape + # coord_ext: (nf, nall, 3) + # a_nlist: truncate to a_sel and mask beyond a_rcut + diff_full = np.zeros((nf, nloc, nnei, 3), dtype=np.float64) + for f in range(nf): + for i in range(nloc): + for k in range(nnei): + j = nlist[f, i, k] + if j >= 0: + diff_full[f, i, k] = coord_ext[f, j] - coord_ext[f, i] + # a_rcut gate — clip a_sel to actual nnei columns available + nnei = nlist.shape[2] + eff_a_sel = min(a_sel, nnei) + a_dist_mask = (np.linalg.norm(diff_full, axis=-1) < a_rcut)[:, :, :eff_a_sel] + a_nlist = nlist[:, :, :eff_a_sel].copy() + a_nlist = np.where(a_dist_mask, a_nlist, np.full_like(a_nlist, -1)) + # a_diff: shape (nf, nloc, eff_a_sel, 3) + a_diff = np.zeros((nf, nloc, eff_a_sel, 3), dtype=np.float64) + for f in range(nf): + for i in range(nloc): + for k in range(eff_a_sel): + j = a_nlist[f, i, k] + if j >= 0: + a_diff[f, i, k] = coord_ext[f, j] - coord_ext[f, i] + # normalized_diff_i: (nf, nloc, eff_a_sel, 3) + norm = np.linalg.norm(a_diff, axis=-1, keepdims=True) + normalized_diff_i = a_diff / (norm + 1e-6) + # cosine_ij: (nf, nloc, eff_a_sel, eff_a_sel) + cosine_ij = np.matmul(normalized_diff_i, np.swapaxes(normalized_diff_i, -2, -1)) + cosine_ij = cosine_ij * (1 - 1e-6) + return cosine_ij, a_nlist + + +def test_graph_angle_cos_parity_vs_dpa3_dense(): + """Graph unordered/no-self cos values must match dense OFF-DIAGONAL cosine_ij. + + Uses a small 4-atom system with a large a_sel (non-binding) so that the + graph and dense see the same neighbor set. + + - Tolerance: rtol=1e-12, atol=1e-12 (CPU fp64 same-math, identical eps). + - The graph is UNORDERED (no duplicates) and NO-SELF. + - Dense diagonal (j==k, cos≈1) must NOT appear in graph angle set. + - Dense off-diagonal (j!=k) are collected as unordered {j,k} pairs and + matched against graph angles by (center_node, unordered edge-src pair). + """ + rng = np.random.default_rng(42) + # 4 atoms in a box; no PBC => set box to None + # Use a single frame, 4 atoms + nf, nloc = 1, 4 + coord = rng.uniform(-1, 1, (nf, nloc, 3)) + atype = np.zeros((nf, nloc), dtype=np.int32) + rcut = 3.0 + a_rcut = 2.5 + # Choose a_sel equal to nloc-1 (max neighbors) => non-binding + a_sel = nloc - 1 # =3; each atom has at most 3 neighbors => non-binding + + # --- graph side --- + ng = attach_angles(build_neighbor_graph(coord, atype, None, rcut), a_rcut=a_rcut) + cos_graph = np.asarray(graph_angle_cos(ng.angle_index, ng.edge_vec)) + am = np.asarray(ng.angle_mask) + ai = np.asarray(ng.angle_index) + ei = np.asarray(ng.edge_index) # (2, E): [src, dst] + ev = np.asarray(ng.edge_vec) + + # No self-angles + for p in range(am.shape[0]): + if am[p]: + assert int(ai[0, p]) != int(ai[1, p]), "Self-angle found in graph" + + # --- dense side --- + # Build a dense nlist from the same coord (no PBC, full nlist) + # We construct nlist by brute force + coord3 = coord[0] # (nloc, 3) + # For the dense side, coord_ext = coord (nloc=nall, no ghosts) + coord_ext = coord # (1, nloc, 3) + # Build dense nlist: shape (1, nloc, nnei_max) + # For 4 atoms, each atom has at most nloc-1=3 neighbors + nnei = nloc - 1 # max possible neighbors (no self) + dense_nlist = np.full((nf, nloc, nnei), -1, dtype=np.int64) + for i in range(nloc): + k = 0 + for j in range(nloc): + d = np.linalg.norm(coord3[j] - coord3[i]) + if d < rcut and i != j: + dense_nlist[0, i, k] = j + k += 1 + + cosine_ij, a_nlist = _dense_cosine_ij(coord_ext, dense_nlist, a_rcut, a_sel) + # cosine_ij: (1, nloc, a_sel, a_sel) + + # --- match graph angles to dense off-diagonal --- + # graph edge_index: src=edge_index[0], dst=edge_index[1] (center=dst) + # For each valid graph angle p: edge_a=ai[0,p], edge_b=ai[1,p] + # center = ei[1, edge_a] = ei[1, edge_b] (shared center) + # neighbor_a = ei[0, edge_a], neighbor_b = ei[0, edge_b] + + # Build dense lookup: (center, na, nb) -> cos for off-diagonal + dense_cos_lookup = {} # (center, unordered frozenset(na, nb)) -> cos + for i in range(nloc): + for j_idx in range(a_sel): + na = int(a_nlist[0, i, j_idx]) + if na < 0: + continue + for k_idx in range(a_sel): + nb = int(a_nlist[0, i, k_idx]) + if nb < 0: + continue + if j_idx == k_idx: # skip diagonal (self-angles = edge channel) + continue + cos_val = float(cosine_ij[0, i, j_idx, k_idx]) + key = (i, frozenset([na, nb])) + # unordered: both (j,k) and (k,j) map to same pair + # store the value for frozenset key (they should be equal by symmetry + # of cos, but we verify) + if key not in dense_cos_lookup: + dense_cos_lookup[key] = cos_val + else: + # cosine is symmetric: both directions should be equal + np.testing.assert_allclose( + dense_cos_lookup[key], + cos_val, + atol=1e-14, + err_msg=f"Dense cosine not symmetric for ({i},{na},{nb})", + ) + + # Now compare each valid graph angle + for p in range(am.shape[0]): + if not am[p]: + continue + ea = int(ai[0, p]) + eb = int(ai[1, p]) + center = int(ei[1, ea]) + assert center == int(ei[1, eb]), "Angle edges don't share center" + na = int(ei[0, ea]) + nb = int(ei[0, eb]) + key = (center, frozenset([na, nb])) + assert key in dense_cos_lookup, ( + f"Graph angle (center={center}, na={na}, nb={nb}) not in dense" + ) + cos_g = float(cos_graph[p]) + cos_d = dense_cos_lookup[key] + np.testing.assert_allclose( + cos_g, + cos_d, + rtol=1e-12, + atol=1e-12, + err_msg=f"cos mismatch at (center={center}, na={na}, nb={nb})", + ) + + +# --------------------------------------------------------------------------- +# se_t dot-product cross-check +# --------------------------------------------------------------------------- + + +def test_matches_se_t_dot_form(): + """Cross-check graph_angle_cos against an independent coordinate-based oracle. + + se_t.py:428-437 computes ``env_ij = sum(rr_i * rr_j, -1)`` where + ``rr_i = sw * diff / r^2`` (the 3-D columns of the env-mat). The raw + unnormalized dot product ``va · vb`` (with ``va = r_a - r_center``) is the + numerator that graph_angle_cos normalizes: + + graph_angle_cos = (1 - eps) * (va · vb) / ((|va| + eps) * (|vb| + eps)) + + Inverting: + + graph_angle_cos * (|va| + eps) * (|vb| + eps) / (1 - eps) = va · vb + + **Why sw is factored out**: sw scales each env-mat vector by a scalar. + When all neighbor distances are *below* ``rcut_smth``, the smooth switch + function equals 1 exactly (``sw == 1``), so the sw factor contributes + nothing and ``env_ij`` reduces to the plain geometry. + + **Why this test is not tautological**: the reference ``va``, ``vb``, and + ``env_ij = va · vb`` are computed DIRECTLY FROM COORDINATES in plain numpy, + independent of the graph code path. The |va| and |vb| norms used to unwind + ``cos`` are also recomputed from coordinates, NOT read from ``edge_vec``. + This verifies that the graph stores ``edge_vec = neighbor - center`` + correctly and that ``graph_angle_cos`` faithfully encodes the geometry. + With distances well below ``rcut_smth`` the identity holds to ``rtol=1e-12`` + because it is exact algebra over fp64; eps-induced rounding is negligible + compared to fp64 relative precision. + """ + # All atoms within distance 0.5 of center; rcut_smth = 1.0 so sw == 1 for all. + rng = np.random.default_rng(7) + center = np.array([0.0, 0.0, 0.0]) + r_a = rng.uniform(0.1, 0.4, 3) # distance from center < rcut_smth=1.0 + r_b = rng.uniform(0.1, 0.4, 3) + coord = np.array([[[*center], [*r_a], [*r_b]]]) # (1, 3, 3), single frame + atype = np.array([[0, 0, 0]]) + + rcut = 2.0 + a_rcut = 2.0 + ng = attach_angles(build_neighbor_graph(coord, atype, None, rcut), a_rcut=a_rcut) + ai = np.asarray(ng.angle_index) + am = np.asarray(ng.angle_mask) + ei = np.asarray(ng.edge_index) + cos = np.asarray(graph_angle_cos(ng.angle_index, ng.edge_vec)) + + eps = 1e-6 + + # At least one valid angle must exist (atom 0 is the only center with ≥2 nei) + valid_angles = [p for p in range(am.shape[0]) if am[p]] + assert len(valid_angles) >= 1, "No valid angles found — geometry problem" + + for p in valid_angles: + ea = int(ai[0, p]) + eb = int(ai[1, p]) + center_node = int(ei[1, ea]) + na_node = int(ei[0, ea]) + nb_node = int(ei[0, eb]) + + # Reference: compute difference vectors FROM COORDINATES (independent of graph) + r_center = coord[0, center_node] + r_na = coord[0, na_node] + r_nb = coord[0, nb_node] + va_ref = r_na - r_center # (3,) + vb_ref = r_nb - r_center # (3,) + + # Reference: unnormalized dot product from coordinates (se_t convention) + env_ij_ref = float(np.dot(va_ref, vb_ref)) + + # Reference norms — from coordinates, NOT from edge_vec + na_norm = float(np.linalg.norm(va_ref)) + nb_norm = float(np.linalg.norm(vb_ref)) + + # Graph: unwind graph_angle_cos back to the unnormalized dot product + env_ij_graph = float(cos[p]) * (na_norm + eps) * (nb_norm + eps) / (1 - eps) + + np.testing.assert_allclose( + env_ij_graph, + env_ij_ref, + rtol=1e-12, + err_msg=( + f"se_t dot mismatch at angle {p}: " + f"center={center_node}, na={na_node}, nb={nb_node}, " + f"va={va_ref}, vb={vb_ref}" + ), + ) + + +# --------------------------------------------------------------------------- +# torch namespace smoke test (TID253) +# --------------------------------------------------------------------------- + + +def test_graph_angle_cos_torch_matches_numpy(): + """graph_angle_cos on torch tensors matches numpy output (array-API compat).""" + import torch + + coord = np.array([[[0.0, 0, 0], [1.0, 0, 0], [0.0, 1.0, 0], [1.0, 1.0, 0]]]) + atype = np.array([[0, 0, 0, 0]]) + ng = attach_angles(build_neighbor_graph(coord, atype, None, 3.0), a_rcut=3.0) + angle_index_np = np.asarray(ng.angle_index) + edge_vec_np = np.asarray(ng.edge_vec) + + cos_np = np.asarray(graph_angle_cos(angle_index_np, edge_vec_np)) + + angle_index_t = torch.from_numpy(angle_index_np) + edge_vec_t = torch.from_numpy(edge_vec_np) + cos_t = graph_angle_cos(angle_index_t, edge_vec_t) + cos_t_np = cos_t.numpy() + + np.testing.assert_allclose(cos_t_np, cos_np, rtol=1e-14, atol=1e-14) diff --git a/source/tests/common/dpmodel/test_segment_softmax.py b/source/tests/common/dpmodel/test_segment_softmax.py new file mode 100644 index 0000000000..b34ee8efaf --- /dev/null +++ b/source/tests/common/dpmodel/test_segment_softmax.py @@ -0,0 +1,94 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""segment_max / segment_softmax (NeighborGraph PR-D segment toolkit).""" + +import numpy as np + +from deepmd.dpmodel.utils.neighbor_graph import ( + segment_max, + segment_softmax, +) + + +class TestSegmentMax: + def test_basic(self) -> None: + data = np.array([1.0, 5.0, 2.0, -3.0]) + ids = np.array([0, 0, 2, 2], dtype=np.int64) + out = segment_max(data, ids, 3) + assert out[0] == 5.0 + assert np.isneginf(out[1]) # empty segment + assert out[2] == 2.0 + + def test_trailing_dims(self) -> None: + data = np.array([[1.0, -2.0], [3.0, -4.0], [0.0, 9.0]]) + ids = np.array([1, 1, 0], dtype=np.int64) + out = segment_max(data, ids, 2) + np.testing.assert_allclose(out[0], [0.0, 9.0]) + np.testing.assert_allclose(out[1], [3.0, -2.0]) + + def test_torch_matches_numpy(self) -> None: + import torch + + data = np.array([0.3, 1.2, -0.7, 2.0]) + ids = np.array([0, 0, 1, 1], dtype=np.int64) + ref = segment_max(data, ids, 2) + out = segment_max(torch.from_numpy(data), torch.from_numpy(ids), 2) + np.testing.assert_allclose(out.numpy(), ref) + + +class TestSegmentSoftmax: + def test_matches_dense(self) -> None: + logits = np.array([1.0, 2.0, 0.5, -1.0]) + ids = np.array([0, 0, 0, 1], dtype=np.int64) + w = segment_softmax(logits, ids, 2) + ref0 = np.exp(np.array([1.0, 2.0, 0.5]) - 2.0) + ref0 = ref0 / ref0.sum() + np.testing.assert_allclose(w[:3], ref0, atol=1e-12) + np.testing.assert_allclose(w[3], 1.0, atol=1e-12) + + def test_stable_large_logits(self) -> None: + logits = np.array([1e30, 1e30 + 1.0]) + ids = np.array([0, 0], dtype=np.int64) + w = segment_softmax(logits, ids, 1) + assert not np.any(np.isnan(w)) + np.testing.assert_allclose(w.sum(), 1.0, atol=1e-12) + + def test_masked_entries_zero(self) -> None: + logits = np.array([1.0, 2.0, 3.0]) + ids = np.array([0, 0, 0], dtype=np.int64) + mask = np.array([True, False, True]) + w = segment_softmax(logits, ids, 1, mask=mask) + assert w[1] == 0.0 + np.testing.assert_allclose(w.sum(), 1.0, atol=1e-12) + # masked entry excluded from the denominator too + ref = np.exp(np.array([1.0, 3.0]) - 3.0) + ref = ref / ref.sum() + np.testing.assert_allclose(w[[0, 2]], ref, atol=1e-12) + + def test_all_masked_segment_is_zero_no_nan(self) -> None: + logits = np.array([1.0, 2.0, 5.0]) + ids = np.array([0, 0, 1], dtype=np.int64) + mask = np.array([True, True, False]) + w = segment_softmax(logits, ids, 2, mask=mask) + assert not np.any(np.isnan(w)) + assert w[2] == 0.0 + + def test_empty_segment_no_nan(self) -> None: + logits = np.array([1.0, 2.0]) + ids = np.array([0, 0], dtype=np.int64) + w = segment_softmax(logits, ids, 3) + assert not np.any(np.isnan(w)) + + def test_torch_matches_numpy(self) -> None: + import torch + + logits = np.array([0.3, 1.2, -0.7, 2.0]) + ids = np.array([0, 0, 1, 1], dtype=np.int64) + mask = np.array([True, True, True, False]) + ref = segment_softmax(logits, ids, 2, mask=mask) + out = segment_softmax( + torch.from_numpy(logits), + torch.from_numpy(ids), + 2, + mask=torch.from_numpy(mask), + ) + np.testing.assert_allclose(out.numpy(), ref, atol=1e-12) diff --git a/source/tests/pt_expt/descriptor/test_dpa1.py b/source/tests/pt_expt/descriptor/test_dpa1.py index d7d2718e67..cddd22419f 100644 --- a/source/tests/pt_expt/descriptor/test_dpa1.py +++ b/source/tests/pt_expt/descriptor/test_dpa1.py @@ -311,6 +311,65 @@ def fn(coord_ext, atype_ext, nlist, mapping): atol=atol, ) + @pytest.mark.parametrize("smooth", [False, True]) # smooth attention branch + @pytest.mark.parametrize("prec", ["float64"]) # precision + def test_make_fx_graph_attn(self, prec, smooth) -> None: + """make_fx (export-readiness) of the GRAPH forward with attention. + + MERGE BLOCKER (NeighborGraph PR-D): pt_expt compiled training routes + eligible models through the graph lower by default, so graph attention + (``attn_layer > 0``) must be fx-traceable — the shape-static + ``center_edge_pairs`` form keeps the pair enumeration ``nonzero``-free. + Covers both the smooth and non-smooth attention branches. + """ + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(rng.normal(size=(self.nt, nnei, 4))) + + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + dd0 = DescrptDPA1( + self.rcut, + self.rcut_smth, + self.sel_mix, + self.nt, + attn_layer=2, + attn_dotr=True, + smooth_type_embedding=smooth, + precision=prec, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.se_atten.stddev = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0 = dd0.eval() + coord_ext = torch.tensor(self.coord_ext, dtype=dtype, device=self.device) + atype_ext = torch.tensor(self.atype_ext, dtype=int, device=self.device) + nlist = torch.tensor(self.nlist, dtype=int, device=self.device) + mapping = torch.tensor(self.mapping, dtype=int, device=self.device) + + def fn(coord_ext, atype_ext, nlist, mapping): + coord_ext = coord_ext.detach().requires_grad_(True) + rd = dd0(coord_ext, atype_ext, nlist, mapping)[0] + grad = torch.autograd.grad(rd.sum(), coord_ext, create_graph=False)[0] + return rd, grad + + rd_eager, grad_eager = fn(coord_ext, atype_ext, nlist, mapping) + traced = make_fx(fn)(coord_ext, atype_ext, nlist, mapping) + rd_traced, grad_traced = traced(coord_ext, atype_ext, nlist, mapping) + np.testing.assert_allclose( + rd_eager.detach().cpu().numpy(), + rd_traced.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) + np.testing.assert_allclose( + grad_eager.detach().cpu().numpy(), + grad_traced.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) + @pytest.mark.parametrize("shared_level", [0, 1]) # sharing level def test_share_params(self, shared_level) -> None: """share_params level 0: share all; level 1: share type_embedding only.""" diff --git a/source/tests/pt_expt/model/test_dpa1_graph_lower.py b/source/tests/pt_expt/model/test_dpa1_graph_lower.py index e274a1bcec..d6cbb51f2f 100644 --- a/source/tests/pt_expt/model/test_dpa1_graph_lower.py +++ b/source/tests/pt_expt/model/test_dpa1_graph_lower.py @@ -91,7 +91,7 @@ def setup_method(self) -> None: [[0, 0, 0, 1, 1]], dtype=torch.int64, device=self.device ) - def _make_model(self) -> EnergyModel: + def _make_model(self, attn_layer: int = 0, smooth: bool = False) -> EnergyModel: ds = DescrptDPA1( self.rcut, self.rcut_smth, @@ -100,9 +100,13 @@ def _make_model(self) -> EnergyModel: neuron=[3, 6], axis_neuron=2, attn=4, - attn_layer=0, # graph lower only supports attn_layer == 0 + attn_layer=attn_layer, attn_dotr=True, attn_mask=False, + # smooth attention keeps sel-padding in the dense softmax + # denominator; the carry-all graph drops it BY DESIGN (PR-D), so + # exact graph-vs-dense parity requires smooth=False here. + smooth_type_embedding=smooth, activation_function="tanh", set_davg_zero=False, type_one_side=True, @@ -165,13 +169,16 @@ def _prepare_lower_inputs(self, periodic: bool): mapping_t = torch.tensor(mapping, dtype=torch.int64, device=self.device) return ext_coord, ext_atype, nlist_t, mapping_t + @pytest.mark.parametrize("attn_layer", [0, 2]) # factorizable AND attention @pytest.mark.parametrize("periodic", [True, False]) # PBC vs non-PBC @pytest.mark.parametrize("do_av", [False, True]) # atom-virial off / on - def test_force_virial_parity_vs_legacy(self, periodic, do_av) -> None: + def test_force_virial_parity_vs_legacy(self, periodic, do_av, attn_layer) -> None: """Graph lower energy/force/virial/atom_virial == legacy dense lower on the SAME neighbor set (regime-1 graph from from_dense_quartet). + attn_layer=2 exercises graph attention through model-level autograd + (smooth=False: exact carry-all parity regime, NeighborGraph PR-D). """ - model = self._make_model() + model = self._make_model(attn_layer=attn_layer) model.eval() tol = ( {"rtol": 1e-12, "atol": 1e-12} diff --git a/source/tests/pt_expt/model/test_linear_model.py b/source/tests/pt_expt/model/test_linear_model.py index a18cabd9e1..8c95f3e69e 100644 --- a/source/tests/pt_expt/model/test_linear_model.py +++ b/source/tests/pt_expt/model/test_linear_model.py @@ -343,6 +343,12 @@ def test_forward_lower_exportable(self) -> None: "temperature": 1.0, "set_davg_zero": True, "type_one_side": True, + # smooth attention diverges between the graph default (standard model, + # carry-all: no phantom sel-padding softmax terms) and the dense route + # (linear models are graph-ineligible) by design (NeighborGraph PR-D); + # pin smooth off so both routes are exact and the weight-combination + # comparison stays at 1e-10. + "smooth_type_embedding": False, "seed": 1, }, "fitting_net": { diff --git a/source/tests/pt_expt/test_plugin.py b/source/tests/pt_expt/test_plugin.py index a59242e592..ddee7a6876 100644 --- a/source/tests/pt_expt/test_plugin.py +++ b/source/tests/pt_expt/test_plugin.py @@ -22,12 +22,36 @@ def fake_entry_points(*, group=None): return [_FakeEntryPoint(calls)] monkeypatch.setattr(importlib.metadata, "entry_points", fake_entry_points) + + # Snapshot the deepmd.pt_expt module tree BEFORE re-importing. Just popping + # "deepmd.pt_expt" and leaving its submodules cached poisons sys.modules + # for the rest of the pytest process: a later import of a cached submodule + # (e.g. deepmd.pt_expt.infer.deep_eval) re-creates a BARE parent package + # whose submodule attributes (utils/infer/...) are never rebound, and + # mock.patch("deepmd.pt_expt.utils...") then fails with AttributeError on + # py3.10 (shard-order dependent CI failure). + saved = { + k: v + for k, v in sys.modules.items() + if k == "deepmd.pt_expt" or k.startswith("deepmd.pt_expt.") + } + deepmd_pkg = sys.modules.get("deepmd") sys.modules.pop("deepmd.pt_expt", None) try: importlib.import_module("deepmd.pt_expt") finally: - sys.modules.pop("deepmd.pt_expt", None) + # drop everything the fresh import created, then restore the snapshot + # (including the parent-package attribute binding). + for k in [ + m + for m in list(sys.modules) + if m == "deepmd.pt_expt" or m.startswith("deepmd.pt_expt.") + ]: + sys.modules.pop(k, None) + sys.modules.update(saved) + if deepmd_pkg is not None and "deepmd.pt_expt" in saved: + deepmd_pkg.pt_expt = saved["deepmd.pt_expt"] assert groups == ["deepmd.pt_expt"] assert calls == ["load"] diff --git a/source/tests/pt_expt/utils/test_neighbor_list.py b/source/tests/pt_expt/utils/test_neighbor_list.py index 26bf5c19da..ff29ed50c4 100644 --- a/source/tests/pt_expt/utils/test_neighbor_list.py +++ b/source/tests/pt_expt/utils/test_neighbor_list.py @@ -113,6 +113,11 @@ "temperature": 1.0, "set_davg_zero": True, "type_one_side": True, + # smooth attention diverges between the carry-all graph default + # (neighbor_list=None) and the explicit World-1 builders by design + # (NeighborGraph PR-D: dense keeps sel-padding in the attention + # softmax denominator); pin smooth off so all routes are exact. + "smooth_type_embedding": False, "seed": 1, }, "fitting_net": {"neuron": [8, 8], "resnet_dt": True, "seed": 1},