Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
f322d78
feat(dpmodel): segment_max + numerically-stable mask-aware segment_so…
Jul 2, 2026
f4b7141
feat(dpmodel): center_edge_pairs primitive (shared by attention/angles)
Jul 2, 2026
cfb0cb8
feat(dpmodel): graph-native se_atten transformer attention (attn_laye…
Jul 2, 2026
e069d2d
test(pt_expt): graph attention make_fx (merge gate) + model force/vir…
Jul 2, 2026
480cd61
test: binding-sel audit for graph-default attention models
Jul 2, 2026
a1061af
test(pt_expt): pin smooth off in neighbor-list dpa1 fixture (route pa…
Jul 2, 2026
2560148
feat(dpmodel): pad_and_guard_angles angle-axis padder
Jul 3, 2026
0b1b49f
feat(dpmodel): build_angle_index (edge pairs within a_rcut)
Jul 3, 2026
1e134a1
test(dpmodel): build_angle_index multi-center, static-layout, ordered…
Jul 3, 2026
2ccd120
feat(dpmodel): attach_angles post-hoc angle attachment
Jul 3, 2026
f7dc3be
test(dpmodel): attach_angles layout.node_capacity branch
Jul 3, 2026
27ec5cb
feat(dpmodel): angle->edge/node segment aggregation
Jul 3, 2026
a7c5aaa
feat(dpmodel): graph_angle_cos with dpa3 dense-parity
Jul 3, 2026
f39b79d
test(dpmodel): real se_t-convention oracle for graph_angle_cos; fix a…
Jul 3, 2026
66438ba
feat(dpmodel): angle-force invariance test + padding-fraction report
Jul 3, 2026
5e799da
test(dpmodel): angle_padding_fraction total==0 branch
Jul 3, 2026
19f0cf1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 3, 2026
ad73de0
test(pt_expt): restore sys.modules snapshot in plugin entry-point test
Jul 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 166 additions & 20 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,15 +428,14 @@ def get_numb_attn_layer(self) -> int:
def uses_graph_lower(self) -> bool:
"""Returns whether this descriptor supports the graph-native lower.

The graph-native energy lower (``call_graph``) currently covers only the
non-attention (``attn_layer == 0``) factorizable path with concat
type-embedding and no type exclusion. Any other config (attention,
``tebd_input_mode == "strip"``, ``exclude_types``) falls back to the
legacy dense path, so those models keep working unchanged.
The graph-native lower (``call_graph``) covers the factorizable path
AND transformer attention (``attn_layer >= 0``, NeighborGraph PR-D)
with concat type-embedding and no type exclusion. Remaining ineligible
configs (``tebd_input_mode == "strip"``, ``exclude_types``) fall back
to the legacy dense path, so those models keep working unchanged.
"""
return (
self.se_atten.attn_layer == 0
and self.se_atten.tebd_input_mode == "concat"
self.se_atten.tebd_input_mode == "concat"
and not self.se_atten.exclude_types
)

Expand Down Expand Up @@ -643,6 +642,9 @@ def _call_graph_adapter(
graph,
atype_local,
type_embedding=self.type_embedding.call(),
# the adapter graph is shape-static center-major (compact=False):
# keep the attention pair enumeration nonzero-free (traceable)
static_nnei=nnei,
)
# call_graph returns flat (N, ...) node axis; reshape to (nf, nloc, ...)
# for the dense 5-tuple ABI -- this reshape is LOCAL to the adapter shim.
Expand Down Expand Up @@ -727,8 +729,9 @@ def call_graph(
graph: Any,
atype: Array,
type_embedding: Array | None = None,
static_nnei: int | None = None,
) -> tuple[Array, Array]:
"""Descriptor-level graph-native forward (``attn_layer == 0``).
"""Descriptor-level graph-native forward.

Wraps the block kernel
:meth:`DescrptBlockSeAtten.call_graph`, adds the descriptor-level
Expand Down Expand Up @@ -760,7 +763,7 @@ def call_graph(
xp = array_api_compat.array_namespace(graph.edge_vec)
dev = array_api_compat.device(graph.edge_vec)
grrg, rot_mat = self.se_atten.call_graph(
graph, atype, type_embedding=type_embedding
graph, atype, type_embedding=type_embedding, static_nnei=static_nnei
)
# FLAT node axis (N, ...): no (nf, nloc) reshape -- ragged-native, spec.
if self.concat_output_tebd:
Expand Down Expand Up @@ -1670,12 +1673,15 @@ def call_graph(
graph: Any,
atype: Array,
type_embedding: Array | None = None,
static_nnei: int | None = None,
) -> tuple[Array, Array]:
"""Graph-native forward (``attn_layer=0`` only).
"""Graph-native forward.

Bit-exact analogue of :meth:`call` on the SAME neighbor list, with the
neighbor-axis reduction replaced by a ``segment_sum`` over edge centers
(``dst``). Geometry enters only through ``graph.edge_vec``.
(``dst``) and the dense ``(nnei, nnei)`` transformer attention replaced
by pairs of edges sharing a center (``center_edge_pairs`` +
``segment_softmax``). Geometry enters only through ``graph.edge_vec``.

Parameters
----------
Expand All @@ -1687,6 +1693,12 @@ def call_graph(
(N,) flat node atom types (``N = sum(graph.n_node)``).
type_embedding
(ntypes_with_padding, tebd_dim) type-embedding table.
static_nnei
When the graph uses the shape-static center-major layout
(``from_dense_quartet(compact=False)``, ``E = n_center * nnei``),
pass ``nnei`` so the attention edge-pair enumeration stays
jit/export-traceable (no ``nonzero``). ``None`` (carry-all /
compact graphs) selects the dynamic eager form.

Returns
-------
Expand All @@ -1699,8 +1711,7 @@ def call_graph(

Notes
-----
Known limitations (NeighborGraph PR-A):
- ``attn_layer == 0`` only (attention lands in PR-D);
Known limitations:
- ``tebd_input_mode == "concat"`` only (strip mode lands later);
- ``exclude_types`` is not yet supported and raises (lands in a later PR).
"""
Expand All @@ -1709,11 +1720,6 @@ def call_graph(
segment_sum,
)

if self.attn_layer != 0:
raise NotImplementedError(
"graph path supports attn_layer=0 only (NeighborGraph PR-A); "
"attn_layer>0 lands in PR-D"
)
if self.tebd_input_mode not in ["concat"]:
raise NotImplementedError(
"graph path supports tebd_input_mode='concat' only (NeighborGraph PR-A)"
Expand All @@ -1738,7 +1744,7 @@ def call_graph(
# per-edge env-mat 4-vector, normalized by the center (dst) atom type.
# self.mean/self.stddev are slot-independent (ntypes, nnei, 4); slot 0 is
# the canonical per-type vector.
rr = edge_env_mat(
rr, sw_e = edge_env_mat(
graph.edge_vec,
center_type,
self.mean[:, 0, :],
Expand All @@ -1747,7 +1753,8 @@ def call_graph(
self.rcut_smth,
protection=self.env_protection,
edge_mask=graph.edge_mask,
) # (E, 4)
return_sw=True,
) # (E, 4), (E, 1) sw zeroed on padding
# radial channel
ss = rr[:, 0:1] # (E, 1)
# neighbor / center type embeddings (concat mode); ghost type == owner type
Expand All @@ -1764,6 +1771,13 @@ def call_graph(
ss = xp.concat([ss, atype_embd_nlist], axis=-1)
# embedding net (same weights as the dense path); applies on the last axis
gg = self.embeddings[0].call(ss) # (E, ng)
# transformer attention over each center's edges — mirrors the dense
# self.dpa1_attention(gg, nlist_mask, input_r, sw), which also runs on
# the UNMASKED gg (padding rows are neutralized afterwards).
if self.attn_layer > 0:
gg = self._graph_attention(
gg, rr, dst, n_total, graph.edge_mask, sw_e, static_nnei
)
# zero padding/guard edges BEFORE the segment sum
gg = gg * xp.astype(graph.edge_mask[:, None], gg.dtype)
# outer product (replaces the dense gg[:,:,:,None] * rr[:,:,None,:])
Expand All @@ -1784,6 +1798,138 @@ def call_graph(
rot_mat = gr[:, :, 1:]
return grrg, rot_mat

def _graph_attention(
self,
gg: Array,
rr: Array,
dst: Array,
n_total: int,
edge_mask: Array,
sw_e: Array,
static_nnei: int | None,
) -> Array:
"""Graph-native transformer attention over each center's edges.

Ragged reproduction of :class:`NeighborGatedAttention` /
:class:`GatedAttentionLayer`: edges sharing a center attend to each
other. The dense ``(nnei, nnei)`` square per center becomes the
edge-pair axis from ``center_edge_pairs(ordered=True,
include_self=True)``; softmax over the key axis becomes
``segment_softmax`` grouped by the query edge.

Parameters
----------
gg : (E, ng) per-edge embedding (UNMASKED, as in the dense path).
rr : (E, 4) per-edge env-mat vector (``rr[:, 1:4]`` carries direction).
dst : (E,) center of each edge.
n_total : number of centers.
edge_mask : (E,) real-vs-padding edge mask.
sw_e : (E, 1) smooth switch, zeroed on padding edges.
static_nnei : shape-static layout ``nnei`` or ``None`` (compact eager).
"""
from deepmd.dpmodel.utils.neighbor_graph import (
center_edge_pairs,
)

xp = array_api_compat.array_namespace(gg)
# per-edge normalized direction (mirrors the dense input_r,
# rr[..., 1:4] / max(|rr[..., 1:4]|, 1e-12))
dir3 = rr[:, 1:4]
normed = safe_for_vector_norm(dir3, axis=-1, keepdims=True)
input_r = dir3 / xp.maximum(normed, xp.full_like(normed, 1e-12)) # (E, 3)
# transformer neighbor-pairs: full ordered square incl. the diagonal
# (q_m . k_n is not symmetric and self-attention keeps m == n)
q_e, k_e, pair_mask = center_edge_pairs(
dst,
edge_mask,
n_total,
include_self=True,
ordered=True,
static_nnei=static_nnei,
)
for layer in self.dpa1_attention.attention_layers:
gg = self._graph_attention_one_layer(
layer, gg, input_r, sw_e, q_e, k_e, pair_mask
)
return gg

def _graph_attention_one_layer(
self,
layer: "NeighborGatedAttentionLayer",
gg: Array,
input_r: Array,
sw_e: Array,
q_e: Array,
k_e: Array,
pair_mask: Array,
) -> Array:
"""One residual attention layer, op-for-op vs the dense reference.

Mirrors ``NeighborGatedAttentionLayer.call`` (residual +
``GatedAttentionLayer.call`` + LayerNorm). Structural translation:
per-center ``q @ k^T`` -> per-pair ``q_m . k_n``; softmax over the key
axis -> ``segment_softmax`` grouped by the query edge. The smooth
branch keeps padding pairs IN the softmax denominator with ``sw = 0``
(weight ``exp(-attnw_shift)``), exactly like the dense branch, which
replaces the ``-inf`` masking by the switch weighting.
"""
from deepmd.dpmodel.utils.neighbor_graph import (
segment_softmax,
segment_sum,
)

xp = array_api_compat.array_namespace(gg)
e_tot = gg.shape[0]
gal = layer.attention_layer # GatedAttentionLayer
if gal.num_heads != 1:
raise NotImplementedError(
"graph attention assumes num_heads == 1 (dpa1 never exposes "
"num_heads; the dense head_dim QKV slicing relies on it)"
)
hd = gal.head_dim # == hidden_dim for num_heads == 1
residual = gg
# in_proj -> Q, K, V; mirror the dense HEAD_DIM slicing exactly
qkv = gal.in_proj.call(gg) # (E, 3 * hidden)
q = qkv[:, 0:hd]
k = qkv[:, hd : hd * 2]
v = qkv[:, hd * 2 : hd * 3]
if gal.normalize:
q = np_normalize(q, axis=-1)
k = np_normalize(k, axis=-1)
v = np_normalize(v, axis=-1)
q = q * gal.scaling
# per-pair logits q_m . k_n (num_heads == 1)
logits = xp.sum(
xp.take(q, q_e, axis=0) * xp.take(k, k_e, axis=0), axis=-1
) # (P,)
if gal.smooth:
# (logits + shift) * sw_m * sw_n - shift, then softmax WITHOUT the
# pair mask: padding pairs stay in the denominator at exp(-shift),
# mirroring the dense smooth branch (sw already zeroed on padding).
attnw_shift = 20.0 # dense GatedAttentionLayer.call default
sw_flat = sw_e[:, 0] # (E,)
sw_q = xp.take(sw_flat, q_e, axis=0)
sw_k = xp.take(sw_flat, k_e, axis=0)
logits = (logits + attnw_shift) * sw_q * sw_k - attnw_shift
w = segment_softmax(logits, q_e, e_tot) # (P,)
w = w * sw_q * sw_k
else:
# non-smooth: dense masks padding keys to -inf pre-softmax ==
# excluding them from the softmax entirely
w = segment_softmax(logits, q_e, e_tot, mask=pair_mask)
if gal.dotr:
angular = xp.sum(
xp.take(input_r, q_e, axis=0) * xp.take(input_r, k_e, axis=0),
axis=-1,
) # (P,) = input_r_m . input_r_n
w = w * angular
# o_m = sum_n w[m, n] v[n] -> segment_sum over the query edge
wv = w[:, None] * xp.take(v, k_e, axis=0) # (P, hd)
o = segment_sum(wv, q_e, e_tot) # (E, hd)
out = gal.out_proj.call(o) # (E, ng)
x = residual + out
return layer.attn_layer_norm.call(x)

def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
return False
Expand Down
24 changes: 24 additions & 0 deletions deepmd/dpmodel/utils/neighbor_graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@
See the design discussion wanghan-iapcm/deepmd-kit#4.
"""

from .angles import (
angle_padding_fraction,
angle_to_edge_sum,
angle_to_node_sum,
attach_angles,
build_angle_index,
graph_angle_cos,
)
from .ase_builder import (
build_neighbor_graph_ase,
)
Expand All @@ -30,25 +38,41 @@
NeighborGraph,
frame_id_from_n_node,
node_validity_mask,
pad_and_guard_angles,
pad_and_guard_edges,
)
from .pairs import (
center_edge_pairs,
)
from .segment import (
segment_max,
segment_mean,
segment_softmax,
segment_sum,
)

__all__ = [
"GraphLayout",
"NeighborGraph",
"angle_padding_fraction",
"angle_to_edge_sum",
"angle_to_node_sum",
"attach_angles",
"build_angle_index",
"build_neighbor_graph",
"build_neighbor_graph_ase",
"center_edge_pairs",
"edge_env_mat",
"edge_force_virial",
"frame_id_from_n_node",
"from_dense_quartet",
"graph_angle_cos",
"neighbor_graph_from_ijs",
"node_validity_mask",
"pad_and_guard_angles",
"pad_and_guard_edges",
"segment_max",
"segment_mean",
"segment_softmax",
"segment_sum",
]
Loading
Loading