Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
450 changes: 337 additions & 113 deletions deepmd/dpmodel/descriptor/dpa4.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions deepmd/dpmodel/descriptor/dpa4_nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
EnvironmentInitialEmbedding,
GeometricInitialEmbedding,
SeZMTypeEmbedding,
SpinEmbedding,
)
from .ffn import (
EquivariantFFN,
Expand Down Expand Up @@ -158,6 +159,7 @@
"ScalarRMSNorm",
"SeZMInteractionBlock",
"SeZMTypeEmbedding",
"SpinEmbedding",
"SwiGLU",
"WignerDCalculator",
"apply_lora_to_sezm",
Expand Down
33 changes: 21 additions & 12 deletions deepmd/dpmodel/descriptor/dpa4_nn/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ class GatedActivation(NativeOP):
Whether to use bias in the gate linear layer.
layout
Tensor layout convention. ``"nfdc"`` means input shape (N, F, D, C);
``"ndfc"`` means input shape (N, D, F, C).
``"ndfc"`` means input shape (N, D, F, C); ``"fndc"`` means input shape
(F, N, D, C), the focus-major layout used by the SO(2) mixing stack.
trainable
Whether parameters are trainable.
seed
Expand Down Expand Up @@ -125,8 +126,8 @@ def __init__(
self.precision = precision
self.mlp_bias = bool(mlp_bias)
self.layout = str(layout).lower()
if self.layout not in {"nfdc", "ndfc"}:
raise ValueError("`layout` must be either 'nfdc' or 'ndfc'")
if self.layout not in {"nfdc", "ndfc", "fndc"}:
raise ValueError("`layout` must be one of 'nfdc', 'ndfc', or 'fndc'")

self.activation_function = str(activation_function)
self.scalar_act = get_activation_fn(activation_function)
Expand Down Expand Up @@ -170,7 +171,8 @@ def call(self, x: Any, gate: Any = None) -> Any:
----------
x
Value features. Shape is (N, F, D, C) when ``layout='nfdc'``,
or (N, D, F, C) when ``layout='ndfc'``.
(N, D, F, C) when ``layout='ndfc'``, or (F, N, D, C) when
``layout='fndc'``.
gate
Optional gate features with the same layout as ``x``.
When provided, enables GLU mode:
Expand All @@ -184,6 +186,10 @@ def call(self, x: Any, gate: Any = None) -> Any:
Gated features with the same layout as ``x``.
"""
xp = array_api_compat.array_namespace(x)
# ``ndfc`` carries the degree axis at position 1; ``nfdc`` and the
# focus-major ``fndc`` carry it at position 2. Every select/narrow/reshape
# below is expressed against this single degree axis, so the three layouts
# share one code path apart from the per-focus gate projection.
degree_axis = 1 if self.layout == "ndfc" else 2

scalar_idx = tuple(
Expand Down Expand Up @@ -211,14 +217,17 @@ def call(self, x: Any, gate: Any = None) -> Any:
return x0

input_dtype = gate_scalar_source.dtype
gating_scalars = xp.astype(
xp_sigmoid(
self.gate_linear(
xp.astype(gate_scalar_source, get_xp_precision(xp, self.precision))
)
),
input_dtype,
)
gate_src = xp.astype(gate_scalar_source, get_xp_precision(xp, self.precision))
if self.layout == "fndc":
# The scalar source is focus-major (F, N, C). ``FocusLinear`` mixes
# channels with the focus stream on axis 1, so present it in the shared
# (N, F, C) convention and restore the focus-major orientation.
gate_logits = xp.permute_dims(
self.gate_linear(xp.permute_dims(gate_src, (1, 0, 2))), (1, 0, 2)
)
else:
gate_logits = self.gate_linear(gate_src)
gating_scalars = xp.astype(xp_sigmoid(gate_logits), input_dtype)
gating_scalars = xp.reshape(
gating_scalars,
(x.shape[0], gate_scalar_source.shape[1], self.lmax, self.channels),
Expand Down
Loading
Loading