diff --git a/deepmd/dpmodel/array_api.py b/deepmd/dpmodel/array_api.py index 094d4dfd6f..115242edfb 100644 --- a/deepmd/dpmodel/array_api.py +++ b/deepmd/dpmodel/array_api.py @@ -211,6 +211,25 @@ def xp_add_at(x: Array, indices: Array, values: Array) -> Array: return x +def xp_hint_dynamic_size(x: Array) -> None: + """Mark a data-dependent leading dimension as a valid size for torch.export. + + Under symbolic tracing (``make_fx`` / ``torch.export``) the length of a + data-dependent array (e.g. the output of ``nonzero`` or a tensor-``repeat``) + is an UNBACKED SymInt; guarding Python control flow or allocations on it + raises ``GuardOnDataDependentSymNode``. ``torch._check_is_size`` registers + the ``>= 0`` size hint that lets the tracer treat it as a proper dimension + (recorded as a ``sym_constrain_range_for_size`` node, preserved by AOTI). + + No-op for numpy / jax / eager-torch concrete shapes — safe to call + unconditionally from dpmodel code (torch imported lazily, torch arrays only). + """ + if array_api_compat.is_torch_array(x): + import torch + + torch._check_is_size(x.shape[0]) + + def xp_maximum_at(x: Array, indices: Array, values: Array) -> Array: """Segment max-assign of values into x at the specified indices. diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 13f8f3e351..30a2d25e38 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -26,6 +26,7 @@ ) from deepmd.dpmodel.common import ( cast_precision, + get_xp_precision, to_numpy_array, ) from deepmd.dpmodel.utils import ( @@ -428,15 +429,20 @@ def get_numb_attn_layer(self) -> int: def uses_graph_lower(self) -> bool: """Returns whether this descriptor supports the graph-native lower. - The graph-native energy lower (``call_graph``) currently covers only the - non-attention (``attn_layer == 0``) factorizable path with concat - type-embedding and no type exclusion. Any other config (attention, - ``tebd_input_mode == "strip"``, ``exclude_types``) falls back to the - legacy dense path, so those models keep working unchanged. + The graph-native lower (``call_graph``) covers the factorizable path + AND transformer attention (``attn_layer >= 0``, NeighborGraph PR-D) + with concat type-embedding and no type exclusion. Remaining ineligible + configs (``tebd_input_mode == "strip"``, ``exclude_types``) fall back + to the legacy dense path, so those models keep working unchanged. + + Eligibility does NOT imply numerical interchangeability with the + dense route for every config: with ``smooth_type_embedding=True`` + the carry-all graph attention is sel-independent by design and + differs from the dense lower by up to ~1e-4 (see the Notes of + :meth:`call_graph`). """ return ( - self.se_atten.attn_layer == 0 - and self.se_atten.tebd_input_mode == "concat" + self.se_atten.tebd_input_mode == "concat" and not self.se_atten.exclude_types ) @@ -643,6 +649,9 @@ def _call_graph_adapter( graph, atype_local, type_embedding=self.type_embedding.call(), + # the adapter graph is shape-static center-major (compact=False): + # keep the attention pair enumeration nonzero-free (traceable) + static_nnei=nnei, ) # call_graph returns flat (N, ...) node axis; reshape to (nf, nloc, ...) # for the dense 5-tuple ABI -- this reshape is LOCAL to the adapter shim. @@ -727,8 +736,9 @@ def call_graph( graph: Any, atype: Array, type_embedding: Array | None = None, + static_nnei: int | None = None, ) -> tuple[Array, Array]: - """Descriptor-level graph-native forward (``attn_layer == 0``). + """Descriptor-level graph-native forward. Wraps the block kernel :meth:`DescrptBlockSeAtten.call_graph`, adds the descriptor-level @@ -740,6 +750,21 @@ def call_graph( not produce the dense ``sw`` (that lives in the dense :meth:`call` adapter, which has the ``nlist``/``coord_ext`` needed to build it). + Notes + ----- + **Smooth attention is intentionally sel-independent on the graph + path.** For ``smooth_type_embedding=True`` the legacy dense attention + keeps the sel-padding slots in its softmax DENOMINATOR (phantom + ``exp(-attnw_shift)`` terms), which makes dense output depend on the + ``sel`` setting by up to ~1e-4 even for identical physical neighbors. + A carry-all graph has no padding slots, so its softmax runs over the + real neighbor pairs only: cleaner, sel-independent semantics that + deliberately DIFFER from the dense route for smooth models. The two + routes agree bit-tight only for ``smooth_type_embedding=False`` (at + non-binding ``sel``), or when this kernel is realized on a dense + layout via ``static_nnei`` (the dense :meth:`call` adapter), which + reproduces the phantom terms for exact backward compatibility. + Parameters ---------- graph @@ -757,10 +782,21 @@ def call_graph( (N, ng, 3) equivariant single-particle representation, flat node axis. """ + import dataclasses + xp = array_api_compat.array_namespace(graph.edge_vec) dev = array_api_compat.device(graph.edge_vec) + # manual @cast_precision: the decorator casts array ARGUMENTS, but the + # graph's only float input (edge_vec) is inside the NeighborGraph + # dataclass, invisible to it. Cast edge_vec down to the descriptor + # precision on entry and the outputs back to the caller's dtype on + # exit (differentiable: grad still flows to the caller's edge_vec leaf). + in_dtype = graph.edge_vec.dtype + prec = get_xp_precision(xp, self.precision) + if in_dtype != prec: + graph = dataclasses.replace(graph, edge_vec=xp.astype(graph.edge_vec, prec)) grrg, rot_mat = self.se_atten.call_graph( - graph, atype, type_embedding=type_embedding + graph, atype, type_embedding=type_embedding, static_nnei=static_nnei ) # FLAT node axis (N, ...): no (nf, nloc) reshape -- ragged-native, spec. if self.concat_output_tebd: @@ -772,6 +808,9 @@ def call_graph( atype_local = xp.asarray(atype, device=dev) atype_embd = xp.take(type_embedding, atype_local, axis=0) # (N, tebd_dim) grrg = xp.concat([grrg, atype_embd], axis=-1) + if in_dtype != prec: + grrg = xp.astype(grrg, in_dtype) + rot_mat = xp.astype(rot_mat, in_dtype) return grrg, rot_mat def enable_compression( @@ -1670,12 +1709,15 @@ def call_graph( graph: Any, atype: Array, type_embedding: Array | None = None, + static_nnei: int | None = None, ) -> tuple[Array, Array]: - """Graph-native forward (``attn_layer=0`` only). + """Graph-native forward. Bit-exact analogue of :meth:`call` on the SAME neighbor list, with the neighbor-axis reduction replaced by a ``segment_sum`` over edge centers - (``dst``). Geometry enters only through ``graph.edge_vec``. + (``dst``) and the dense ``(nnei, nnei)`` transformer attention replaced + by pairs of edges sharing a center (``center_edge_pairs`` + + ``segment_softmax``). Geometry enters only through ``graph.edge_vec``. Parameters ---------- @@ -1687,6 +1729,12 @@ def call_graph( (N,) flat node atom types (``N = sum(graph.n_node)``). type_embedding (ntypes_with_padding, tebd_dim) type-embedding table. + static_nnei + When the graph uses the shape-static center-major layout + (``from_dense_quartet(compact=False)``, ``E = n_center * nnei``), + pass ``nnei`` so the attention edge-pair enumeration stays + jit/export-traceable (no ``nonzero``). ``None`` (carry-all / + compact graphs) selects the dynamic eager form. Returns ------- @@ -1699,8 +1747,7 @@ def call_graph( Notes ----- - Known limitations (NeighborGraph PR-A): - - ``attn_layer == 0`` only (attention lands in PR-D); + Known limitations: - ``tebd_input_mode == "concat"`` only (strip mode lands later); - ``exclude_types`` is not yet supported and raises (lands in a later PR). """ @@ -1709,11 +1756,6 @@ def call_graph( segment_sum, ) - if self.attn_layer != 0: - raise NotImplementedError( - "graph path supports attn_layer=0 only (NeighborGraph PR-A); " - "attn_layer>0 lands in PR-D" - ) if self.tebd_input_mode not in ["concat"]: raise NotImplementedError( "graph path supports tebd_input_mode='concat' only (NeighborGraph PR-A)" @@ -1738,7 +1780,7 @@ def call_graph( # per-edge env-mat 4-vector, normalized by the center (dst) atom type. # self.mean/self.stddev are slot-independent (ntypes, nnei, 4); slot 0 is # the canonical per-type vector. - rr = edge_env_mat( + rr, sw_e = edge_env_mat( graph.edge_vec, center_type, self.mean[:, 0, :], @@ -1747,7 +1789,8 @@ def call_graph( self.rcut_smth, protection=self.env_protection, edge_mask=graph.edge_mask, - ) # (E, 4) + return_sw=True, + ) # (E, 4), (E, 1) sw zeroed on padding # radial channel ss = rr[:, 0:1] # (E, 1) # neighbor / center type embeddings (concat mode); ghost type == owner type @@ -1764,6 +1807,13 @@ def call_graph( ss = xp.concat([ss, atype_embd_nlist], axis=-1) # embedding net (same weights as the dense path); applies on the last axis gg = self.embeddings[0].call(ss) # (E, ng) + # transformer attention over each center's edges — mirrors the dense + # self.dpa1_attention(gg, nlist_mask, input_r, sw), which also runs on + # the UNMASKED gg (padding rows are neutralized afterwards). + if self.attn_layer > 0: + gg = self._graph_attention( + gg, rr, dst, n_total, graph.edge_mask, sw_e, static_nnei + ) # zero padding/guard edges BEFORE the segment sum gg = gg * xp.astype(graph.edge_mask[:, None], gg.dtype) # outer product (replaces the dense gg[:,:,:,None] * rr[:,:,None,:]) @@ -1784,6 +1834,138 @@ def call_graph( rot_mat = gr[:, :, 1:] return grrg, rot_mat + def _graph_attention( + self, + gg: Array, + rr: Array, + dst: Array, + n_total: int, + edge_mask: Array, + sw_e: Array, + static_nnei: int | None, + ) -> Array: + """Graph-native transformer attention over each center's edges. + + Ragged reproduction of :class:`NeighborGatedAttention` / + :class:`GatedAttentionLayer`: edges sharing a center attend to each + other. The dense ``(nnei, nnei)`` square per center becomes the + edge-pair axis from ``center_edge_pairs(ordered=True, + include_self=True)``; softmax over the key axis becomes + ``segment_softmax`` grouped by the query edge. + + Parameters + ---------- + gg : (E, ng) per-edge embedding (UNMASKED, as in the dense path). + rr : (E, 4) per-edge env-mat vector (``rr[:, 1:4]`` carries direction). + dst : (E,) center of each edge. + n_total : number of centers. + edge_mask : (E,) real-vs-padding edge mask. + sw_e : (E, 1) smooth switch, zeroed on padding edges. + static_nnei : shape-static layout ``nnei`` or ``None`` (compact eager). + """ + from deepmd.dpmodel.utils.neighbor_graph import ( + center_edge_pairs, + ) + + xp = array_api_compat.array_namespace(gg) + # per-edge normalized direction (mirrors the dense input_r, + # rr[..., 1:4] / max(|rr[..., 1:4]|, 1e-12)) + dir3 = rr[:, 1:4] + normed = safe_for_vector_norm(dir3, axis=-1, keepdims=True) + input_r = dir3 / xp.maximum(normed, xp.full_like(normed, 1e-12)) # (E, 3) + # transformer neighbor-pairs: full ordered square incl. the diagonal + # (q_m . k_n is not symmetric and self-attention keeps m == n) + q_e, k_e, pair_mask = center_edge_pairs( + dst, + edge_mask, + n_total, + include_self=True, + ordered=True, + static_nnei=static_nnei, + ) + for layer in self.dpa1_attention.attention_layers: + gg = self._graph_attention_one_layer( + layer, gg, input_r, sw_e, q_e, k_e, pair_mask + ) + return gg + + def _graph_attention_one_layer( + self, + layer: "NeighborGatedAttentionLayer", + gg: Array, + input_r: Array, + sw_e: Array, + q_e: Array, + k_e: Array, + pair_mask: Array, + ) -> Array: + """One residual attention layer, op-for-op vs the dense reference. + + Mirrors ``NeighborGatedAttentionLayer.call`` (residual + + ``GatedAttentionLayer.call`` + LayerNorm). Structural translation: + per-center ``q @ k^T`` -> per-pair ``q_m . k_n``; softmax over the key + axis -> ``segment_softmax`` grouped by the query edge. The smooth + branch keeps padding pairs IN the softmax denominator with ``sw = 0`` + (weight ``exp(-attnw_shift)``), exactly like the dense branch, which + replaces the ``-inf`` masking by the switch weighting. + """ + from deepmd.dpmodel.utils.neighbor_graph import ( + segment_softmax, + segment_sum, + ) + + xp = array_api_compat.array_namespace(gg) + e_tot = gg.shape[0] + gal = layer.attention_layer # GatedAttentionLayer + if gal.num_heads != 1: + raise NotImplementedError( + "graph attention assumes num_heads == 1 (dpa1 never exposes " + "num_heads; the dense head_dim QKV slicing relies on it)" + ) + hd = gal.head_dim # == hidden_dim for num_heads == 1 + residual = gg + # in_proj -> Q, K, V; mirror the dense HEAD_DIM slicing exactly + qkv = gal.in_proj.call(gg) # (E, 3 * hidden) + q = qkv[:, 0:hd] + k = qkv[:, hd : hd * 2] + v = qkv[:, hd * 2 : hd * 3] + if gal.normalize: + q = np_normalize(q, axis=-1) + k = np_normalize(k, axis=-1) + v = np_normalize(v, axis=-1) + q = q * gal.scaling + # per-pair logits q_m . k_n (num_heads == 1) + logits = xp.sum( + xp.take(q, q_e, axis=0) * xp.take(k, k_e, axis=0), axis=-1 + ) # (P,) + if gal.smooth: + # (logits + shift) * sw_m * sw_n - shift, then softmax WITHOUT the + # pair mask: padding pairs stay in the denominator at exp(-shift), + # mirroring the dense smooth branch (sw already zeroed on padding). + attnw_shift = 20.0 # dense GatedAttentionLayer.call default + sw_flat = sw_e[:, 0] # (E,) + sw_q = xp.take(sw_flat, q_e, axis=0) + sw_k = xp.take(sw_flat, k_e, axis=0) + logits = (logits + attnw_shift) * sw_q * sw_k - attnw_shift + w = segment_softmax(logits, q_e, e_tot) # (P,) + w = w * sw_q * sw_k + else: + # non-smooth: dense masks padding keys to -inf pre-softmax == + # excluding them from the softmax entirely + w = segment_softmax(logits, q_e, e_tot, mask=pair_mask) + if gal.dotr: + angular = xp.sum( + xp.take(input_r, q_e, axis=0) * xp.take(input_r, k_e, axis=0), + axis=-1, + ) # (P,) = input_r_m . input_r_n + w = w * angular + # o_m = sum_n w[m, n] v[n] -> segment_sum over the query edge + wv = w[:, None] * xp.take(v, k_e, axis=0) # (P, hd) + o = segment_sum(wv, q_e, e_tot) # (E, hd) + out = gal.out_proj.call(o) # (E, ng) + x = residual + out + return layer.attn_layer_norm.call(x) + def has_message_passing(self) -> bool: """Returns whether the descriptor block has message passing.""" return False diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index b3ce544377..be52bc9f22 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -313,10 +313,14 @@ def call_common( The graph routes (``"dense"``/``"ase"``, and the pt_expt default-flip) require a ``mixed_types`` descriptor with a graph - lower (dpa1 ``attn_layer == 0``). At non-binding ``sel`` the - graph matches the dense path exactly; at binding ``sel`` the - carry-all graph keeps neighbors the dense path truncates, so the - energy intentionally differs. + lower (dpa1/se_atten with concat type embedding and no + ``exclude_types``; attention layers included). At non-binding + ``sel`` the graph matches the dense path exactly for the + non-smooth branch; at binding ``sel`` the carry-all graph keeps + neighbors the dense path truncates, and for + ``smooth_type_embedding=True`` the graph drops the dense + layout's sel-padding softmax terms, so the energy intentionally + differs (sel-independent graph semantics). Returns ------- @@ -688,7 +692,7 @@ def call_common_lower_graph( comm_dict: dict | None = None, charge_spin: Array | None = None, ) -> dict[str, Array]: - """Graph-native PUBLIC lower (PR-A: dpa1 ``attn_layer == 0``). + """Graph-native PUBLIC lower (dpa1/se_atten concat-tebd, attention included). The PRIMARY directly-callable graph interface (spec decision #14). Casts inputs/outputs to/from the model precision exactly like the diff --git a/deepmd/dpmodel/utils/neighbor_graph/__init__.py b/deepmd/dpmodel/utils/neighbor_graph/__init__.py index 6e041805b2..24fb090309 100644 --- a/deepmd/dpmodel/utils/neighbor_graph/__init__.py +++ b/deepmd/dpmodel/utils/neighbor_graph/__init__.py @@ -32,8 +32,13 @@ node_validity_mask, pad_and_guard_edges, ) +from .pairs import ( + center_edge_pairs, +) from .segment import ( + segment_max, segment_mean, + segment_softmax, segment_sum, ) @@ -42,6 +47,7 @@ "NeighborGraph", "build_neighbor_graph", "build_neighbor_graph_ase", + "center_edge_pairs", "edge_env_mat", "edge_force_virial", "frame_id_from_n_node", @@ -49,6 +55,8 @@ "neighbor_graph_from_ijs", "node_validity_mask", "pad_and_guard_edges", + "segment_max", "segment_mean", + "segment_softmax", "segment_sum", ] diff --git a/deepmd/dpmodel/utils/neighbor_graph/env.py b/deepmd/dpmodel/utils/neighbor_graph/env.py index 55bbe1b02f..4057cd8640 100644 --- a/deepmd/dpmodel/utils/neighbor_graph/env.py +++ b/deepmd/dpmodel/utils/neighbor_graph/env.py @@ -41,7 +41,8 @@ def edge_env_mat( rcut_smth: float, protection: float = 0.0, edge_mask: Array | None = None, -) -> Array: + return_sw: bool = False, +) -> Array | tuple[Array, Array]: """Compute the per-edge environment-matrix 4-vector. Mirrors the math in ``_make_env_mat`` / ``EnvMat.call`` (env_mat.py) @@ -79,6 +80,9 @@ def edge_env_mat( (E, 4) normalized environment-matrix vectors. Padding edges (``edge_vec = 0``) produce nonzero values but are masked by ``NeighborGraph.edge_mask`` downstream. + When ``return_sw`` is True, returns ``(em, sw)`` where ``sw`` is the + (E, 1) smooth switch, zeroed on padding edges (mirrors the dense + ``_make_env_mat`` mask; consumed by the smooth attention branch). """ xp = array_api_compat.array_namespace(edge_vec) dev = array_api_compat.device(edge_vec) @@ -114,4 +118,13 @@ def edge_env_mat( avg = xp.take(xp.asarray(davg, device=dev), center_type, axis=0) # (E, 4) std = xp.take(xp.asarray(dstd, device=dev), center_type, axis=0) # (E, 4) + if return_sw: + # per-edge switch, zeroed on padding edges — mirrors the dense + # ``_make_env_mat`` (``weight = weight * mask``); used by the smooth + # attention branch. + if edge_mask is not None: + sw_out = sw * xp.astype(edge_mask[:, None], sw.dtype) + else: + sw_out = sw + return (em - avg) / std, sw_out return (em - avg) / std diff --git a/deepmd/dpmodel/utils/neighbor_graph/pairs.py b/deepmd/dpmodel/utils/neighbor_graph/pairs.py new file mode 100644 index 0000000000..75f2682f20 --- /dev/null +++ b/deepmd/dpmodel/utils/neighbor_graph/pairs.py @@ -0,0 +1,198 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Pairs of edges sharing a center (``dst``) — the edge-pair axis. + +Shared primitive: graph-native attention (NeighborGraph PR-D) uses +``(ordered=True, include_self=True)`` = the full transformer neighbor-pair +square per center; 3-body angles (PR-E) use ``(ordered=False, +include_self=False)``. + +Two forms: + +- **compact** (``static_nnei=None``): segment-based enumeration over the + real edges only — sort edge ids by center, expand each center's Cartesian + square via cumsum offsets. Dynamic ``P = sum(deg**2)``; memory ``O(P)`` + (same order as dense attention's ``O(nloc * nnei**2)``). The data-dependent + sizes (``nonzero`` output, tensor-``repeat`` output) are registered as + UNBACKED SymInt sizes via :func:`xp_hint_dynamic_size`, so the form traces + through ``make_fx``/``torch.export`` and compiles under AOTI (torch >= 2.6 + unbacked-symint support) — this is what makes the carry-all attention + graph lower exportable to a ``.pt2``. numpy/jax run it eagerly as before + (jax.jit would still need a static realization — deferred to the jax PR). +- **shape-static** (``static_nnei`` set): assumes the center-major static + layout (``E = n_center * static_nnei``, edge ``c * static_nnei + m`` belongs + to center ``c`` — the layout ``from_dense_quartet(compact=False)`` emits). + Pure arange/reshape arithmetic, ``P = n_center * static_nnei**2`` with all + pairs materialized and validity carried by ``pair_mask`` — no data-dependent + ops at all, so it traces with only BACKED symbolic shapes (bit-parity with + the dense layout; used by the dense-quartet adapter). + +A global ``(E, E)`` same-center boolean is deliberately NOT used: with +``E ~ N * nnei`` it costs ``O(N**2 * nnei**2)`` memory. +""" + +from __future__ import ( + annotations, +) + +from typing import ( + Any, +) + +import array_api_compat + +from deepmd.dpmodel.array_api import ( + Array, + xp_add_at, + xp_hint_dynamic_size, +) + + +def center_edge_pairs( + dst: Array, + edge_mask: Array, + n_total: int, + *, + include_self: bool = True, + ordered: bool = True, + static_nnei: int | None = None, +) -> tuple[Array, Array, Array]: + """Enumerate pairs of edges sharing a center. + + Parameters + ---------- + dst : Array + (E,) int64 center of each edge (``edge_index[1]``). + edge_mask : Array + (E,) bool, real (True) vs padding (False) edges. + n_total : int + Number of centers (bounds ``dst``). + include_self : bool + Keep the ``m == n`` diagonal (transformer self-attention needs it). + ordered : bool + Keep both ``(m, n)`` and ``(n, m)`` (attention: yes, ``q_m . k_n`` is + not symmetric). ``False`` keeps only ``n >= m`` (with + ``include_self=False``: ``n > m`` — the angle set). + static_nnei : int | None + ``None`` -> compact eager form. Set -> shape-static form assuming the + center-major layout ``E = n_center * static_nnei``. + + Returns + ------- + query_edge : Array + (P,) int64 edge index of the query (``m``). + key_edge : Array + (P,) int64 edge index of the key (``n``). + pair_mask : Array + (P,) bool; False where either edge is padding or the pair is filtered + by the ``include_self`` / ``ordered`` policy (shape-static form; the + compact form drops such pairs and returns all-True). + """ + xp = array_api_compat.array_namespace(dst) + dev = array_api_compat.device(dst) + if static_nnei is not None: + return _pairs_shape_static( + xp, dev, dst, edge_mask, static_nnei, include_self, ordered + ) + return _pairs_compact(xp, dev, dst, edge_mask, n_total, include_self, ordered) + + +def _pairs_shape_static( + xp: Any, + dev: Any, + dst: Array, + edge_mask: Array, + nn: int, + include_self: bool, + ordered: bool, +) -> tuple[Array, Array, Array]: + e_tot = dst.shape[0] + # (E, nn): every edge queries the nn slots of its own center block + eids = xp.arange(e_tot, dtype=xp.int64, device=dev) + base = (eids // nn) * nn # start of each edge's center block + slots = xp.arange(nn, dtype=xp.int64, device=dev) + q2 = xp.broadcast_to(eids[:, None], (e_tot, nn)) + k2 = base[:, None] + slots[None, :] + query_edge = xp.reshape(q2, (-1,)) + key_edge = xp.reshape(k2, (-1,)) + pair_mask = xp.take(edge_mask, query_edge, axis=0) & xp.take( + edge_mask, key_edge, axis=0 + ) + if not include_self: + pair_mask = pair_mask & (query_edge != key_edge) + if not ordered: + pair_mask = pair_mask & (key_edge >= query_edge) + return query_edge, key_edge, pair_mask + + +def _pairs_compact( + xp: Any, + dev: Any, + dst: Array, + edge_mask: Array, + n_total: int, + include_self: bool, + ordered: bool, +) -> tuple[Array, Array, Array]: + empty = ( + xp.zeros((0,), dtype=xp.int64, device=dev), + xp.zeros((0,), dtype=xp.int64, device=dev), + xp.zeros((0,), dtype=xp.bool, device=dev), + ) + # early-return fast paths ONLY when the shape is a concrete Python int: + # under make_fx/torch.export symbolic tracing shape[0] is a SymInt (the + # nonzero output size is UNBACKED), and branching on it raises + # GuardOnDataDependentSymNode. The general code below handles the empty + # case correctly anyway (all downstream arrays come out empty). + if isinstance(dst.shape[0], int) and dst.shape[0] == 0: + return empty + # real edges only, grouped by center (stable sort keeps original order + # within a center — irrelevant for correctness, deterministic for tests) + (real_idx,) = xp.nonzero(edge_mask) + xp_hint_dynamic_size(real_idx) # unbacked size R: register >= 0 size hint + r_tot = real_idx.shape[0] + if isinstance(r_tot, int) and r_tot == 0: + return empty + d_real = xp.take(dst, real_idx, axis=0) + order = xp.argsort(d_real, stable=True) + eid = xp.take(real_idx, order, axis=0) # (R,) edge ids, center-grouped + ds = xp.take(d_real, order, axis=0) # (R,) sorted centers + # per-center degree and group start (over the sorted layout) + ones = xp.ones((r_tot,), dtype=xp.int64, device=dev) + counts = xp_add_at( + xp.zeros((n_total,), dtype=xp.int64, device=dev), ds, ones + ) # (n_total,) + csum = xp.cumulative_sum(counts) + start = csum - counts # (n_total,) group start per center + deg = xp.take(counts, ds, axis=0) # (R,) degree of each edge's center + # iota over the (unbacked-size) R and P axes via cumsum(ones) - 1 instead + # of xp.arange: the array_api_compat arange wrapper branches on the length + # in Python, which raises GuardOnDataDependentSymNode for unbacked SymInts. + iota_r = xp.cumulative_sum(ones) - 1 # (R,) = arange(r_tot) + # each sorted edge t emits deg[t] pairs; P = sum(deg**2) + query_sorted = xp.repeat(iota_r, deg) # (P,) + xp_hint_dynamic_size(query_sorted) # unbacked size P = sum(deg**2) + # within each query's block, a 0..deg-1 ramp indexes the key group + pair_off = xp.cumulative_sum(deg) - deg # (R,) exclusive prefix of deg + p_tot = query_sorted.shape[0] + iota_p = xp.cumulative_sum(xp.ones_like(query_sorted)) - 1 # (P,) = arange(p_tot) + ramp = iota_p - xp.take(pair_off, query_sorted, axis=0) + key_sorted = xp.take(start, xp.take(ds, query_sorted, axis=0), axis=0) + ramp + query_edge = xp.take(eid, query_sorted, axis=0) + key_edge = xp.take(eid, key_sorted, axis=0) + if include_self and ordered: + # no policy filter (the attention default): every enumerated pair is + # real. Skipping the compression nonzero here keeps the attention + # graph-lower traceable with a single unbacked size (P). + pair_mask = xp.ones((p_tot,), dtype=xp.bool, device=dev) + return query_edge, key_edge, pair_mask + keep = xp.ones((p_tot,), dtype=xp.bool, device=dev) + if not include_self: + keep = keep & (query_edge != key_edge) + if not ordered: + keep = keep & (key_edge >= query_edge) + (kept,) = xp.nonzero(keep) + xp_hint_dynamic_size(kept) # unbacked size: policy-filtered pair count + query_edge = xp.take(query_edge, kept, axis=0) + key_edge = xp.take(key_edge, kept, axis=0) + pair_mask = xp.ones((query_edge.shape[0],), dtype=xp.bool, device=dev) + return query_edge, key_edge, pair_mask diff --git a/deepmd/dpmodel/utils/neighbor_graph/segment.py b/deepmd/dpmodel/utils/neighbor_graph/segment.py index 45d64af08c..f95671d05c 100644 --- a/deepmd/dpmodel/utils/neighbor_graph/segment.py +++ b/deepmd/dpmodel/utils/neighbor_graph/segment.py @@ -9,6 +9,7 @@ from deepmd.dpmodel.array_api import ( Array, xp_add_at, + xp_maximum_at, ) @@ -37,3 +38,59 @@ def segment_mean(data: Array, segment_ids: Array, num_segments: int) -> Array: # broadcast counts over the trailing dims of summed shape = (num_segments,) + (1,) * (summed.ndim - 1) return summed / xp.reshape(safe, shape) + + +def segment_max(data: Array, segment_ids: Array, num_segments: int) -> Array: + """out[s] = max of data[i] over i with segment_ids[i] == s. + + Shape ``(num_segments, *data.shape[1:])``; empty segments are ``-inf`` + (neutral element — callers guard with masks before consuming them). + """ + xp = array_api_compat.array_namespace(data) + out = xp.full( + (num_segments, *tuple(data.shape[1:])), + -xp.inf, + dtype=data.dtype, + device=array_api_compat.device(data), + ) + return xp_maximum_at(out, segment_ids, data) + + +def segment_softmax( + data: Array, + segment_ids: Array, + num_segments: int, + mask: Array | None = None, +) -> Array: + """Softmax over entries sharing a segment id, numerically stable. + + Mirrors the dense ``np_softmax`` max-subtraction trick with a PER-SEGMENT + max. ``mask`` (bool, per entry) removes masked entries from the softmax + entirely (zero weight AND excluded from the denominator). Empty or + fully-masked segments produce all-zero weights (no NaN). + """ + xp = array_api_compat.array_namespace(data) + if mask is not None: + # keep masked entries out of the per-segment max: send them to -inf + neg = xp.full_like(data, -xp.inf) + data_for_max = xp.where(mask, data, neg) + else: + data_for_max = data + seg_max = segment_max(data_for_max, segment_ids, num_segments) + # guard -inf (empty / fully-masked segments) so gather doesn't yield inf-inf + seg_max = xp.where(xp.isinf(seg_max), xp.zeros_like(seg_max), seg_max) + # shift data_for_max (masked entries already -inf), NOT the raw data: + # a masked entry whose raw value exceeds the unmasked per-segment max by + # more than the exp overflow threshold (~709 fp64 / ~88 fp32) would give + # exp(+big) = inf, and the post-hoc inf * 0 mask multiply = nan, poisoning + # the WHOLE segment through the denominator. exp(-inf) = 0 exactly. + shifted = data_for_max - xp.take(seg_max, segment_ids, axis=0) + ex = xp.exp(shifted) + if mask is not None: + # defensive no-op after the -inf shift (exp(-inf) == 0); kept so the + # zero-weight guarantee never depends on the shift implementation + ex = ex * xp.astype(mask, ex.dtype) + denom = segment_sum(ex, segment_ids, num_segments) + denom_e = xp.take(denom, segment_ids, axis=0) + safe = xp.where(denom_e > 0, denom_e, xp.ones_like(denom_e)) + return ex / safe diff --git a/deepmd/pt_expt/entrypoints/main.py b/deepmd/pt_expt/entrypoints/main.py index d21547fbef..eeb97dcd2c 100644 --- a/deepmd/pt_expt/entrypoints/main.py +++ b/deepmd/pt_expt/entrypoints/main.py @@ -498,10 +498,16 @@ def freeze( lower_kind : str Lower-level export form: ``"nlist"`` (default, dense neighbor-list lower) or ``"graph"`` (NeighborGraph edge-list lower). ``"graph"`` is only valid - for graph-eligible models (``mixed_types`` and ``uses_graph_lower``, - currently dpa1 with ``attn_layer == 0``) and selects the C++ graph - inference path; the per-atom virial is enabled for it (near-free in the - graph path: one extra scatter off the shared single backward). + for graph-eligible models (``mixed_types`` and ``uses_graph_lower``: + dpa1/se_atten with concat type embedding and no ``exclude_types``, + attention layers included) and selects the C++ graph inference path; + the per-atom virial is enabled for it (near-free in the graph path: + one extra scatter off the shared single backward). NOTE: for + ``smooth_type_embedding=True`` the carry-all graph attention + intentionally drops the dense layout's sel-padding terms from the + softmax denominator, so graph-form results are sel-independent and + differ from the legacy dense lower by up to ~1e-4 (see + ``DescrptDPA1.call_graph``). """ import torch @@ -563,10 +569,12 @@ def freeze( m.eval() - # The graph lower is opt-in and only valid for graph-eligible models (dpa1 - # attn_layer==0 today). Fail fast with a clear message rather than emitting a - # broken .pt2. Enable the per-atom virial for the graph form -- it is - # near-free there (one extra scatter off the single shared backward). + # The graph lower is opt-in and only valid for graph-eligible models + # (dpa1 with concat tebd and no type exclusion; attention layers included + # -- the carry-all pair enumeration exports via unbacked SymInts). Fail + # fast with a clear message rather than emitting a broken .pt2. Enable the + # per-atom virial for the graph form -- it is near-free there (one extra + # scatter off the single shared backward). do_atomic_virial = False if lower_kind == "graph": from deepmd.pt_expt.train.training import ( @@ -577,8 +585,8 @@ def freeze( raise ValueError( "lower_kind='graph' requires a graph-eligible model " "(mixed_types and a descriptor exposing uses_graph_lower()==True, " - "currently dpa1 with attn_layer==0). Use lower_kind='nlist' for " - "this model." + "currently dpa1 with tebd_input_mode='concat' and no " + "exclude_types). Use lower_kind='nlist' for this model." ) do_atomic_virial = True diff --git a/deepmd/pt_expt/model/make_model.py b/deepmd/pt_expt/model/make_model.py index 928149ca94..ae2e83eada 100644 --- a/deepmd/pt_expt/model/make_model.py +++ b/deepmd/pt_expt/model/make_model.py @@ -292,7 +292,7 @@ def forward_common_lower_graph( aparam: torch.Tensor | None = None, charge_spin: torch.Tensor | None = None, ) -> dict[str, torch.Tensor]: - """Graph-native lower with autograd force/virial (PR-A: dpa1 ``attn_layer==0``). + """Graph-native lower with autograd force/virial (dpa1/se_atten concat-tebd, attention included). OUTPUT-AGNOSTIC: runs the graph descriptor + fitting forward with ``edge_vec`` as the autograd leaf (via the inherited diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index 1ae51fd483..d2868a4082 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -587,7 +587,8 @@ def _model_uses_graph_lower(model: torch.nn.Module) -> bool: :meth:`~deepmd.pt_expt.model.make_model.make_model..CM._resolve_graph_method` for ``neighbor_graph_method is None`` (the training default): a model is graph-eligible iff it is ``mixed_types`` AND its single descriptor reports - ``uses_graph_lower() == True`` (currently only dpa1 ``attn_layer == 0``). + ``uses_graph_lower() == True`` (dpa1/se_atten with concat type embedding + and no ``exclude_types``; attention layers included). When True the compiled lower must be the GRAPH ``forward_common_lower_graph`` so the compiled path matches eager training (which already default-flips to @@ -721,8 +722,10 @@ def _trace_and_compile_graph( # float precision and device; optional tensors match the actual call. from deepmd.pt_expt.utils.serialization import ( build_synthetic_graph_inputs, + check_graph_trace_torch_version, ) + check_graph_trace_torch_version(model) sample = build_synthetic_graph_inputs( model, e_max=e_max, @@ -906,7 +909,7 @@ def forward( nframes, nloc = atype.shape[:2] rcut = self.original_model.get_rcut() - # Graph-eligible models (dpa1 attn_layer==0) default-flip to the carry-all + # Graph-eligible models (dpa1 concat-tebd, incl. attention) default-flip to the carry-all # GRAPH forward in eager training; the compiled lower must be the GRAPH # lower too, otherwise the eager (graph) and compiled (dense) backward # gradients diverge at fp64 accumulation and the optimizer amplifies it. diff --git a/deepmd/pt_expt/utils/serialization.py b/deepmd/pt_expt/utils/serialization.py index b9d1531769..9317e87f4d 100644 --- a/deepmd/pt_expt/utils/serialization.py +++ b/deepmd/pt_expt/utils/serialization.py @@ -141,6 +141,51 @@ def _needs_with_comm_artifact(model: torch.nn.Module) -> bool: return False +def check_graph_trace_torch_version(model: torch.nn.Module) -> None: + """Fail fast when the graph trace needs unbacked-SymInt support torch lacks. + + The compact ``center_edge_pairs`` realization used by graph attention + (``attn_layer > 0``) relies on unbacked-SymInt tracing + (``torch._check_is_size`` hints on ``nonzero`` / tensor-``repeat`` outputs, + see ``deepmd/dpmodel/utils/neighbor_graph/pairs.py``), which is only solid + from torch >= 2.6. On older torch the trace dies deep inside + ``make_fx``/AOTI with an obscure ``GuardOnDataDependentSymNode`` (or an + ``AttributeError`` on ``_check_is_size``), so both graph trace sites (the + ``.pt2`` export below and the training compile in + ``training._trace_and_compile_graph``) call this guard first. Factorizable + models (``attn_layer == 0``) trace with backed symbols only and are not + restricted. + + Parameters + ---------- + model + The graph-eligible model about to be traced. The attention depth is + read from ``model.atomic_model.descriptor.get_numb_attn_layer()``; + models without a single descriptor (linear/zbl/frozen) pass the + check (they take the dense route anyway). + + Raises + ------ + RuntimeError + If the descriptor has ``attn_layer > 0`` and the running torch is + older than 2.6. + """ + desc = getattr(getattr(model, "atomic_model", None), "descriptor", None) + get_n_attn = getattr(desc, "get_numb_attn_layer", None) + n_attn = get_n_attn() if get_n_attn is not None else 0 + if n_attn <= 0: + return + version = torch.__version__.split("+")[0] + major_minor = tuple(int(p) for p in version.split(".")[:2] if p.isdigit()) + if len(major_minor) == 2 and major_minor < (2, 6): + raise RuntimeError( + f"graph-form tracing of attention layers (attn_layer={n_attn}) " + f"requires torch >= 2.6 (unbacked-SymInt support for the compact " + f"center_edge_pairs realization); found torch {torch.__version__}. " + "Upgrade torch, set 'attn_layer: 0', or use the dense (nlist) path." + ) + + # Module-level cache for the trace-time sendlist buffer. The pointer # value embedded in ``send_list_tensor`` references this numpy array's # data; the array must outlive the trace + export call. Caching here @@ -889,6 +934,7 @@ def _trace_and_export( if lower_kind == "graph": import math + check_graph_trace_torch_version(model) if is_spin: raise NotImplementedError( "graph-form .pt2 export is not supported for spin models" diff --git a/doc/model/train-se-atten.md b/doc/model/train-se-atten.md index e504207ac2..177c652bed 100644 --- a/doc/model/train-se-atten.md +++ b/doc/model/train-se-atten.md @@ -157,6 +157,11 @@ In other backends, type embedding is within this descriptor with the {ref}`tebd_ TensorFlow and other backends have different implementations for {ref}`smooth_type_embedding `. The results are inconsistent when `smooth_type_embedding` is `true`. +In the pt_expt backend, graph-eligible descriptors (mixed types, `tebd_input_mode` `"concat"`, no descriptor-level `exclude_types` or compression) are evaluated by default through the carry-all neighbor-graph path instead of the legacy dense neighbor list. +The graph path considers all neighbors within the cutoff, so its result does not depend on {ref}`sel `. +When `smooth_type_embedding` is `true` and {ref}`attn_layer ` is larger than 0 (the defaults), the dense path keeps `sel`-padding phantom terms in the attention softmax denominator while the graph path drops them, so checkpoints trained under the dense semantics shift by up to about 1e-4 in energy when evaluated on the graph path. +Passing `neighbor_graph_method="legacy"` to the model forward (or the corresponding evaluation option) restores the dense-path numbers exactly. + In the TensorFlow backend, {ref}`scaling_factor ` cannot set to a value other than `1.0`; {ref}`normalize ` cannot be set to `false`; {ref}`temperature ` cannot be set; diff --git a/source/tests/common/dpmodel/test_center_edge_pairs.py b/source/tests/common/dpmodel/test_center_edge_pairs.py new file mode 100644 index 0000000000..aa018b6bb8 --- /dev/null +++ b/source/tests/common/dpmodel/test_center_edge_pairs.py @@ -0,0 +1,133 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""center_edge_pairs: pairs of edges sharing a center (NeighborGraph PR-D/E).""" + +import numpy as np + +from deepmd.dpmodel.utils.neighbor_graph import ( + center_edge_pairs, +) + + +def _oracle(dst, mask, include_self, ordered): + pairs = [] + for m in range(len(dst)): + if not mask[m]: + continue + for n in range(len(dst)): + if not mask[n] or dst[m] != dst[n]: + continue + if not include_self and m == n: + continue + if not ordered and n < m: + continue + pairs.append((m, n)) + return set(pairs) + + +def _got(q, k, pm): + return {(int(q[p]), int(k[p])) for p in range(q.shape[0]) if pm[p]} + + +class TestCompact: + def test_transformer_all_ordered_with_self(self) -> None: + # 3 edges: dst = [0, 0, 1]; center 0 has edges {0,1}, center 1 has {2} + dst = np.array([0, 0, 1], dtype=np.int64) + edge_mask = np.array([True, True, True]) + q, k, pm = center_edge_pairs(dst, edge_mask, 2) + assert _got(q, k, pm) == _oracle([0, 0, 1], [1, 1, 1], True, True) + # center 0: (0,0),(0,1),(1,0),(1,1); center 1: (2,2) => 5 pairs + assert len(_got(q, k, pm)) == 5 + + def test_unordered_no_self_is_angle_set(self) -> None: + dst = np.array([0, 0, 0], dtype=np.int64) + edge_mask = np.array([True, True, True]) + q, k, pm = center_edge_pairs( + dst, edge_mask, 1, include_self=False, ordered=False + ) + assert _got(q, k, pm) == {(0, 1), (0, 2), (1, 2)} + + def test_ignores_padding_edges(self) -> None: + dst = np.array([0, 0, 0], dtype=np.int64) + edge_mask = np.array([True, True, False]) # 3rd is a guard edge + q, k, pm = center_edge_pairs(dst, edge_mask, 1) + assert _got(q, k, pm) == {(0, 0), (0, 1), (1, 0), (1, 1)} + + def test_non_contiguous_center_order(self) -> None: + # edges NOT sorted by center: dst = [1, 0, 1, 0] + dst = np.array([1, 0, 1, 0], dtype=np.int64) + edge_mask = np.array([True, True, True, True]) + q, k, pm = center_edge_pairs(dst, edge_mask, 2) + assert _got(q, k, pm) == _oracle([1, 0, 1, 0], [1] * 4, True, True) + + def test_empty(self) -> None: + dst = np.zeros((0,), dtype=np.int64) + edge_mask = np.zeros((0,), dtype=bool) + q, k, pm = center_edge_pairs(dst, edge_mask, 3) + assert q.shape[0] == 0 and k.shape[0] == 0 and pm.shape[0] == 0 + + def test_random_vs_oracle(self) -> None: + rng = np.random.default_rng(7) + dst = rng.integers(0, 5, size=23).astype(np.int64) + edge_mask = rng.random(23) > 0.3 + for include_self in (True, False): + for ordered in (True, False): + q, k, pm = center_edge_pairs( + dst, edge_mask, 5, include_self=include_self, ordered=ordered + ) + assert _got(q, k, pm) == _oracle( + dst, edge_mask, include_self, ordered + ), (include_self, ordered) + + def test_torch_matches_numpy(self) -> None: + import torch + + dst = np.array([0, 0, 1, 1, 1], dtype=np.int64) + edge_mask = np.array([True, False, True, True, True]) + ref = _got(*center_edge_pairs(dst, edge_mask, 2)) + q, k, pm = center_edge_pairs( + torch.from_numpy(dst), torch.from_numpy(edge_mask), 2 + ) + assert _got(q.numpy(), k.numpy(), pm.numpy()) == ref + + +class TestShapeStatic: + def test_matches_compact(self) -> None: + # center-major static layout: 2 centers x static_nnei=3, edges 2,5 padded + dst = np.array([0, 0, 0, 1, 1, 1], dtype=np.int64) + edge_mask = np.array([True, True, False, True, True, False]) + qc, kc, pmc = center_edge_pairs(dst, edge_mask, 2) + qs, ks, pms = center_edge_pairs(dst, edge_mask, 2, static_nnei=3) + assert qs.shape[0] == 2 * 3 * 3 # static P, data-independent + assert _got(qs, ks, pms) == _got(qc, kc, pmc) + + def test_flags_and_masking(self) -> None: + dst = np.array([0, 0, 0, 1, 1, 1], dtype=np.int64) + edge_mask = np.array([True, True, True, True, False, False]) + for include_self in (True, False): + for ordered in (True, False): + qs, ks, pms = center_edge_pairs( + dst, + edge_mask, + 2, + include_self=include_self, + ordered=ordered, + static_nnei=3, + ) + assert qs.shape[0] == 2 * 3 * 3 # P static regardless of flags + assert _got(qs, ks, pms) == _oracle( + dst, edge_mask, include_self, ordered + ), (include_self, ordered) + + def test_torch_matches_numpy(self) -> None: + import torch + + dst = np.array([0, 0, 1, 1], dtype=np.int64) + edge_mask = np.array([True, False, True, True]) + ref = _got(*center_edge_pairs(dst, edge_mask, 2, static_nnei=2)) + q, k, pm = center_edge_pairs( + torch.from_numpy(dst), + torch.from_numpy(edge_mask), + 2, + static_nnei=2, + ) + assert _got(q.numpy(), k.numpy(), pm.numpy()) == ref diff --git a/source/tests/common/dpmodel/test_dpa1_call_graph_block.py b/source/tests/common/dpmodel/test_dpa1_call_graph_block.py index e8930101dd..9a984a30f3 100644 --- a/source/tests/common/dpmodel/test_dpa1_call_graph_block.py +++ b/source/tests/common/dpmodel/test_dpa1_call_graph_block.py @@ -90,11 +90,8 @@ def test_block_graph_equals_dense_any_sel(self, sel, type_one_side) -> None: atol=1e-12, ) - def test_attn_layer_gt0_raises(self) -> None: - """The graph block kernel fail-fasts for attn_layer > 0 (unsupported).""" - dd = DescrptDPA1(rcut=4.0, rcut_smth=0.5, sel=[20], ntypes=2, attn_layer=2) - with pytest.raises(NotImplementedError): - dd.se_atten.call_graph(None, np.array([0], dtype=np.int64)) + # attn_layer > 0 is supported since NeighborGraph PR-D; parity is covered + # by test_dpa1_graph_attention_parity.py (the fail-fast test was removed). def test_exclude_types_raises(self) -> None: """The graph block kernel fail-fasts for exclude_types (not yet applied).""" diff --git a/source/tests/common/dpmodel/test_dpa1_graph_attention_parity.py b/source/tests/common/dpmodel/test_dpa1_graph_attention_parity.py new file mode 100644 index 0000000000..97d044862f --- /dev/null +++ b/source/tests/common/dpmodel/test_dpa1_graph_attention_parity.py @@ -0,0 +1,269 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Graph-native se_atten attention (attn_layer > 0) vs the dense reference. + +Regime-1 parity (NeighborGraph PR-D): the graph is built FROM the same dense +nlist (``from_dense_quartet``), so the neighbor sets are identical and the +graph attention must reproduce ``GatedAttentionLayer``/``NeighborGatedAttention`` +bit-exactly (CPU rtol 1e-12) for ANY sel — binding or not. + +The smooth branch needs the SHAPE-STATIC graph (``compact=False`` + +``static_nnei``): dense smooth keeps padding slots in the softmax DENOMINATOR +(weight ``exp(-attnw_shift)`` since ``sw = 0``), so bit-parity requires the +same padded pairs on the graph side. The compact (carry-all-like) form drops +padding pairs and is exercised on the non-smooth branch only. +""" + +import numpy as np +import pytest + +from deepmd.dpmodel.descriptor.dpa1 import ( + DescrptDPA1, +) +from deepmd.dpmodel.utils.neighbor_graph import ( + from_dense_quartet, +) +from deepmd.dpmodel.utils.nlist import ( + extend_input_and_build_neighbor_list, +) + +GLOBAL_SEED = 20260703 + + +def _make( + attn_layer, + dotr=False, + smooth=False, + normalize=False, + temperature=1.0, + sel=(20,), +): + # attention `smooth` is wired to smooth_type_embedding (NOT rcut_smth); + # pass it explicitly — its default (True) would silently enable the + # smooth branch in the non-smooth cases. + return DescrptDPA1( + rcut=4.0, + rcut_smth=0.5, + sel=list(sel), + ntypes=2, + neuron=[6, 12], + axis_neuron=2, + attn=8, + attn_layer=attn_layer, + attn_dotr=dotr, + attn_mask=False, + normalize=normalize, + smooth_type_embedding=smooth, + temperature=temperature, + tebd_input_mode="concat", + type_one_side=True, + precision="float64", + seed=GLOBAL_SEED, + ) + + +class TestGraphAttentionParity: + def setup_method(self) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + self.nloc = 5 + self.coord = rng.normal(size=(1, self.nloc, 3)) * 1.5 + self.atype = np.array([[0, 1, 0, 1, 1]], dtype=np.int64) + + def _quartet(self, dd): + return extend_input_and_build_neighbor_list( + self.coord, + self.atype, + dd.get_rcut(), + dd.get_sel(), + mixed_types=dd.mixed_types(), + box=None, + ) + + def _dense_vs_adapter(self, dd, rtol=1e-12): + """Descriptor-level: legacy dense body vs the graph adapter (shape-static).""" + ext_coord, ext_atype, mapping, nlist = self._quartet(dd) + dense = dd._call_dense(ext_coord, ext_atype, nlist) + graph = dd._call_graph_adapter(ext_coord, ext_atype, nlist, mapping) + np.testing.assert_allclose( + graph[0], dense[0], rtol=rtol, atol=rtol, err_msg="descriptor" + ) + np.testing.assert_allclose( + graph[1], dense[1], rtol=rtol, atol=rtol, err_msg="rot_mat" + ) + + # ── Task 5a/5b/5c/5e: full matrix on the shape-static adapter path ────── + @pytest.mark.parametrize("sel", [(20,), (3,)]) # non-binding AND binding + @pytest.mark.parametrize("attn_layer", [1, 2]) # single + stacked layers + def test_core_layers_sel(self, attn_layer, sel) -> None: + dd = _make(attn_layer, sel=sel) + self._dense_vs_adapter(dd) + + @pytest.mark.parametrize("normalize", [False, True]) # q/k/v np_normalize + @pytest.mark.parametrize("temperature", [None, 1.0]) # scaling source + def test_normalize_temperature(self, normalize, temperature) -> None: + dd = _make(1, normalize=normalize, temperature=temperature) + self._dense_vs_adapter(dd) + + @pytest.mark.parametrize("dotr", [False, True]) # angular weighting + @pytest.mark.parametrize("smooth", [False, True]) # switch-fn weighting + def test_dotr_smooth(self, dotr, smooth) -> None: + dd = _make(2, dotr=dotr, smooth=smooth, normalize=True, temperature=None) + self._dense_vs_adapter(dd) + + # ── compact (carry-all-form) graph through the BLOCK kernel, non-smooth ── + @pytest.mark.parametrize("attn_layer", [1, 2]) # single + stacked layers + def test_block_compact_graph_no_smooth(self, attn_layer) -> None: + dd = _make(attn_layer, dotr=True, normalize=True) + blk = dd.se_atten + ext_coord, ext_atype, mapping, nlist = self._quartet(dd) + tebd = dd.type_embedding.call() + nf, nall = ext_atype.shape + atype_embd_ext = np.reshape( + np.take(tebd, np.reshape(ext_atype, (-1,)), axis=0), + (nf, nall, dd.tebd_dim), + ) + dense_g, *_ = blk.call( + nlist, + ext_coord, + ext_atype, + atype_embd_ext=atype_embd_ext, + mapping=None, + type_embedding=tebd, + ) + ng = from_dense_quartet(ext_coord, nlist, mapping) # compact=True + graph_g, _ = blk.call_graph( + ng, np.reshape(ext_atype, (-1,)), type_embedding=tebd + ) + np.testing.assert_allclose( + graph_g.reshape(dense_g.shape), dense_g, rtol=1e-12, atol=1e-12 + ) + + # ── smooth on the compact (carry-all) form: CLEAN DIVERGENCE by design ──── + def test_block_compact_graph_smooth_clean_divergence(self) -> None: + """Carry-all smooth attention deliberately DIVERGES from dense. + + The dense smooth branch keeps sel-padding slots in the attention + softmax DENOMINATOR at weight ``exp(-attnw_shift)``, which makes the + dense output depend on ``sel`` itself (same physical neighbors, + different sel => different output, up to ~1e-4). The carry-all graph + drops those phantom terms — the sel-independent math (user decision + 2026-07-03, PR-D). Bit-parity (1e-12) is proven on the shape-static + adapter (same padded pairs on both sides, ``test_dotr_smooth``); here + we pin only that the compact form stays CLOSE to dense (the artifact + is a bounded denominator perturbation) while NOT bit-equal. + """ + dd = _make(1, smooth=True) + blk = dd.se_atten + ext_coord, ext_atype, mapping, nlist = self._quartet(dd) + tebd = dd.type_embedding.call() + nf, nall = ext_atype.shape + atype_embd_ext = np.reshape( + np.take(tebd, np.reshape(ext_atype, (-1,)), axis=0), + (nf, nall, dd.tebd_dim), + ) + dense_g, *_ = blk.call( + nlist, + ext_coord, + ext_atype, + atype_embd_ext=atype_embd_ext, + mapping=None, + type_embedding=tebd, + ) + ng = from_dense_quartet(ext_coord, nlist, mapping) + graph_g, _ = blk.call_graph( + ng, np.reshape(ext_atype, (-1,)), type_embedding=tebd + ) + graph_g = graph_g.reshape(dense_g.shape) + # close (the artifact is a small denominator perturbation) ... + np.testing.assert_allclose(graph_g, dense_g, rtol=1e-3, atol=1e-3) + # ... but NOT bit-equal: the phantom-padding terms are really gone + assert np.max(np.abs(graph_g - dense_g)) > 1e-9 + + # ── torch namespace smoke (CLAUDE.md: catches numpy-weight leaks) ──────── + # NB: the smoke runs the BLOCK kernel with a torch type_embedding table; + # the raw dpmodel adapter is numpy-weighted by design (pt_expt wraps it). + def test_torch_block_matches_numpy(self) -> None: + import torch + + dd = _make(2, dotr=True, smooth=True, normalize=True, temperature=None) + blk = dd.se_atten + ext_coord, ext_atype, mapping, nlist = self._quartet(dd) + tebd = dd.type_embedding.call() + ng = from_dense_quartet(ext_coord, nlist, mapping, compact=False) + ref, _ = blk.call_graph( + ng, + np.reshape(ext_atype, (-1,)), + type_embedding=tebd, + static_nnei=nlist.shape[2], + ) + ng_t = from_dense_quartet( + torch.from_numpy(ext_coord), + torch.from_numpy(nlist), + torch.from_numpy(mapping), + compact=False, + ) + out, _ = blk.call_graph( + ng_t, + torch.from_numpy(np.reshape(ext_atype, (-1,))), + type_embedding=torch.from_numpy(tebd), + static_nnei=nlist.shape[2], + ) + np.testing.assert_allclose(out.numpy(), ref, rtol=1e-12, atol=1e-12) + + +class TestGraphEligibility: + def test_attention_concat_is_graph_eligible(self) -> None: + assert _make(2).uses_graph_lower() + + def test_strip_mode_stays_dense(self) -> None: + """se_atten_v2 (tebd_input_mode='strip') is NOT graph-eligible yet: + strip-mode graph support is a later PR; it must keep the dense route + (the PR-D plan's 'se_atten_v2 inherits for free' did not hold). + """ + from deepmd.dpmodel.descriptor.se_atten_v2 import ( + DescrptSeAttenV2, + ) + + dd = DescrptSeAttenV2(rcut=4.0, rcut_smth=0.5, sel=[20], ntypes=2, attn_layer=2) + assert not dd.uses_graph_lower() + + +class TestBindingSelDivergence: + """At BINDING sel the carry-all graph attends over MORE neighbors than the + sel-truncated dense path — outputs must differ (sanity, not parity; + spec decision #17). + """ + + def test_carry_all_attention_differs_at_binding_sel(self) -> None: + from deepmd.dpmodel.utils.neighbor_graph import ( + build_neighbor_graph, + ) + + rng = np.random.default_rng(GLOBAL_SEED) + nloc = 6 + coord = rng.random((1, nloc, 3)) * 2.0 # dense blob => binding sel=2 + atype = np.array([[0, 1, 0, 1, 1, 0]], dtype=np.int64) + dd = _make(2, dotr=True, sel=(2,)) + ext_coord, ext_atype, mapping, nlist = extend_input_and_build_neighbor_list( + coord, atype, dd.get_rcut(), dd.get_sel(), mixed_types=True, box=None + ) + assert (nlist >= 0).all(), "fixture must be sel-binding (all slots full)" + tebd = dd.type_embedding.call() + atype_embd_ext = np.reshape( + np.take(tebd, np.reshape(ext_atype, (-1,)), axis=0), + (1, ext_atype.shape[1], dd.tebd_dim), + ) + dense_g, *_ = dd.se_atten.call( + nlist, + ext_coord, + ext_atype, + atype_embd_ext=atype_embd_ext, + mapping=None, + type_embedding=tebd, + ) + graph = build_neighbor_graph(coord, atype, None, dd.get_rcut()) + graph_g, _ = dd.se_atten.call_graph( + graph, atype.reshape(-1), type_embedding=tebd + ) + assert np.max(np.abs(graph_g.reshape(dense_g.shape) - dense_g)) > 1e-6, ( + "carry-all attention must diverge from sel-truncated dense" + ) diff --git a/source/tests/common/dpmodel/test_segment_softmax.py b/source/tests/common/dpmodel/test_segment_softmax.py new file mode 100644 index 0000000000..a97bd9c1aa --- /dev/null +++ b/source/tests/common/dpmodel/test_segment_softmax.py @@ -0,0 +1,112 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""segment_max / segment_softmax (NeighborGraph PR-D segment toolkit).""" + +import numpy as np + +from deepmd.dpmodel.utils.neighbor_graph import ( + segment_max, + segment_softmax, +) + + +class TestSegmentMax: + def test_basic(self) -> None: + data = np.array([1.0, 5.0, 2.0, -3.0]) + ids = np.array([0, 0, 2, 2], dtype=np.int64) + out = segment_max(data, ids, 3) + assert out[0] == 5.0 + assert np.isneginf(out[1]) # empty segment + assert out[2] == 2.0 + + def test_trailing_dims(self) -> None: + data = np.array([[1.0, -2.0], [3.0, -4.0], [0.0, 9.0]]) + ids = np.array([1, 1, 0], dtype=np.int64) + out = segment_max(data, ids, 2) + np.testing.assert_allclose(out[0], [0.0, 9.0]) + np.testing.assert_allclose(out[1], [3.0, -2.0]) + + def test_torch_matches_numpy(self) -> None: + import torch + + data = np.array([0.3, 1.2, -0.7, 2.0]) + ids = np.array([0, 0, 1, 1], dtype=np.int64) + ref = segment_max(data, ids, 2) + out = segment_max(torch.from_numpy(data), torch.from_numpy(ids), 2) + np.testing.assert_allclose(out.numpy(), ref) + + +class TestSegmentSoftmax: + def test_matches_dense(self) -> None: + logits = np.array([1.0, 2.0, 0.5, -1.0]) + ids = np.array([0, 0, 0, 1], dtype=np.int64) + w = segment_softmax(logits, ids, 2) + ref0 = np.exp(np.array([1.0, 2.0, 0.5]) - 2.0) + ref0 = ref0 / ref0.sum() + np.testing.assert_allclose(w[:3], ref0, atol=1e-12) + np.testing.assert_allclose(w[3], 1.0, atol=1e-12) + + def test_stable_large_logits(self) -> None: + logits = np.array([1e30, 1e30 + 1.0]) + ids = np.array([0, 0], dtype=np.int64) + w = segment_softmax(logits, ids, 1) + assert not np.any(np.isnan(w)) + np.testing.assert_allclose(w.sum(), 1.0, atol=1e-12) + + def test_masked_entries_zero(self) -> None: + logits = np.array([1.0, 2.0, 3.0]) + ids = np.array([0, 0, 0], dtype=np.int64) + mask = np.array([True, False, True]) + w = segment_softmax(logits, ids, 1, mask=mask) + assert w[1] == 0.0 + np.testing.assert_allclose(w.sum(), 1.0, atol=1e-12) + # masked entry excluded from the denominator too + ref = np.exp(np.array([1.0, 3.0]) - 3.0) + ref = ref / ref.sum() + np.testing.assert_allclose(w[[0, 2]], ref, atol=1e-12) + + def test_all_masked_segment_is_zero_no_nan(self) -> None: + logits = np.array([1.0, 2.0, 5.0]) + ids = np.array([0, 0, 1], dtype=np.int64) + mask = np.array([True, True, False]) + w = segment_softmax(logits, ids, 2, mask=mask) + assert not np.any(np.isnan(w)) + assert w[2] == 0.0 + + def test_empty_segment_no_nan(self) -> None: + logits = np.array([1.0, 2.0]) + ids = np.array([0, 0], dtype=np.int64) + w = segment_softmax(logits, ids, 3) + assert not np.any(np.isnan(w)) + + def test_torch_matches_numpy(self) -> None: + import torch + + logits = np.array([0.3, 1.2, -0.7, 2.0]) + ids = np.array([0, 0, 1, 1], dtype=np.int64) + mask = np.array([True, True, True, False]) + ref = segment_softmax(logits, ids, 2, mask=mask) + out = segment_softmax( + torch.from_numpy(logits), + torch.from_numpy(ids), + 2, + mask=torch.from_numpy(mask), + ) + np.testing.assert_allclose(out.numpy(), ref, atol=1e-12) + + +def test_masked_entry_larger_than_unmasked_max_no_nan() -> None: + """A masked entry FAR ABOVE the unmasked max must not poison the segment. + + Regression (CodeRabbit #5715): shifting the raw data let a huge masked + logit overflow exp() to inf, and inf * 0 (mask multiply) = nan summed into + the denominator, contaminating every entry of the segment. The shift must + use the masked (-inf) values so masked entries exp() to exactly zero. + """ + data = np.array([1.0, 2.0, 1e5], dtype=np.float64) # 1e5 - 2 >> 709 + ids = np.zeros(3, dtype=np.int64) + mask = np.array([True, True, False]) + out = segment_softmax(data, ids, 1, mask=mask) + assert np.all(np.isfinite(out)) + ref = np.exp([1.0, 2.0]) / np.exp([1.0, 2.0]).sum() + np.testing.assert_allclose(out[:2], ref, rtol=1e-12) + assert out[2] == 0.0 diff --git a/source/tests/pt_expt/descriptor/test_dpa1.py b/source/tests/pt_expt/descriptor/test_dpa1.py index d7d2718e67..cddd22419f 100644 --- a/source/tests/pt_expt/descriptor/test_dpa1.py +++ b/source/tests/pt_expt/descriptor/test_dpa1.py @@ -311,6 +311,65 @@ def fn(coord_ext, atype_ext, nlist, mapping): atol=atol, ) + @pytest.mark.parametrize("smooth", [False, True]) # smooth attention branch + @pytest.mark.parametrize("prec", ["float64"]) # precision + def test_make_fx_graph_attn(self, prec, smooth) -> None: + """make_fx (export-readiness) of the GRAPH forward with attention. + + MERGE BLOCKER (NeighborGraph PR-D): pt_expt compiled training routes + eligible models through the graph lower by default, so graph attention + (``attn_layer > 0``) must be fx-traceable — the shape-static + ``center_edge_pairs`` form keeps the pair enumeration ``nonzero``-free. + Covers both the smooth and non-smooth attention branches. + """ + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(rng.normal(size=(self.nt, nnei, 4))) + + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + dd0 = DescrptDPA1( + self.rcut, + self.rcut_smth, + self.sel_mix, + self.nt, + attn_layer=2, + attn_dotr=True, + smooth_type_embedding=smooth, + precision=prec, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.se_atten.stddev = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0 = dd0.eval() + coord_ext = torch.tensor(self.coord_ext, dtype=dtype, device=self.device) + atype_ext = torch.tensor(self.atype_ext, dtype=int, device=self.device) + nlist = torch.tensor(self.nlist, dtype=int, device=self.device) + mapping = torch.tensor(self.mapping, dtype=int, device=self.device) + + def fn(coord_ext, atype_ext, nlist, mapping): + coord_ext = coord_ext.detach().requires_grad_(True) + rd = dd0(coord_ext, atype_ext, nlist, mapping)[0] + grad = torch.autograd.grad(rd.sum(), coord_ext, create_graph=False)[0] + return rd, grad + + rd_eager, grad_eager = fn(coord_ext, atype_ext, nlist, mapping) + traced = make_fx(fn)(coord_ext, atype_ext, nlist, mapping) + rd_traced, grad_traced = traced(coord_ext, atype_ext, nlist, mapping) + np.testing.assert_allclose( + rd_eager.detach().cpu().numpy(), + rd_traced.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) + np.testing.assert_allclose( + grad_eager.detach().cpu().numpy(), + grad_traced.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) + @pytest.mark.parametrize("shared_level", [0, 1]) # sharing level def test_share_params(self, shared_level) -> None: """share_params level 0: share all; level 1: share type_embedding only.""" diff --git a/source/tests/pt_expt/infer/test_deep_eval.py b/source/tests/pt_expt/infer/test_deep_eval.py index 4444d5e9a1..856bc0724a 100644 --- a/source/tests/pt_expt/infer/test_deep_eval.py +++ b/source/tests/pt_expt/infer/test_deep_eval.py @@ -40,6 +40,21 @@ ) +def _assert_repeatable(a, b) -> None: + """Two evals of the same inputs must agree. + + On CPU the reduction is bit-exact, so require exact equality. On CUDA the + graph descriptor's ``segment_sum`` lowers to ``torch.index_add``, whose + atomicAdd ordering varies run-to-run (1-2 fp64 ULP); require agreement to + the documented CUDA fp64 tolerance instead. Genuine non-determinism + (unseeded dropout/sampling) is O(1e-3+) and still fails this bound. + """ + if DEVICE.type == "cpu": + np.testing.assert_array_equal(a, b) + else: + np.testing.assert_allclose(a, b, rtol=1e-10, atol=1e-10) + + class TestDeepEvalEner(unittest.TestCase): """Test pt_expt inference for energy models.""" @@ -1762,11 +1777,16 @@ def test_descriptor_deterministic_sea(self) -> None: np.testing.assert_array_equal(d1, d2) def test_descriptor_deterministic_dpa1(self) -> None: - """Calling eval_descriptor twice gives same result for DPA1.""" + """Calling eval_descriptor twice gives same result for DPA1. + + DPA1 (mixed_types) routes through the carry-all graph path, whose + ``segment_sum`` is bit-exact on CPU but 1-2 ULP non-deterministic on + CUDA (index_add atomics); see ``_assert_repeatable``. + """ coords, cells, atom_types = self._make_inputs() d1 = self.dp_dpa1.deep_eval.eval_descriptor(coords, cells, atom_types) d2 = self.dp_dpa1.deep_eval.eval_descriptor(coords, cells, atom_types) - np.testing.assert_array_equal(d1, d2) + _assert_repeatable(d1, d2) def test_descriptor_with_fparam(self) -> None: """eval_descriptor works with fparam.""" @@ -1920,7 +1940,12 @@ def test_fitting_ll_deterministic_sea(self) -> None: np.testing.assert_array_equal(fit_ll1, fit_ll2) def test_fitting_ll_deterministic_dpa1(self) -> None: - """Verify calling twice gives the same result for DPA1.""" + """Verify calling twice gives the same result for DPA1. + + DPA1 (mixed_types) routes through the carry-all graph path, whose + ``segment_sum`` is bit-exact on CPU but 1-2 ULP non-deterministic on + CUDA (index_add atomics); see ``_assert_repeatable``. + """ coords, cells, atom_types = self._make_inputs() fit_ll1 = self.dp_dpa1.deep_eval.eval_fitting_last_layer( coords, cells, atom_types @@ -1928,7 +1953,7 @@ def test_fitting_ll_deterministic_dpa1(self) -> None: fit_ll2 = self.dp_dpa1.deep_eval.eval_fitting_last_layer( coords, cells, atom_types ) - np.testing.assert_array_equal(fit_ll1, fit_ll2) + _assert_repeatable(fit_ll1, fit_ll2) def test_fitting_ll_with_fparam_aparam(self) -> None: """eval_fitting_last_layer works with fparam and aparam.""" diff --git a/source/tests/pt_expt/infer/test_graph_deepeval.py b/source/tests/pt_expt/infer/test_graph_deepeval.py index 4e7b929391..07649d9e77 100644 --- a/source/tests/pt_expt/infer/test_graph_deepeval.py +++ b/source/tests/pt_expt/infer/test_graph_deepeval.py @@ -153,11 +153,18 @@ def _eager_dense_reference( return {k: v.detach().cpu().numpy() for k, v in out.items()} -@pytest.fixture(scope="module") -def graph_pt2(): - """Build a dpa1(attn_layer=0) model and export it to a graph-form ``.pt2``. - - The AOTI compile is slow (~90 s), so it is done once per module. The eager +@pytest.fixture(scope="module", params=[0, 2], ids=["attn0", "attn2"]) +def graph_pt2(request): + """Build a dpa1 model and export it to a graph-form ``.pt2``. + + Parametrized over ``attn_layer``: 0 exercises the factorizable graph lower; + 2 exercises the carry-all ATTENTION graph lower, whose compact pair + enumeration exports via unbacked SymInts (``xp_hint_dynamic_size``). + ``smooth_type_embedding`` stays False: the smooth dense reference keeps + sel-padding in its softmax denominator, so dense==carry-all parity holds + only for the non-smooth branch (PR-D divergence decision). + + The AOTI compile is slow (~90 s), so it is done once per param. The eager pt_expt model is returned alongside the archive path to serve as the dense parity reference. """ @@ -165,7 +172,10 @@ def graph_pt2(): get_model, ) - model = get_model(copy.deepcopy(DPA1_CONFIG)).to(torch.float64) + config = copy.deepcopy(DPA1_CONFIG) + config["descriptor"]["attn_layer"] = request.param + config["descriptor"]["smooth_type_embedding"] = False + model = get_model(config).to(torch.float64) model.eval() data = {"model": model.serialize()} @@ -264,3 +274,24 @@ def test_graph_pt2_deepeval_vesin_matches_dense(graph_pt2, pbc) -> None: np.testing.assert_allclose(e_v, e_d, rtol=1e-10, atol=1e-10, err_msg="energy") np.testing.assert_allclose(f_v, f_d, rtol=1e-10, atol=1e-10, err_msg="force") np.testing.assert_allclose(v_v, v_d, rtol=1e-10, atol=1e-10, err_msg="virial") + + +def test_graph_pt2_single_atom_no_edges(graph_pt2) -> None: + """A single isolated atom (zero real edges) evaluates through the ``.pt2``. + + The graph builder emits only masked guard edges here, so at runtime the + compact pair enumeration sees ``R == 0`` real edges — the empty extreme of + the unbacked-SymInt sizes the attention export carries. Energy must match + the eager dense reference and the force must be (numerically) zero. + """ + pt2_path, model = graph_pt2 + coords = np.array([[[9.0, 9.0, 9.0]]]) + atype = np.array([0], dtype=np.int32) + + dp = DeepPot(pt2_path) + e, f, v = dp.eval(coords, None, atype)[:3] + ref = _eager_dense_reference(model, coords, None, atype) + np.testing.assert_allclose( + e.reshape(-1), ref["energy"].reshape(-1), rtol=1e-10, atol=1e-10 + ) + np.testing.assert_allclose(f.reshape(-1), 0.0, atol=1e-12) diff --git a/source/tests/pt_expt/model/test_dpa1_graph_lower.py b/source/tests/pt_expt/model/test_dpa1_graph_lower.py index e274a1bcec..11c49cdca2 100644 --- a/source/tests/pt_expt/model/test_dpa1_graph_lower.py +++ b/source/tests/pt_expt/model/test_dpa1_graph_lower.py @@ -91,7 +91,7 @@ def setup_method(self) -> None: [[0, 0, 0, 1, 1]], dtype=torch.int64, device=self.device ) - def _make_model(self) -> EnergyModel: + def _make_model(self, attn_layer: int = 0, smooth: bool = False) -> EnergyModel: ds = DescrptDPA1( self.rcut, self.rcut_smth, @@ -100,9 +100,13 @@ def _make_model(self) -> EnergyModel: neuron=[3, 6], axis_neuron=2, attn=4, - attn_layer=0, # graph lower only supports attn_layer == 0 + attn_layer=attn_layer, attn_dotr=True, attn_mask=False, + # smooth attention keeps sel-padding in the dense softmax + # denominator; the carry-all graph drops it BY DESIGN (PR-D), so + # exact graph-vs-dense parity requires smooth=False here. + smooth_type_embedding=smooth, activation_function="tanh", set_davg_zero=False, type_one_side=True, @@ -165,13 +169,16 @@ def _prepare_lower_inputs(self, periodic: bool): mapping_t = torch.tensor(mapping, dtype=torch.int64, device=self.device) return ext_coord, ext_atype, nlist_t, mapping_t + @pytest.mark.parametrize("attn_layer", [0, 2]) # factorizable AND attention @pytest.mark.parametrize("periodic", [True, False]) # PBC vs non-PBC @pytest.mark.parametrize("do_av", [False, True]) # atom-virial off / on - def test_force_virial_parity_vs_legacy(self, periodic, do_av) -> None: + def test_force_virial_parity_vs_legacy(self, periodic, do_av, attn_layer) -> None: """Graph lower energy/force/virial/atom_virial == legacy dense lower on the SAME neighbor set (regime-1 graph from from_dense_quartet). + attn_layer=2 exercises graph attention through model-level autograd + (smooth=False: exact carry-all parity regime, NeighborGraph PR-D). """ - model = self._make_model() + model = self._make_model(attn_layer=attn_layer) model.eval() tol = ( {"rtol": 1e-12, "atol": 1e-12} @@ -230,3 +237,152 @@ def test_force_virial_parity_vs_legacy(self, periodic, do_av) -> None: ) graph_av = graph["energy_derv_c"].reshape(nf, nloc, 1, 9) torch.testing.assert_close(graph_av, legacy_av_local, **tol) + + @pytest.mark.parametrize("attn_layer", [0, 2]) # factorizable AND attention + def test_graph_lower_symbolic_trace(self, attn_layer) -> None: + """``forward_lower_graph_exportable`` traces symbolically for BOTH the + factorizable (attn_layer=0) and attention (attn_layer=2) graph lowers, + and the traced module reproduces the eager graph lower bit-tight. + + attn_layer > 0 exercises the carry-all compact pair enumeration + (``center_edge_pairs`` with ``static_nnei=None``) under make_fx + symbolic tracing: its ``nonzero``/tensor-``repeat`` output sizes are + UNBACKED SymInts, registered via ``xp_hint_dynamic_size`` — the + mechanism that makes the attention graph lower ``.pt2``-exportable. + """ + from deepmd.pt_expt.utils.serialization import ( + build_synthetic_graph_inputs, + ) + + # The real .pt2 export (``deserialize_to_file``) traces on CPU: it does + # ``model.to("cpu")`` and builds CPU synthetic inputs. Mirror that here so + # model params and the traced inputs share a device -- otherwise, on a + # CUDA runner, the CUDA params meet the CPU graph tensors and FakeTensor + # device propagation raises for aten.index_select. + model = self._make_model(attn_layer=attn_layer).to("cpu") + model.eval() + sample = build_synthetic_graph_inputs( + model, + e_max=175, + nframes=2, + nloc=7, + dtype=torch.float64, + device=torch.device("cpu"), + ) + atype, n_node, ei, ev, em, fp, ap, cs = sample + traced = model.forward_lower_graph_exportable( + atype, + n_node, + ei, + ev, + em, + fparam=fp, + aparam=ap, + do_atomic_virial=True, + charge_spin=cs, + tracing_mode="symbolic", + _allow_non_fake_inputs=True, + ) + out = traced(atype, n_node, ei, ev, em, fp, ap, cs) + ref = model.forward_common_lower_graph( + atype, n_node, ei, ev, em, fparam=fp, aparam=ap, do_atomic_virial=True + ) + tol = {"rtol": 1e-12, "atol": 1e-12} + torch.testing.assert_close(out["energy"], ref["energy_redu"], **tol) + torch.testing.assert_close( + out["force"], ref["energy_derv_r"].reshape(out["force"].shape), **tol + ) + torch.testing.assert_close( + out["virial"], ref["energy_derv_c_redu"].reshape(out["virial"].shape), **tol + ) + + def test_smooth_attention_divergence_pinned(self) -> None: + """End-to-end: the pt_expt DEFAULT route (carry-all graph) diverges + from the dense route for ``smooth_type_embedding=True`` + attention — + nonzero and bounded by the documented ~1e-4 magnitude. + + The carry-all graph drops sel-padding phantom terms from the smooth + attention softmax denominator BY DESIGN (NeighborGraph PR-D), while + the dense path keeps them, so dense output is sel-dependent. This + test pins that divergence at the public model forward so a future + refactor cannot silently change the carry-all smooth semantics. + ``neighbor_graph_method="legacy"`` is the escape hatch restoring the + dense numbers; the parity tests above cover the smooth=False regime + where the two routes agree bit-tight. + """ + model = self._make_model(attn_layer=2, smooth=True) + model.eval() + coord = self.coord.clone().requires_grad_(True) + box = self.cell.reshape(1, 9) + # None = the default flip: graph-eligible mixed_types -> carry-all graph + graph = model.call_common(coord, self.atype, box, neighbor_graph_method=None) + dense = model.call_common( + self.coord.clone().requires_grad_(True), + self.atype, + box, + neighbor_graph_method="legacy", + ) + e_diff = (graph["energy_redu"] - dense["energy_redu"]).abs().max().item() + f_diff = (graph["energy_derv_r"] - dense["energy_derv_r"]).abs().max().item() + # nonzero: well above fp64 accumulation noise of a bit-tight parity + assert e_diff > 1e-10, f"expected smooth divergence, got {e_diff:.3e}" + # bounded: the documented magnitude is ~1e-4; 1e-3 leaves headroom + assert e_diff < 1e-3, f"smooth divergence too large: {e_diff:.3e}" + assert f_diff < 1e-3, f"smooth force divergence too large: {f_diff:.3e}" + + @pytest.mark.parametrize("attn_layer", [0, 2]) # factorizable AND attention + def test_graph_route_float32(self, attn_layer) -> None: + """A float32 model runs the graph route and matches the dense route. + + The descriptor-level ``call_graph`` casts ``edge_vec`` to the + descriptor precision manually (``@cast_precision`` cannot see inside + the NeighborGraph dataclass); without it, fp32 models crash with a + double-vs-float matmul on the graph route while the dense route works. + fp32 accumulation-order differences bound the tolerance (1e-6/1e-5), + per the fp32-computation guidance. + """ + from deepmd.pt_expt.descriptor.dpa1 import DescrptDPA1 as _D + from deepmd.pt_expt.fitting import InvarFitting as _F + + ds = _D( + self.rcut, + self.rcut_smth, + self.sel, + self.nt, + neuron=[3, 6], + axis_neuron=2, + attn=4, + attn_layer=attn_layer, + attn_dotr=True, + smooth_type_embedding=False, + precision="float32", + seed=GLOBAL_SEED, + ).to(self.device) + ft = _F( + "energy", + self.nt, + ds.get_dim_out(), + 1, + mixed_types=True, + precision="float32", + seed=GLOBAL_SEED, + ).to(self.device) + model = EnergyModel(ds, ft, type_map=self.type_map).to(self.device) + model.eval() + graph = model.call_common( + self.coord.clone().requires_grad_(True), + self.atype, + self.cell.reshape(1, 9), + neighbor_graph_method="dense", + ) + dense = model.call_common( + self.coord.clone().requires_grad_(True), + self.atype, + self.cell.reshape(1, 9), + neighbor_graph_method="legacy", + ) + tol = {"rtol": 1e-5, "atol": 1e-6} + torch.testing.assert_close(graph["energy_redu"], dense["energy_redu"], **tol) + torch.testing.assert_close( + graph["energy_derv_r"], dense["energy_derv_r"], **tol + ) diff --git a/source/tests/pt_expt/model/test_linear_model.py b/source/tests/pt_expt/model/test_linear_model.py index a18cabd9e1..8c95f3e69e 100644 --- a/source/tests/pt_expt/model/test_linear_model.py +++ b/source/tests/pt_expt/model/test_linear_model.py @@ -343,6 +343,12 @@ def test_forward_lower_exportable(self) -> None: "temperature": 1.0, "set_davg_zero": True, "type_one_side": True, + # smooth attention diverges between the graph default (standard model, + # carry-all: no phantom sel-padding softmax terms) and the dense route + # (linear models are graph-ineligible) by design (NeighborGraph PR-D); + # pin smooth off so both routes are exact and the weight-combination + # comparison stays at 1e-10. + "smooth_type_embedding": False, "seed": 1, }, "fitting_net": { diff --git a/source/tests/pt_expt/test_finetune.py b/source/tests/pt_expt/test_finetune.py index b000c313a3..a97e8e208c 100644 --- a/source/tests/pt_expt/test_finetune.py +++ b/source/tests/pt_expt/test_finetune.py @@ -406,7 +406,12 @@ def test_finetune_change_type(self) -> None: sampled = make_stat_input(data, nbatches=1) ntest = 1 - prec = 1e-10 + # The two type-map-remapped DPA1 models are mixed_types, so each + # routes through the carry-all graph path. Its ``segment_sum`` + # lowers to ``torch.index_add``: bit-exact on CPU, but the two + # models' different edge orderings accumulate via CUDA atomicAdd + # in different (non-deterministic) orders, giving ~1e-7 diffs. + prec = 1e-10 if DEVICE.type == "cpu" else 1e-5 box = torch.tensor( sampled[0]["box"][:ntest], dtype=torch.float64, device=DEVICE ) diff --git a/source/tests/pt_expt/utils/test_graph_pt2_metadata.py b/source/tests/pt_expt/utils/test_graph_pt2_metadata.py index a541f744cc..54aa9f688d 100644 --- a/source/tests/pt_expt/utils/test_graph_pt2_metadata.py +++ b/source/tests/pt_expt/utils/test_graph_pt2_metadata.py @@ -132,3 +132,74 @@ def test_neighbor_graph_method_rejected_on_nlist_artifact(dpa1_dpmodel_data) -> DeepPot(p, neighbor_graph_method="vesin") # the default stays accepted (no behavior change) DeepPot(p) + + +class _FakeDesc: + def __init__(self, n_attn: int) -> None: + self._n = n_attn + + def get_numb_attn_layer(self) -> int: + return self._n + + +class _FakeAtomicModel: + def __init__(self, n_attn: int) -> None: + self.descriptor = _FakeDesc(n_attn) + + +class _FakeModel: + def __init__(self, n_attn: int) -> None: + self.atomic_model = _FakeAtomicModel(n_attn) + + +@pytest.mark.parametrize( + "version", ["2.5.1", "2.5.1+cu124"] +) # torch below the 2.6 floor +def test_graph_trace_version_guard_rejects_attention_on_old_torch( + monkeypatch, version +) -> None: + """attn_layer > 0 on torch < 2.6 fails fast with a clear message.""" + import torch + + from deepmd.pt_expt.utils.serialization import ( + check_graph_trace_torch_version, + ) + + monkeypatch.setattr(torch, "__version__", version) + with pytest.raises(RuntimeError, match=r"torch >= 2\.6"): + check_graph_trace_torch_version(_FakeModel(2)) + + +@pytest.mark.parametrize( + ("version", "n_attn"), + [ + ("2.5.1", 0), # old torch OK without attention (backed symbols only) + ("2.6.0", 2), # floor version with attention + ("2.10.0+cu126", 2), # current torch with attention, local suffix + ], +) +def test_graph_trace_version_guard_passes(monkeypatch, version, n_attn) -> None: + """No-attention models and torch >= 2.6 pass the guard silently.""" + import torch + + from deepmd.pt_expt.utils.serialization import ( + check_graph_trace_torch_version, + ) + + monkeypatch.setattr(torch, "__version__", version) + check_graph_trace_torch_version(_FakeModel(n_attn)) + + +def test_graph_trace_version_guard_tolerates_no_descriptor(monkeypatch) -> None: + """Composite models without a single descriptor pass (dense route anyway).""" + import torch + + from deepmd.pt_expt.utils.serialization import ( + check_graph_trace_torch_version, + ) + + class _NoDesc: + pass + + monkeypatch.setattr(torch, "__version__", "2.5.1") + check_graph_trace_torch_version(_NoDesc()) diff --git a/source/tests/pt_expt/utils/test_neighbor_list.py b/source/tests/pt_expt/utils/test_neighbor_list.py index 26bf5c19da..ff29ed50c4 100644 --- a/source/tests/pt_expt/utils/test_neighbor_list.py +++ b/source/tests/pt_expt/utils/test_neighbor_list.py @@ -113,6 +113,11 @@ "temperature": 1.0, "set_davg_zero": True, "type_one_side": True, + # smooth attention diverges between the carry-all graph default + # (neighbor_list=None) and the explicit World-1 builders by design + # (NeighborGraph PR-D: dense keeps sel-padding in the attention + # softmax denominator); pin smooth off so all routes are exact. + "smooth_type_embedding": False, "seed": 1, }, "fitting_net": {"neuron": [8, 8], "resnet_dt": True, "seed": 1},