Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions deepmd/dpmodel/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,25 @@ def xp_add_at(x: Array, indices: Array, values: Array) -> Array:
return x


def xp_hint_dynamic_size(x: Array) -> None:
"""Mark a data-dependent leading dimension as a valid size for torch.export.
Under symbolic tracing (``make_fx`` / ``torch.export``) the length of a
data-dependent array (e.g. the output of ``nonzero`` or a tensor-``repeat``)
is an UNBACKED SymInt; guarding Python control flow or allocations on it
raises ``GuardOnDataDependentSymNode``. ``torch._check_is_size`` registers
the ``>= 0`` size hint that lets the tracer treat it as a proper dimension
(recorded as a ``sym_constrain_range_for_size`` node, preserved by AOTI).
No-op for numpy / jax / eager-torch concrete shapes — safe to call
unconditionally from dpmodel code (torch imported lazily, torch arrays only).
"""
if array_api_compat.is_torch_array(x):
import torch

torch._check_is_size(x.shape[0])


def xp_maximum_at(x: Array, indices: Array, values: Array) -> Array:
"""Segment max-assign of values into x at the specified indices.
Expand Down
222 changes: 202 additions & 20 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)
from deepmd.dpmodel.common import (
cast_precision,
get_xp_precision,
to_numpy_array,
)
from deepmd.dpmodel.utils import (
Expand Down Expand Up @@ -428,15 +429,20 @@ def get_numb_attn_layer(self) -> int:
def uses_graph_lower(self) -> bool:
"""Returns whether this descriptor supports the graph-native lower.

The graph-native energy lower (``call_graph``) currently covers only the
non-attention (``attn_layer == 0``) factorizable path with concat
type-embedding and no type exclusion. Any other config (attention,
``tebd_input_mode == "strip"``, ``exclude_types``) falls back to the
legacy dense path, so those models keep working unchanged.
The graph-native lower (``call_graph``) covers the factorizable path
AND transformer attention (``attn_layer >= 0``, NeighborGraph PR-D)
with concat type-embedding and no type exclusion. Remaining ineligible
configs (``tebd_input_mode == "strip"``, ``exclude_types``) fall back
to the legacy dense path, so those models keep working unchanged.

Eligibility does NOT imply numerical interchangeability with the
dense route for every config: with ``smooth_type_embedding=True``
the carry-all graph attention is sel-independent by design and
differs from the dense lower by up to ~1e-4 (see the Notes of
:meth:`call_graph`).
"""
return (
self.se_atten.attn_layer == 0
and self.se_atten.tebd_input_mode == "concat"
self.se_atten.tebd_input_mode == "concat"
and not self.se_atten.exclude_types
)

Expand Down Expand Up @@ -643,6 +649,9 @@ def _call_graph_adapter(
graph,
atype_local,
type_embedding=self.type_embedding.call(),
# the adapter graph is shape-static center-major (compact=False):
# keep the attention pair enumeration nonzero-free (traceable)
static_nnei=nnei,
)
# call_graph returns flat (N, ...) node axis; reshape to (nf, nloc, ...)
# for the dense 5-tuple ABI -- this reshape is LOCAL to the adapter shim.
Expand Down Expand Up @@ -727,8 +736,9 @@ def call_graph(
graph: Any,
atype: Array,
type_embedding: Array | None = None,
static_nnei: int | None = None,
) -> tuple[Array, Array]:
"""Descriptor-level graph-native forward (``attn_layer == 0``).
"""Descriptor-level graph-native forward.

Wraps the block kernel
:meth:`DescrptBlockSeAtten.call_graph`, adds the descriptor-level
Expand All @@ -740,6 +750,21 @@ def call_graph(
not produce the dense ``sw`` (that lives in the dense :meth:`call`
adapter, which has the ``nlist``/``coord_ext`` needed to build it).

Notes
-----
**Smooth attention is intentionally sel-independent on the graph
path.** For ``smooth_type_embedding=True`` the legacy dense attention
keeps the sel-padding slots in its softmax DENOMINATOR (phantom
``exp(-attnw_shift)`` terms), which makes dense output depend on the
``sel`` setting by up to ~1e-4 even for identical physical neighbors.
A carry-all graph has no padding slots, so its softmax runs over the
real neighbor pairs only: cleaner, sel-independent semantics that
deliberately DIFFER from the dense route for smooth models. The two
routes agree bit-tight only for ``smooth_type_embedding=False`` (at
non-binding ``sel``), or when this kernel is realized on a dense
layout via ``static_nnei`` (the dense :meth:`call` adapter), which
reproduces the phantom terms for exact backward compatibility.

Parameters
----------
graph
Expand All @@ -757,10 +782,21 @@ def call_graph(
(N, ng, 3) equivariant single-particle representation, flat node
axis.
"""
import dataclasses

xp = array_api_compat.array_namespace(graph.edge_vec)
dev = array_api_compat.device(graph.edge_vec)
# manual @cast_precision: the decorator casts array ARGUMENTS, but the
# graph's only float input (edge_vec) is inside the NeighborGraph
# dataclass, invisible to it. Cast edge_vec down to the descriptor
# precision on entry and the outputs back to the caller's dtype on
# exit (differentiable: grad still flows to the caller's edge_vec leaf).
in_dtype = graph.edge_vec.dtype
prec = get_xp_precision(xp, self.precision)
if in_dtype != prec:
graph = dataclasses.replace(graph, edge_vec=xp.astype(graph.edge_vec, prec))
grrg, rot_mat = self.se_atten.call_graph(
graph, atype, type_embedding=type_embedding
graph, atype, type_embedding=type_embedding, static_nnei=static_nnei
)
# FLAT node axis (N, ...): no (nf, nloc) reshape -- ragged-native, spec.
if self.concat_output_tebd:
Expand All @@ -772,6 +808,9 @@ def call_graph(
atype_local = xp.asarray(atype, device=dev)
atype_embd = xp.take(type_embedding, atype_local, axis=0) # (N, tebd_dim)
grrg = xp.concat([grrg, atype_embd], axis=-1)
if in_dtype != prec:
grrg = xp.astype(grrg, in_dtype)
rot_mat = xp.astype(rot_mat, in_dtype)
return grrg, rot_mat

def enable_compression(
Expand Down Expand Up @@ -1670,12 +1709,15 @@ def call_graph(
graph: Any,
atype: Array,
type_embedding: Array | None = None,
static_nnei: int | None = None,
) -> tuple[Array, Array]:
"""Graph-native forward (``attn_layer=0`` only).
"""Graph-native forward.

Bit-exact analogue of :meth:`call` on the SAME neighbor list, with the
neighbor-axis reduction replaced by a ``segment_sum`` over edge centers
(``dst``). Geometry enters only through ``graph.edge_vec``.
(``dst``) and the dense ``(nnei, nnei)`` transformer attention replaced
by pairs of edges sharing a center (``center_edge_pairs`` +
``segment_softmax``). Geometry enters only through ``graph.edge_vec``.

Parameters
----------
Expand All @@ -1687,6 +1729,12 @@ def call_graph(
(N,) flat node atom types (``N = sum(graph.n_node)``).
type_embedding
(ntypes_with_padding, tebd_dim) type-embedding table.
static_nnei
When the graph uses the shape-static center-major layout
(``from_dense_quartet(compact=False)``, ``E = n_center * nnei``),
pass ``nnei`` so the attention edge-pair enumeration stays
jit/export-traceable (no ``nonzero``). ``None`` (carry-all /
compact graphs) selects the dynamic eager form.

Returns
-------
Expand All @@ -1699,8 +1747,7 @@ def call_graph(

Notes
Comment thread
OutisLi marked this conversation as resolved.
-----
Known limitations (NeighborGraph PR-A):
- ``attn_layer == 0`` only (attention lands in PR-D);
Known limitations:
- ``tebd_input_mode == "concat"`` only (strip mode lands later);
- ``exclude_types`` is not yet supported and raises (lands in a later PR).
"""
Expand All @@ -1709,11 +1756,6 @@ def call_graph(
segment_sum,
)

if self.attn_layer != 0:
raise NotImplementedError(
"graph path supports attn_layer=0 only (NeighborGraph PR-A); "
"attn_layer>0 lands in PR-D"
)
if self.tebd_input_mode not in ["concat"]:
raise NotImplementedError(
"graph path supports tebd_input_mode='concat' only (NeighborGraph PR-A)"
Expand All @@ -1738,7 +1780,7 @@ def call_graph(
# per-edge env-mat 4-vector, normalized by the center (dst) atom type.
# self.mean/self.stddev are slot-independent (ntypes, nnei, 4); slot 0 is
# the canonical per-type vector.
rr = edge_env_mat(
rr, sw_e = edge_env_mat(
graph.edge_vec,
center_type,
self.mean[:, 0, :],
Expand All @@ -1747,7 +1789,8 @@ def call_graph(
self.rcut_smth,
protection=self.env_protection,
edge_mask=graph.edge_mask,
) # (E, 4)
return_sw=True,
) # (E, 4), (E, 1) sw zeroed on padding
# radial channel
ss = rr[:, 0:1] # (E, 1)
# neighbor / center type embeddings (concat mode); ghost type == owner type
Expand All @@ -1764,6 +1807,13 @@ def call_graph(
ss = xp.concat([ss, atype_embd_nlist], axis=-1)
# embedding net (same weights as the dense path); applies on the last axis
gg = self.embeddings[0].call(ss) # (E, ng)
# transformer attention over each center's edges — mirrors the dense
# self.dpa1_attention(gg, nlist_mask, input_r, sw), which also runs on
# the UNMASKED gg (padding rows are neutralized afterwards).
if self.attn_layer > 0:
gg = self._graph_attention(
gg, rr, dst, n_total, graph.edge_mask, sw_e, static_nnei
)
# zero padding/guard edges BEFORE the segment sum
gg = gg * xp.astype(graph.edge_mask[:, None], gg.dtype)
# outer product (replaces the dense gg[:,:,:,None] * rr[:,:,None,:])
Expand All @@ -1784,6 +1834,138 @@ def call_graph(
rot_mat = gr[:, :, 1:]
return grrg, rot_mat

def _graph_attention(
self,
gg: Array,
rr: Array,
dst: Array,
n_total: int,
edge_mask: Array,
sw_e: Array,
static_nnei: int | None,
) -> Array:
"""Graph-native transformer attention over each center's edges.

Ragged reproduction of :class:`NeighborGatedAttention` /
:class:`GatedAttentionLayer`: edges sharing a center attend to each
other. The dense ``(nnei, nnei)`` square per center becomes the
edge-pair axis from ``center_edge_pairs(ordered=True,
include_self=True)``; softmax over the key axis becomes
``segment_softmax`` grouped by the query edge.

Parameters
----------
gg : (E, ng) per-edge embedding (UNMASKED, as in the dense path).
rr : (E, 4) per-edge env-mat vector (``rr[:, 1:4]`` carries direction).
dst : (E,) center of each edge.
n_total : number of centers.
edge_mask : (E,) real-vs-padding edge mask.
sw_e : (E, 1) smooth switch, zeroed on padding edges.
static_nnei : shape-static layout ``nnei`` or ``None`` (compact eager).
"""
from deepmd.dpmodel.utils.neighbor_graph import (
center_edge_pairs,
)

xp = array_api_compat.array_namespace(gg)
# per-edge normalized direction (mirrors the dense input_r,
# rr[..., 1:4] / max(|rr[..., 1:4]|, 1e-12))
dir3 = rr[:, 1:4]
normed = safe_for_vector_norm(dir3, axis=-1, keepdims=True)
input_r = dir3 / xp.maximum(normed, xp.full_like(normed, 1e-12)) # (E, 3)
# transformer neighbor-pairs: full ordered square incl. the diagonal
# (q_m . k_n is not symmetric and self-attention keeps m == n)
q_e, k_e, pair_mask = center_edge_pairs(
dst,
edge_mask,
n_total,
include_self=True,
ordered=True,
static_nnei=static_nnei,
)
for layer in self.dpa1_attention.attention_layers:
gg = self._graph_attention_one_layer(
layer, gg, input_r, sw_e, q_e, k_e, pair_mask
)
return gg

def _graph_attention_one_layer(
self,
layer: "NeighborGatedAttentionLayer",
gg: Array,
input_r: Array,
sw_e: Array,
q_e: Array,
k_e: Array,
pair_mask: Array,
) -> Array:
"""One residual attention layer, op-for-op vs the dense reference.

Mirrors ``NeighborGatedAttentionLayer.call`` (residual +
``GatedAttentionLayer.call`` + LayerNorm). Structural translation:
per-center ``q @ k^T`` -> per-pair ``q_m . k_n``; softmax over the key
axis -> ``segment_softmax`` grouped by the query edge. The smooth
branch keeps padding pairs IN the softmax denominator with ``sw = 0``
(weight ``exp(-attnw_shift)``), exactly like the dense branch, which
replaces the ``-inf`` masking by the switch weighting.
"""
from deepmd.dpmodel.utils.neighbor_graph import (
segment_softmax,
segment_sum,
)

xp = array_api_compat.array_namespace(gg)
e_tot = gg.shape[0]
gal = layer.attention_layer # GatedAttentionLayer
if gal.num_heads != 1:
raise NotImplementedError(
"graph attention assumes num_heads == 1 (dpa1 never exposes "
"num_heads; the dense head_dim QKV slicing relies on it)"
)
hd = gal.head_dim # == hidden_dim for num_heads == 1
residual = gg
# in_proj -> Q, K, V; mirror the dense HEAD_DIM slicing exactly
qkv = gal.in_proj.call(gg) # (E, 3 * hidden)
q = qkv[:, 0:hd]
k = qkv[:, hd : hd * 2]
v = qkv[:, hd * 2 : hd * 3]
if gal.normalize:
q = np_normalize(q, axis=-1)
k = np_normalize(k, axis=-1)
v = np_normalize(v, axis=-1)
q = q * gal.scaling
# per-pair logits q_m . k_n (num_heads == 1)
logits = xp.sum(
xp.take(q, q_e, axis=0) * xp.take(k, k_e, axis=0), axis=-1
) # (P,)
if gal.smooth:
# (logits + shift) * sw_m * sw_n - shift, then softmax WITHOUT the
# pair mask: padding pairs stay in the denominator at exp(-shift),
# mirroring the dense smooth branch (sw already zeroed on padding).
attnw_shift = 20.0 # dense GatedAttentionLayer.call default
sw_flat = sw_e[:, 0] # (E,)
sw_q = xp.take(sw_flat, q_e, axis=0)
sw_k = xp.take(sw_flat, k_e, axis=0)
logits = (logits + attnw_shift) * sw_q * sw_k - attnw_shift
w = segment_softmax(logits, q_e, e_tot) # (P,)
w = w * sw_q * sw_k
else:
# non-smooth: dense masks padding keys to -inf pre-softmax ==
# excluding them from the softmax entirely
w = segment_softmax(logits, q_e, e_tot, mask=pair_mask)
if gal.dotr:
angular = xp.sum(
xp.take(input_r, q_e, axis=0) * xp.take(input_r, k_e, axis=0),
axis=-1,
) # (P,) = input_r_m . input_r_n
w = w * angular
# o_m = sum_n w[m, n] v[n] -> segment_sum over the query edge
wv = w[:, None] * xp.take(v, k_e, axis=0) # (P, hd)
o = segment_sum(wv, q_e, e_tot) # (E, hd)
out = gal.out_proj.call(o) # (E, ng)
x = residual + out
return layer.attn_layer_norm.call(x)

def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
return False
Expand Down
14 changes: 9 additions & 5 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,10 +313,14 @@ def call_common(

The graph routes (``"dense"``/``"ase"``, and the pt_expt
default-flip) require a ``mixed_types`` descriptor with a graph
lower (dpa1 ``attn_layer == 0``). At non-binding ``sel`` the
graph matches the dense path exactly; at binding ``sel`` the
carry-all graph keeps neighbors the dense path truncates, so the
energy intentionally differs.
lower (dpa1/se_atten with concat type embedding and no
``exclude_types``; attention layers included). At non-binding
``sel`` the graph matches the dense path exactly for the
non-smooth branch; at binding ``sel`` the carry-all graph keeps
neighbors the dense path truncates, and for
``smooth_type_embedding=True`` the graph drops the dense
layout's sel-padding softmax terms, so the energy intentionally
differs (sel-independent graph semantics).

Returns
-------
Expand Down Expand Up @@ -688,7 +692,7 @@ def call_common_lower_graph(
comm_dict: dict | None = None,
charge_spin: Array | None = None,
) -> dict[str, Array]:
"""Graph-native PUBLIC lower (PR-A: dpa1 ``attn_layer == 0``).
"""Graph-native PUBLIC lower (dpa1/se_atten concat-tebd, attention included).

The PRIMARY directly-callable graph interface (spec decision #14).
Casts inputs/outputs to/from the model precision exactly like the
Expand Down
Loading
Loading