diff --git a/deepmd/dpmodel/descriptor/dpa4.py b/deepmd/dpmodel/descriptor/dpa4.py index 013164cc51..d7b4287199 100644 --- a/deepmd/dpmodel/descriptor/dpa4.py +++ b/deepmd/dpmodel/descriptor/dpa4.py @@ -49,6 +49,7 @@ ) from deepmd.dpmodel.array_api import ( xp_asarray_nodetach, + xp_scatter_sum, xp_take_first_n, ) from deepmd.dpmodel.common import ( @@ -92,6 +93,7 @@ EnvironmentInitialEmbedding, GeometricInitialEmbedding, SeZMTypeEmbedding, + SpinEmbedding, ) from .dpa4_nn.ffn import ( EquivariantFFN, @@ -231,7 +233,12 @@ class DescrptDPA4(NativeOP, BaseDescriptor): The node degree of block `i` is `l_schedule[i] + extra_node_l`, while SO(2) message passing still uses `l_schedule[i]`. n_blocks - Number of blocks (only used when `l_schedule` is None). + Number of blocks (only used when `l_schedule` is None). ``0`` disables + the interaction blocks and builds the zero-block descriptor: type + embedding, optional env FiLM and geometric initial embedding, then the + final SO(3) read-out. The backbone degree is taken from `lmax` + (plus `extra_node_l`). Geometry then enters only through the GIE, which + is active when `use_env_seed=True` and `lmax + extra_node_l > 0`. so2_norm If True, apply intermediate ReducedEquivariantRMSNorm between SO(2) mixing layers. When False (default), no normalization is applied between layers. @@ -371,8 +378,15 @@ class DescrptDPA4(NativeOP, BaseDescriptor): interaction block, driven by the SO(3) Wigner-D grid, so ``l>0`` geometry is folded into ``l=0`` before the scalar is extracted. The value selects the quadratic grid product (``"glu"``) or the polynomial point-wise grid - MLP (``"mlp"``). The Wigner-D frame order follows ``kmax``. The residual - stays on the ``l=0`` channel. + MLP (``"mlp"``). The Wigner-D frame order follows ``kmax``. + readout_layers + Number of stacked equivariant residual read-out FFNs (default ``1``). + Every layer is an ``x + FFN(x)`` residual block sharing the read-out + degree; intermediate layers keep the full SO(3) tensor so high-degree + geometry is folded into ``l=0`` repeatedly, and only the final layer + slices the ``l=0`` channel from its residual sum. With ``so3_readout`` of + ``"none"`` the stack is a degree-0 scalar residual MLP on the ``l=0`` + slice. lebedev_quadrature Either one boolean applied to both S2 branches, or two booleans ``[so2_enabled, ffn_enabled]`` aligned with ``s2_activation``. If @@ -482,6 +496,7 @@ def __init__( message_node_s2: bool = False, message_node_so3: bool = False, so3_readout: str = "none", + readout_layers: int = 1, lebedev_quadrature: bool | list[bool] | None = True, activation_function: str = "silu", glu_activation: bool = True, @@ -496,6 +511,7 @@ def __init__( inner_clamp_r_outer: float | None = None, add_chg_spin_ebd: bool = False, default_chg_spin: list[float] | None = None, + use_spin: list[bool] | None = None, **kwargs: Any, ) -> None: self.version = float(self.LATEST_VERSION) @@ -565,6 +581,9 @@ def __init__( self.so3_readout = str(so3_readout).lower() if self.so3_readout not in {"none", "glu", "mlp"}: raise ValueError("`so3_readout` must be one of 'none', 'glu', or 'mlp'") + self.readout_layers = int(readout_layers) + if self.readout_layers < 1: + raise ValueError("`readout_layers` must be >= 1") if lebedev_quadrature is None: lebedev_quadrature = [True, True] elif isinstance(lebedev_quadrature, bool): @@ -605,7 +624,7 @@ def __init__( ) self.mlp_bias = bool(mlp_bias) self.layer_scale = bool(layer_scale) - self.use_amp = bool(use_amp) # and self.training + self.use_amp = bool(use_amp) self.trainable = bool(trainable) self.seed = seed self.random_gamma = bool(random_gamma) @@ -618,6 +637,12 @@ def __init__( None if default_chg_spin is None else [float(x) for x in default_chg_spin] ) + # === Native per-atom spin embedding === + # The spin vector enters the descriptor as an l=0 magnitude scalar plus + # an l=1 direction feature (see ``SpinEmbedding``). Providing per-type + # ``use_spin`` flags enables the native spin embedding. + self.use_spin = None if use_spin is None else [bool(x) for x in use_spin] + # === Zone bridging: InnerClamp + Source Freeze Propagation Gate === # Both the geometry clamp (``InnerClamp``) and the message-passing # switch (``BridgingSwitch``) are activated together on the same @@ -668,6 +693,7 @@ def __init__( seed_full_attn = child_seed(self.seed, 5) seed_block_attn = child_seed(self.seed, 6) seed_charge_spin = child_seed(self.seed, 7) + seed_spin_embedding = child_seed(self.seed, 8) # === L/M schedules === self._init_lm_schedules(lmax, n_blocks, l_schedule, mmax, m_schedule) @@ -676,7 +702,6 @@ def __init__( raise ValueError("`kmax` must be non-negative") if self.kmax > self.lmax: raise ValueError("`kmax` must be <= `lmax`") - self.ebed_dims = [get_so3_dim_of_lmax(l) for l in self.l_schedule] self._init_node_l_schedules(extra_node_l) self.rad_sizes_per_block = [l + 1 for l in self.l_schedule] @@ -784,6 +809,26 @@ def __init__( else: self.charge_spin_embedding = None + if self.use_spin is not None: + if self.node_init_lmax < 1: + raise ValueError( + "`use_spin` requires a node degree >= 1 " + "(lmax + extra_node_l) to host the l=1 spin feature." + ) + self.spin_embedding: SpinEmbedding | None = SpinEmbedding( + ntypes=self.ntypes, + channels=self.channels, + use_spin=self.use_spin, + activation_function=self.activation_function, + precision=self.compute_precision, # force fp32+ + seed=seed_spin_embedding, + trainable=self.trainable, + ) + # Packed rows hosting the l=1 spin coefficients (m = -1, 0, +1). + self._spin_l1_rows = np.arange(1, 4, dtype=np.int64) + else: + self.spin_embedding = None + # === Env FiLM embedding (optional) === if self.use_env_seed: self.env_seed_embedding: EnvironmentInitialEmbedding | None = ( @@ -798,6 +843,7 @@ def __init__( mlp_bias=self.mlp_bias, activation_function=self.activation_function, eps=self.eps, + use_spin=self.use_spin, precision=self.compute_precision, # force fp32+ trainable=self.trainable, seed=seed_env_seed, @@ -849,7 +895,7 @@ def __init__( # GIE and truncated for each SO2Conv block. # radial_mlp specifies hidden layer sizes; input/output layers are prepended/appended. # Use fp32+ precision (same as RBF output) for numerical stability. - radial_out_dim = (self.node_l_schedule[0] + 1) * self.channels + radial_out_dim = (self.node_init_lmax + 1) * self.channels radial_mlp_layers = [self.n_radial, *self.radial_mlp, radial_out_dim] self.radial_embedding = RadialMLP( radial_mlp_layers, @@ -874,22 +920,22 @@ def __init__( ] self._need_full_wigner = not all(block_edge_cartesian) self.wigner_calc = WignerDCalculator( - lmax=self.l_schedule[0], + lmax=self.mp_init_lmax, eps=self.eps, precision=self.compute_precision, # force fp32+ ) - self.use_gie = self.use_env_seed and self.node_l_schedule[0] > 0 + self.use_gie = self.use_env_seed and self.node_init_lmax > 0 if self.use_gie: self.gie = GeometricInitialEmbedding( - lmax=self.node_l_schedule[0], + lmax=self.node_init_lmax, channels=self.channels, precision=self.compute_precision, # force fp32+ ) if self.extra_node_l > 0: self.gie_zonal_wigner_calc: WignerDCalculator | None = ( WignerDCalculator( - lmax=self.node_l_schedule[0], + lmax=self.node_init_lmax, eps=self.eps, precision=self.compute_precision, ) @@ -991,28 +1037,33 @@ def __init__( seed=child_seed(seed_block_attn, 2000), ) - # === Final FFN for l=0 output mixing === - # ``so3_readout="none"`` runs a degree-0 scalar FFN on the l=0 slice. - # ``"glu"``/``"mlp"`` run a full FFN at the last block's node degree whose - # SO(3) Wigner-D grid folds l>0 geometry into l=0; the value selects the - # quadratic grid product or the point-wise grid MLP. - readout_lmax = self.node_l_schedule[-1] - self.output_ffn = EquivariantFFN( - lmax=0 if self.so3_readout == "none" else readout_lmax, - channels=self.channels, - hidden_channels=self.out_ffn_neurons, - kmax=min(self.kmax, readout_lmax), - grid_mlp=self.so3_readout == "mlp", - grid_branch=0, - precision=self.compute_precision, - s2_activation=False, - ffn_so3_grid=self.so3_readout != "none", - activation_function=self.out_activation_function, - glu_activation=self.out_glu_activation, - mlp_bias=self.mlp_bias, - trainable=self.trainable, - seed=seed_out, - ) + # === Final FFN stack for l=0 output mixing === + # ``readout_layers`` residual blocks run in sequence (see + # ``_apply_readout``): ``readout_pre_layers`` keep the full SO(3) tensor + # and only the final ``output_ffn`` slices l=0. The final layer keeps the + # ``output_ffn`` name and ``seed_out`` so a single-layer read-out matches + # the single-module checkpoint layout. + readout_lmax = self.node_readout_lmax + readout_ffn_kwargs = { + "lmax": 0 if self.so3_readout == "none" else readout_lmax, + "channels": self.channels, + "hidden_channels": self.out_ffn_neurons, + "kmax": min(self.kmax, readout_lmax), + "grid_mlp": self.so3_readout == "mlp", + "grid_branch": 0, + "precision": self.compute_precision, + "s2_activation": False, + "ffn_so3_grid": self.so3_readout != "none", + "activation_function": self.out_activation_function, + "glu_activation": self.out_glu_activation, + "mlp_bias": self.mlp_bias, + "trainable": self.trainable, + } + self.readout_pre_layers = [ + EquivariantFFN(**readout_ffn_kwargs, seed=child_seed(seed_out, layer_index)) + for layer_index in range(self.readout_layers - 1) + ] + self.output_ffn = EquivariantFFN(**readout_ffn_kwargs, seed=seed_out) # === Statistics buffers (interface compatibility) === self.stats: dict[str, Any] | None = None @@ -1032,6 +1083,7 @@ def call( fparam: Array | None = None, force_embedding: Array | None = None, charge_spin: Array | None = None, + spin: Array | None = None, ) -> tuple[ Array, Array | None, @@ -1069,7 +1121,7 @@ def call( force_embedding Optional precomputed equivariant force embedding with shape ``(nf * nloc, D, 1, channels)``, where - ``D = (node_l_schedule[0] + 1) ** 2``. This tensor is added to the + ``D = (node_init_lmax + 1) ** 2``. This tensor is added to the initial SO(3) backbone state before the interaction blocks. charge_spin Frame-level charge and spin conditions with shape (nf, 2). @@ -1110,6 +1162,7 @@ def call( edge_mask=edge_mask, force_embedding=force_embedding, charge_spin=charge_spin, + spin=spin, ) return ( descriptor, @@ -1153,6 +1206,14 @@ def call( nloc=nloc, ) + # Native spin: condition the l=0 type features on the spin magnitude + # and hold the l=1 direction coefficients for the backbone seed. + spin_vec = None + if self.spin_embedding is not None and spin is not None: + type_ebed, spin_vec = self._apply_spin_embedding( + type_ebed, spin, xp.reshape(atype_loc, (-1,)), n_nodes=n_nodes + ) + # === Step 4. Build edge cache once (geometry + RBF + Wigner-D) === # Zone bridging (InnerClamp + SFPG + ZBL) is not routed through the # standard DeePMD path: bridging only makes physical sense when @@ -1177,26 +1238,32 @@ def call( build_wigner=self._need_full_wigner, ) - ebed_dim_0 = self.node_ebed_dims[0] # (node_lmax+1)^2 + ebed_dim_0 = self.node_init_dim # (node_init_lmax+1)^2 x0 = type_ebed # (N, C) x0_out = x0 # (N, C) # === Step 5. Compute radial features once (fp32+) === - # Shape: (E, (node_lmax+1)*C) -> (E, node_lmax+1, C) + # Shape: (E, (node_init_lmax+1)*C) -> (E, node_init_lmax+1, C) radial_feat = xp.reshape( self.radial_embedding(edge_cache.edge_rbf), - (-1, self.node_l_schedule[0] + 1, self.channels), - ) # (E, lmax+1, C) + (-1, self.node_init_lmax + 1, self.channels), + ) # (E, node_init_lmax+1, C) if self.version >= 1.1: radial_feat = radial_feat * xp.reshape(edge_cache.edge_env, (-1, 1, 1)) # === Step 6. Env FiLM conditioning (optional, fp32+) === if self.use_env_seed: atype_flat = xp.reshape(atype_loc, (-1,)) # (N,) + spin_flat = ( + xp.reshape(spin, (n_nodes, 3)) + if (self.spin_embedding is not None and spin is not None) + else None + ) film = self.env_seed_embedding( edge_cache=edge_cache, atype_flat=atype_flat, n_nodes=n_nodes, + spin=spin_flat, ) # (N, 2*C) scale_logits = film[:, : self.channels] # (N, C) shift_logits = film[:, self.channels :] # (N, C) @@ -1229,10 +1296,19 @@ def call( axis=1, ) # (N, D, 1, C) - # === Step 8. Geometric Initial Embedding (fp32+) === + # === Step 8. Geometric Initial Embedding (+ neighbor spin l=1) === if self.use_gie: # GIE only needs l>=1, slice radial_feat[:, 1:, :] zonal_coupling = self._build_gie_zonal_coupling(edge_cache) + spin_l1_message = ( + self.spin_embedding.edge_l1( + xp.reshape(spin, (n_nodes, 3)), + xp.reshape(atype_loc, (-1,)), + edge_cache, + ) + if (self.spin_embedding is not None and spin is not None) + else None + ) x = ( x + self.gie( @@ -1240,10 +1316,22 @@ def call( edge_cache=edge_cache, radial_feat=radial_feat[:, 1:, :], zonal_coupling=zonal_coupling, + spin_l1_message=spin_l1_message, )[:, :, None, :] ) - # === Step 9. Fuse edge type features into radial features (fp32+) === + # === Step 9. Add the on-site native spin l=1 to the backbone === + # The neighbor-spin l=1 is aggregated inside GIE (degree-normalized like + # the geometry); the atom's own spin direction is added here, un-normalized. + if spin_vec is not None: + spin_l1_rows = xp_asarray_nodetach(xp, self._spin_l1_rows, device=device) + spin_l1_src = spin_vec[:, :, None, :] # (N, 3, 1, C) + scatter_index = xp.broadcast_to( + xp.reshape(spin_l1_rows, (1, 3, 1, 1)), spin_l1_src.shape + ) + x = xp_scatter_sum(x, 1, scatter_index, spin_l1_src) + + # === Step 10. Fuse edge type features into radial features (fp32+) === radial_feat = radial_feat + xp.reshape( edge_cache.edge_type_feat, (-1, 1, self.channels) ) @@ -1252,38 +1340,23 @@ def call( radial_feat[:, :rad_len, :] for rad_len in self.rad_sizes_per_block ] # list of (E, lmax+1, C) - # === Step 10. Convert to self.dtype and run blocks === + # === Step 11. Convert to self.dtype and run blocks === + # The block stage is skipped entirely when there are no interaction + # blocks (zero-block descriptor) or no valid edges, sparing the working + # edge-cache dtype cast that only the blocks consume. x = xp.astype(x, get_xp_precision(xp, self.precision)) # (N, D, 1, C) if force_embedding is not None: x = x + xp.astype(force_embedding, get_xp_precision(xp, self.precision)) - edge_cache = edge_cache_to_dtype( - edge_cache, get_xp_precision(xp, self.precision) - ) - x = self._forward_blocks(x, edge_cache, rad_feat_per_block) - - # === Step 11. Final l=0 output mixing === - # ``none`` feeds the l=0 slice only; ``glu``/``mlp`` feed the full - # (N, D, 1, C) node tensor so the SO(3) grid folds l>0 into l=0. The - # residual is added on the full coefficient tensor before extracting - # l=0: slicing the summed tensor rather than the FFN output keeps the - # saved degree-axis stride static under torch.compile dynamic shapes. - ffn_in = ( - xp.astype( - xp.reshape(x[:, 0:1, :, :], (n_nodes, 1, 1, self.channels)), - get_xp_precision(xp, self.compute_precision), + if self.blocks and edge_cache.src.shape[0] > 0: + edge_cache = edge_cache_to_dtype( + edge_cache, get_xp_precision(xp, self.precision) ) - if self.so3_readout == "none" - # truncate to the final node degree: the empty-edge path - # skips the blocks, leaving x at node_ebed_dims[0]; output_ffn - # is built for node_ebed_dims[-1]. No-op when blocks ran. - else xp.astype( - x[:, : self.node_ebed_dims[-1], :, :], - get_xp_precision(xp, self.compute_precision), - ) - ) - x_scalar = (ffn_in + self.output_ffn(ffn_in))[:, 0:1, :, :] + x = self._forward_blocks(x, edge_cache, rad_feat_per_block) - # === Step 12. Reshape to (nf, nloc, channels) and return === + # === Step 12. Final l=0 output mixing === + x_scalar = self._apply_readout(x, n_nodes) + + # === Step 13. Reshape to (nf, nloc, channels) and return === descriptor = xp.reshape(x_scalar, (nf, nloc, self.channels)) # (nf, nloc, C) return ( xp.astype(descriptor, get_xp_precision(xp, "global")), @@ -1303,6 +1376,7 @@ def call_with_edges( edge_mask: Array, force_embedding: Array | None = None, charge_spin: Array | None = None, + spin: Array | None = None, comm_dict: dict[str, Array] | None = None, nloc: int | None = None, ) -> tuple[Array, Array]: @@ -1336,7 +1410,7 @@ def call_with_edges( force_embedding Optional precomputed equivariant force embedding with shape ``(nf * nloc, D, 1, channels)``, where - ``D = (node_l_schedule[0] + 1) ** 2``. This tensor is added to the + ``D = (node_init_lmax + 1) ** 2``. This tensor is added to the initial SO(3) backbone state before the interaction blocks. charge_spin Frame-level charge and spin conditions with shape (nf, 2). @@ -1385,6 +1459,14 @@ def call_with_edges( ) n_nodes = type_ebed.shape[0] + # Native spin: condition the l=0 type features on the spin magnitude + # and hold the l=1 direction coefficients for the backbone seed. + spin_vec = None + if self.spin_embedding is not None and spin is not None: + type_ebed, spin_vec = self._apply_spin_embedding( + type_ebed, spin, atype_flat, n_nodes=n_nodes + ) + # === Step 3. Build edge cache once (sparse edges) === edge_cache = build_edge_cache_from_edges( type_ebed=type_ebed, @@ -1408,7 +1490,7 @@ def call_with_edges( build_wigner=self._need_full_wigner, ) - ebed_dim_0 = self.node_ebed_dims[0] # (node_lmax+1)^2 + ebed_dim_0 = self.node_init_dim # (node_init_lmax+1)^2 x0 = type_ebed # (N, C) x0_out = x0 # (N, C) @@ -1418,19 +1500,25 @@ def call_with_edges( radial_feat_flat, ( radial_feat_flat.shape[0], - self.node_l_schedule[0] + 1, + self.node_init_lmax + 1, self.channels, ), - ) # (E, lmax+1, C) + ) # (E, node_init_lmax+1, C) if self.version >= 1.1: radial_feat = radial_feat * xp.reshape(edge_cache.edge_env, (-1, 1, 1)) # === Step 5. Env FiLM conditioning (optional, fp32+) === if self.use_env_seed: + spin_flat = ( + xp.reshape(spin, (n_nodes, 3)) + if (self.spin_embedding is not None and spin is not None) + else None + ) film = self.env_seed_embedding( edge_cache=edge_cache, atype_flat=atype_flat, n_nodes=n_nodes, + spin=spin_flat, ) # (N, 2*C) scale_logits = film[:, : self.channels] # (N, C) shift_logits = film[:, self.channels :] # (N, C) @@ -1463,9 +1551,16 @@ def call_with_edges( axis=1, ) # (N, D, 1, C) - # === Step 7. Geometric Initial Embedding (fp32+) === + # === Step 7. Geometric Initial Embedding (+ neighbor spin l=1) === if self.use_gie: zonal_coupling = self._build_gie_zonal_coupling(edge_cache) + spin_l1_message = ( + self.spin_embedding.edge_l1( + xp.reshape(spin, (n_nodes, 3)), atype_flat, edge_cache + ) + if (self.spin_embedding is not None and spin is not None) + else None + ) x = ( x + self.gie( @@ -1473,10 +1568,22 @@ def call_with_edges( edge_cache=edge_cache, radial_feat=radial_feat[:, 1:, :], zonal_coupling=zonal_coupling, + spin_l1_message=spin_l1_message, )[:, :, None, :] ) - # === Step 8. Fuse edge type features into radial features (fp32+) === + # === Step 8. Add the on-site native spin l=1 to the backbone === + # The neighbor-spin l=1 is aggregated inside GIE; the + # atom's own spin direction is added here, un-normalized. + if spin_vec is not None: + spin_l1_rows = xp_asarray_nodetach(xp, self._spin_l1_rows, device=device) + spin_l1_src = spin_vec[:, :, None, :] # (N, 3, 1, C) + scatter_index = xp.broadcast_to( + xp.reshape(spin_l1_rows, (1, 3, 1, 1)), spin_l1_src.shape + ) + x = xp_scatter_sum(x, 1, scatter_index, spin_l1_src) + + # === Step 9. Fuse edge type features into radial features (fp32+) === radial_feat = xp.astype(radial_feat, get_xp_precision(xp, self.precision)) radial_feat = radial_feat + xp.reshape( xp.astype(edge_cache.edge_type_feat, get_xp_precision(xp, self.precision)), @@ -1486,16 +1593,21 @@ def call_with_edges( radial_feat[:, :rad_len, :] for rad_len in self.rad_sizes_per_block ] - # === Step 9. Convert to self.dtype and run blocks === + # === Step 10. Convert to self.dtype and run blocks === + # The block stage is skipped entirely for the zero-block descriptor, + # sparing the working edge-cache dtype cast that only the blocks consume. x = xp.astype(x, get_xp_precision(xp, self.precision)) # (N, D, 1, C) if force_embedding is not None: x = x + xp.astype(force_embedding, get_xp_precision(xp, self.precision)) - edge_cache = edge_cache_to_dtype( - edge_cache, get_xp_precision(xp, self.precision) - ) - x = self._forward_blocks(x, edge_cache, rad_feat_per_block, comm_dict=comm_dict) + if self.blocks: + edge_cache = edge_cache_to_dtype( + edge_cache, get_xp_precision(xp, self.precision) + ) + x = self._forward_blocks( + x, edge_cache, rad_feat_per_block, comm_dict=comm_dict + ) - # === Step 10. Keep the owned-atom rows for the read-out === + # === Step 11. Keep the owned-atom rows for the read-out === # ``n_out_nodes`` is the owned-node count in the flattened layout # (``nf * nloc``). Single-domain: ``out_nloc == n_per_frame``, so this # equals the whole node set and the slice is a no-op. Parallel @@ -1504,29 +1616,10 @@ def call_with_edges( n_out_nodes = nf * out_nloc x = x[:n_out_nodes] - # === Step 11. Final l=0 output mixing === - # ``none`` feeds the l=0 slice only; ``glu``/``mlp`` feed the full - # (N, D, 1, C) node tensor so the SO(3) grid folds l>0 into l=0. The - # residual is added on the full coefficient tensor before extracting - # l=0: slicing the summed tensor rather than the FFN output keeps the - # saved degree-axis stride static under torch.compile dynamic shapes. - ffn_in = ( - xp.astype( - xp.reshape(x[:, 0:1, :, :], (n_out_nodes, 1, 1, self.channels)), - get_xp_precision(xp, self.compute_precision), - ) - if self.so3_readout == "none" - # truncate to the final node degree: the empty-edge path - # skips the blocks, leaving x at node_ebed_dims[0]; output_ffn - # is built for node_ebed_dims[-1]. No-op when blocks ran. - else xp.astype( - x[:, : self.node_ebed_dims[-1], :, :], - get_xp_precision(xp, self.compute_precision), - ) - ) - x_scalar = (ffn_in + self.output_ffn(ffn_in))[:, 0:1, :, :] + # === Step 12. Final l=0 output mixing === + x_scalar = self._apply_readout(x, n_out_nodes) - # === Step 12. Reshape to (nf, nloc, channels) and return === + # === Step 13. Reshape to (nf, nloc, channels) and return === descriptor = xp.reshape( x_scalar, (nf, out_nloc, self.channels) ) # (nf, nloc, C) @@ -1606,7 +1699,7 @@ def node_l0_extractor(v: Array) -> Array: x = block_output # === Step 3. Final aggregation over all completed unit representations === - final_dim = self.node_ebed_dims[-1] + final_dim = self.node_readout_dim final_sources = [source[:, :final_dim, :, :] for source in unit_history] x = xp.astype( self.final_full_attn_res( @@ -1640,7 +1733,7 @@ def node_l0_extractor(v: Array) -> Array: x = block_output # === Step 3. Final aggregation over all completed block summaries === - final_dim = self.node_ebed_dims[-1] + final_dim = self.node_readout_dim final_sources = [source[:, :final_dim, :, :] for source in block_history] x = xp.astype( self.final_block_attn_res( @@ -1652,6 +1745,49 @@ def node_l0_extractor(v: Array) -> Array: ) return x + def _apply_readout(self, x: Array, n_rows: int) -> Array: + """Fold the node tensor into the scalar (``l=0``) descriptor. + + Runs the ``readout_layers`` stack of equivariant residual read-out FFNs. + ``so3_readout="none"`` feeds only the ``l=0`` slice; ``"glu"``/``"mlp"`` + feed the full ``(N, D, 1, C)`` node tensor so the SO(3) grid folds + ``l>0`` geometry into ``l=0``. Each layer is an ``x + FFN(x)`` residual: + the ``readout_pre_layers`` keep the full tensor so the geometry keeps + folding, while the final ``output_ffn`` slices the ``l=0`` channel from + its residual sum. Slicing the summed tensor rather than the FFN output + keeps the saved degree-axis stride static under ``torch.compile`` dynamic + shapes. + + Parameters + ---------- + x + Node features with shape ``(n_rows, D, 1, channels)``. With the + blocks skipped (zero-block or empty-edge path) ``D`` is the initial + degree; otherwise the pyramid has shrunk it, so the read-out slice to + ``node_readout_dim`` is a no-op there. + n_rows + Number of node rows fed to the read-out. + + Returns + ------- + Array + Scalar descriptor with shape ``(n_rows, 1, 1, channels)``. + """ + xp = array_api_compat.array_namespace(x) + if self.so3_readout == "none": + x_ro = xp.astype( + xp.reshape(x[:, 0:1, :, :], (n_rows, 1, 1, self.channels)), + get_xp_precision(xp, self.compute_precision), + ) + else: + x_ro = xp.astype( + x[:, : self.node_readout_dim, :, :], + get_xp_precision(xp, self.compute_precision), + ) + for layer in self.readout_pre_layers: + x_ro = x_ro + layer(x_ro) + return (x_ro + self.output_ffn(x_ro))[:, 0:1, :, :] + def _edge_quaternion(self, edge_cache: EdgeCache) -> Array: """ Return the cached global->local edge quaternion, rebuilding if absent. @@ -1700,7 +1836,7 @@ def _build_gie_zonal_coupling( return None xp = array_api_compat.array_namespace(edge_cache.Dt_full) device = array_api_compat.device(edge_cache.Dt_full) - mp_row_count = self.ebed_dims[0] - 1 + mp_row_count = self.mp_init_dim - 1 mp_row_index = self.gie.non_scalar_row_index[:mp_row_count] mp_m0_col_index = self.gie.zonal_m0_col_index_for_row[:mp_row_count] dim_full = edge_cache.Dt_full.shape[-1] @@ -1749,6 +1885,45 @@ def _apply_charge_spin_embedding( condition = xp.broadcast_to(condition[:, None, :], (nf, nloc, self.channels)) return type_ebed + xp.reshape(condition, type_ebed.shape) + def _apply_spin_embedding( + self, + type_ebed: Array, + spin: Array, + atype_flat: Array, + *, + n_nodes: int, + ) -> tuple[Array, Array]: + """ + Inject the per-atom spin embedding into the node features. + + The l=0 magnitude scalar is added to the flattened type embedding so it + propagates into the scalar backbone, the per-edge type features, and + every block's radial features (exactly like the type embedding). The l=1 + direction coefficients are returned for the caller to add to the + equivariant backbone after the geometric initial embedding. + + Parameters + ---------- + type_ebed + Flattened type embedding with shape (N, channels). + spin + Per-atom spin vectors with shape (nf, nloc, 3) or (N, 3). + atype_flat + Flattened local atom types with shape (N,). + n_nodes + Number of local nodes ``N = nf * nloc``. + + Returns + ------- + tuple[Array, Array] + The l=0-conditioned type embedding with shape (N, channels) and the + packed l=1 direction coefficients with shape (N, 3, channels). + """ + xp = array_api_compat.array_namespace(type_ebed, spin) + scalar, vector = self.spin_embedding(xp.reshape(spin, (n_nodes, 3)), atype_flat) + type_ebed = type_ebed + xp.astype(scalar, type_ebed.dtype) + return type_ebed, vector + def _edge_type_keep_mask( self, atype_flat: Array, @@ -1837,14 +2012,19 @@ def _init_lm_schedules( mmax: int | None, m_schedule: list[int] | None, ) -> None: - """Parse and validate L/M schedules, setting self.l_schedule/m_schedule/lmax/mmax.""" + """Parse and validate L/M schedules, setting self.l_schedule/m_schedule/lmax/mmax. + + An empty schedule (``n_blocks=0`` or ``l_schedule=[]``) is valid and + selects the zero-block descriptor: no interaction blocks are built, only + the initial SO(3) backbone (type embedding, optional env FiLM and GIE) + followed by the final read-out. The backbone degree then derives from + the configured ``lmax``/``mmax`` instead of the schedule endpoints. + """ # === L schedule === if l_schedule is None: self.l_schedule = [int(lmax)] * int(n_blocks) else: self.l_schedule = [int(x) for x in l_schedule] - if len(self.l_schedule) == 0: - raise ValueError("`l_schedule` must be non-empty") if any(x < 0 for x in self.l_schedule): raise ValueError("`l_schedule` entries must be non-negative") if any( @@ -1853,7 +2033,9 @@ def _init_lm_schedules( ): raise ValueError("`l_schedule` must be non-increasing (pyramid schedule)") - self.lmax = int(self.l_schedule[0]) + # The first entry sets the maximum degree; with zero blocks the backbone + # degree falls back to the configured ``lmax``. + self.lmax = int(self.l_schedule[0]) if self.l_schedule else int(lmax) self.n_blocks = len(self.l_schedule) # === M schedule === @@ -1867,8 +2049,6 @@ def _init_lm_schedules( self.m_schedule = [min(mmax_i, int(l)) for l in self.l_schedule] else: self.m_schedule = [int(x) for x in m_schedule] - if len(self.m_schedule) == 0: - raise ValueError("`m_schedule` must be non-empty") if len(self.m_schedule) != len(self.l_schedule): raise ValueError("`m_schedule` must have the same length as `l_schedule`") if any(x < 0 for x in self.m_schedule): @@ -1878,10 +2058,30 @@ def _init_lm_schedules( "`m_schedule` entries must satisfy `m_schedule[i] <= l_schedule[i]`" ) - self.mmax = int(self.m_schedule[0]) + self.mmax = ( + int(self.m_schedule[0]) + if self.m_schedule + else (int(mmax) if mmax is not None else int(self.lmax)) + ) def _init_node_l_schedules(self, extra_node_l: int) -> None: - """Parse node degree schedules derived from message-passing schedules.""" + """Parse node degree schedules and resolve the canonical backbone degrees. + + The descriptor references three backbone degrees that must stay valid + even with zero interaction blocks, so they are resolved here into + scalars rather than indexed off the (possibly empty) schedules: + + - ``mp_init_lmax`` : message-passing degree at initialization, driving + the Wigner-D calculator and the GIE message-passing coupling rows. + - ``node_init_lmax`` : node backbone degree at initialization, driving + the radial-embedding width, the initial state dimension, and GIE. + - ``node_readout_lmax`` : node backbone degree fed to the read-out FFN. + + With blocks these equal ``l_schedule[0]``, ``node_l_schedule[0]`` and + ``node_l_schedule[-1]``; with zero blocks all three collapse onto the + configured ``lmax`` (plus ``extra_node_l`` on the node side), so the + pyramid endpoints are never read from an empty list. + """ self.extra_node_l = int(extra_node_l) if self.extra_node_l < 0: raise ValueError("`extra_node_l` must be non-negative") @@ -1891,8 +2091,16 @@ def _init_node_l_schedules(self, extra_node_l: int) -> None: self.node_ebed_dims = [ get_so3_dim_of_lmax(l_value) for l_value in self.node_l_schedule ] - self.node_lmax = int(self.node_l_schedule[0]) - self.node_ebed_dim = int(self.node_ebed_dims[0]) + + # === Canonical backbone degrees (valid for any block count) === + self.mp_init_lmax = int(self.lmax) + self.node_init_lmax = int(self.lmax) + self.extra_node_l + self.node_readout_lmax = ( + int(self.node_l_schedule[-1]) if self.n_blocks > 0 else self.node_init_lmax + ) + self.mp_init_dim = get_so3_dim_of_lmax(self.mp_init_lmax) + self.node_init_dim = get_so3_dim_of_lmax(self.node_init_lmax) + self.node_readout_dim = get_so3_dim_of_lmax(self.node_readout_lmax) def _canonicalize_charge_spin( self, @@ -2178,6 +2386,10 @@ def _variables(self) -> dict[str, np.ndarray]: "@variables" ].items(): variables[f"charge_spin_embedding.{key}"] = value + # === Native per-atom spin embedding (optional) === + if self.spin_embedding is not None: + for key, value in self.spin_embedding.serialize()["@variables"].items(): + variables[f"spin_embedding.{key}"] = value # === Environment FiLM stack (optional) === if self.use_env_seed: for key, value in self.env_seed_embedding.serialize()["@variables"].items(): @@ -2238,6 +2450,10 @@ def _variables(self) -> dict[str, np.ndarray]: "@variables" ].items(): variables[f"final_block_attn_res.{key}"] = value + # === Read-out pre-layers (optional) === + for i, layer in enumerate(self.readout_pre_layers): + for key, value in layer._variables().items(): + variables[f"readout_pre_layers.{i}.{key}"] = value # === Output FFN === for key, value in self.output_ffn._variables().items(): variables[f"output_ffn.{key}"] = value @@ -2272,6 +2488,9 @@ def load(module: Any, prefix: str) -> Any: self.charge_spin_embedding = load( self.charge_spin_embedding, "charge_spin_embedding." ) + # === Native per-atom spin embedding (optional) === + if self.spin_embedding is not None: + self.spin_embedding = load(self.spin_embedding, "spin_embedding.") # === Environment FiLM stack (optional) === if self.use_env_seed: self.env_seed_embedding = load( @@ -2300,6 +2519,9 @@ def load(module: Any, prefix: str) -> Any: self.final_block_attn_res = load( self.final_block_attn_res, "final_block_attn_res." ) + # === Read-out pre-layers (optional) === + for i, layer in enumerate(self.readout_pre_layers): + layer._load_variables(take_prefix(f"readout_pre_layers.{i}.")) # === Output FFN === self.output_ffn._load_variables(take_prefix("output_ffn.")) @@ -2355,6 +2577,7 @@ def serialize(self) -> dict[str, Any]: "message_node_s2": self.message_node_s2, "message_node_so3": self.message_node_so3, "so3_readout": self.so3_readout, + "readout_layers": self.readout_layers, "lebedev_quadrature": self.lebedev_quadrature, "activation_function": self.activation_function, "glu_activation": self.glu_activation, @@ -2368,6 +2591,7 @@ def serialize(self) -> dict[str, Any]: "inner_clamp_r_outer": self.inner_clamp_r_outer, "add_chg_spin_ebd": self.add_chg_spin_ebd, "default_chg_spin": self.default_chg_spin, + "use_spin": self.use_spin, }, "@variables": self._variables(), "env_mat": EnvMat(self.rcut, self.rcut, self.eps).serialize(), diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/__init__.py b/deepmd/dpmodel/descriptor/dpa4_nn/__init__.py index d2847b6965..54d1e3045c 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/__init__.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/__init__.py @@ -41,6 +41,7 @@ EnvironmentInitialEmbedding, GeometricInitialEmbedding, SeZMTypeEmbedding, + SpinEmbedding, ) from .ffn import ( EquivariantFFN, @@ -158,6 +159,7 @@ "ScalarRMSNorm", "SeZMInteractionBlock", "SeZMTypeEmbedding", + "SpinEmbedding", "SwiGLU", "WignerDCalculator", "apply_lora_to_sezm", diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/activation.py b/deepmd/dpmodel/descriptor/dpa4_nn/activation.py index 85f947b1fc..f232813885 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/activation.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/activation.py @@ -92,7 +92,8 @@ class GatedActivation(NativeOP): Whether to use bias in the gate linear layer. layout Tensor layout convention. ``"nfdc"`` means input shape (N, F, D, C); - ``"ndfc"`` means input shape (N, D, F, C). + ``"ndfc"`` means input shape (N, D, F, C); ``"fndc"`` means input shape + (F, N, D, C), the focus-major layout used by the SO(2) mixing stack. trainable Whether parameters are trainable. seed @@ -125,8 +126,8 @@ def __init__( self.precision = precision self.mlp_bias = bool(mlp_bias) self.layout = str(layout).lower() - if self.layout not in {"nfdc", "ndfc"}: - raise ValueError("`layout` must be either 'nfdc' or 'ndfc'") + if self.layout not in {"nfdc", "ndfc", "fndc"}: + raise ValueError("`layout` must be one of 'nfdc', 'ndfc', or 'fndc'") self.activation_function = str(activation_function) self.scalar_act = get_activation_fn(activation_function) @@ -170,7 +171,8 @@ def call(self, x: Any, gate: Any = None) -> Any: ---------- x Value features. Shape is (N, F, D, C) when ``layout='nfdc'``, - or (N, D, F, C) when ``layout='ndfc'``. + (N, D, F, C) when ``layout='ndfc'``, or (F, N, D, C) when + ``layout='fndc'``. gate Optional gate features with the same layout as ``x``. When provided, enables GLU mode: @@ -184,6 +186,10 @@ def call(self, x: Any, gate: Any = None) -> Any: Gated features with the same layout as ``x``. """ xp = array_api_compat.array_namespace(x) + # ``ndfc`` carries the degree axis at position 1; ``nfdc`` and the + # focus-major ``fndc`` carry it at position 2. Every select/narrow/reshape + # below is expressed against this single degree axis, so the three layouts + # share one code path apart from the per-focus gate projection. degree_axis = 1 if self.layout == "ndfc" else 2 scalar_idx = tuple( @@ -211,14 +217,17 @@ def call(self, x: Any, gate: Any = None) -> Any: return x0 input_dtype = gate_scalar_source.dtype - gating_scalars = xp.astype( - xp_sigmoid( - self.gate_linear( - xp.astype(gate_scalar_source, get_xp_precision(xp, self.precision)) - ) - ), - input_dtype, - ) + gate_src = xp.astype(gate_scalar_source, get_xp_precision(xp, self.precision)) + if self.layout == "fndc": + # The scalar source is focus-major (F, N, C). ``FocusLinear`` mixes + # channels with the focus stream on axis 1, so present it in the shared + # (N, F, C) convention and restore the focus-major orientation. + gate_logits = xp.permute_dims( + self.gate_linear(xp.permute_dims(gate_src, (1, 0, 2))), (1, 0, 2) + ) + else: + gate_logits = self.gate_linear(gate_src) + gating_scalars = xp.astype(xp_sigmoid(gate_logits), input_dtype) gating_scalars = xp.reshape( gating_scalars, (x.shape[0], gate_scalar_source.shape[1], self.lmax, self.channels), diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/embedding.py b/deepmd/dpmodel/descriptor/dpa4_nn/embedding.py index 9d16db2d14..238f904a8c 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/embedding.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/embedding.py @@ -30,8 +30,10 @@ from deepmd.dpmodel.array_api import ( xp_add_at, xp_asarray_nodetach, + xp_scatter_sum, ) from deepmd.dpmodel.common import ( + get_xp_precision, to_numpy_array, ) from deepmd.dpmodel.utils.network import ( @@ -44,6 +46,9 @@ check_version_compatibility, ) +from .cartesian import ( + build_cartesian_basis, +) from .indexing import ( build_gie_zonal_index, get_so3_dim_of_lmax, @@ -220,6 +225,10 @@ def __init__( self.non_scalar_row_index = node_row_index self.zonal_m0_col_index_for_row = node_zonal_m0_col_index self.radial_slot_index_for_row = node_radial_l_index + # The l=1 coefficients (packed rows 1..3) are the first three entries of + # the non-scalar sequence ``node_row_index = [1, 2, ..., D-1]``, so the + # native neighbor-spin l=1 message folds in at these local positions. + self.l1_local_index = np.arange(3, dtype=np.int64) def call( self, @@ -228,6 +237,7 @@ def call( edge_cache: EdgeCache, radial_feat: Any, zonal_coupling: Any = None, + spin_l1_message: Any = None, ) -> Any: """ Parameters @@ -241,6 +251,12 @@ def call( zonal_coupling Optional precomputed zonal coupling with shape (E, D-1). If None, it is gathered from ``edge_cache.Dt_full``. + spin_l1_message + Optional per-edge neighbor-spin l=1 message with shape (E, 3, C) for + the native spin scheme (built by ``SpinEmbedding.edge_l1``). It is + added to the l=1 rows of the per-edge message, so it shares this + module's source gate, scatter and degree normalization with the + geometric message. Returns ------- @@ -286,6 +302,19 @@ def call( zonal_coupling[:, :, None] * radial_value_for_row ) # (E, D-1, C) + # === Step 3b. Fold in the neighbor-spin l=1 message (native spin) === + # The l=1 coefficients occupy the first three packed non-scalar rows, so + # the neighbor-spin message joins the geometric message there and then + # shares the source gate, scatter and degree normalization below. + if spin_l1_message is not None: + l1_local_index = xp_asarray_nodetach(xp, self.l1_local_index, device=device) + scatter_index = xp.broadcast_to( + xp.reshape(l1_local_index, (1, 3, 1)), spin_l1_message.shape + ) + non_scalar_message = xp_scatter_sum( + non_scalar_message, 1, scatter_index, spin_l1_message + ) + # === Step 4. Source Freeze Propagation Gate (optional) === # Mute messages emitted by nodes whose local neighborhood enters # the frozen zone. ``edge_src_gate`` is ``None`` outside bridging @@ -397,6 +426,12 @@ class EnvironmentInitialEmbedding(NativeOP): Activation function for G network hidden layer. eps : float Small epsilon for numerical stability. + use_spin : list[bool] | None + Per-type spin flags (native spin scheme). When provided, the neighbor + spin is appended as extra coordinate channels of the environment matrix, + so the inner product ``D = M^T M`` additionally yields the neighbor + spin-spin invariants. A per-type mask gates the channel, so a + non-magnetic neighbor contributes zero and carries zero magnetic force. precision : str Parameter precision. trainable : bool @@ -418,6 +453,7 @@ def __init__( mlp_bias: bool = False, activation_function: str = "silu", eps: float = 1e-7, + use_spin: list[bool] | None = None, precision: str = DEFAULT_PRECISION, trainable: bool = True, seed: int | list[int] | None = None, @@ -438,8 +474,16 @@ def __init__( self.mlp_bias = bool(mlp_bias) self.activation_function = str(activation_function) self.eps = float(eps) + self.spin_flags = None if use_spin is None else [bool(x) for x in use_spin] + if self.spin_flags is not None and len(self.spin_flags) != int(ntypes): + raise ValueError("`use_spin` length must equal `ntypes`") self.precision = precision self.trainable = bool(trainable) + # The environment matrix carries the 4 geometric channels ``[s, s*r_hat]`` + # plus, for the native spin scheme, the 3 envelope-gated neighbor-spin + # components, so the inner product ``D = M^T M`` yields the neighbor + # spin-spin invariants alongside the geometric ones. + self.coord_dim = 4 + (3 if self.spin_flags is not None else 0) # === RBF projection: n_radial -> rbf_out_dim (two-layer MLP) === # rbf_out_dim = max(32, embed_dim - 2*type_dim) to align G-network width to embed_dim @@ -517,12 +561,29 @@ def __init__( dtype=PRECISION_DICT[self.precision.lower()], ) + # === Native spin: per-type mask and isotropic channel scale === + # The mask gates the neighbor-spin channel by source type, so a + # non-magnetic neighbor contributes zero and (critically) carries zero + # magnetic force ``-dE/ds``. The single scalar scale (shared across + # x/y/z) keeps the spin coordinates transforming with the geometry, so + # the env-matrix invariant stays SO(3)-invariant; ``output_proj`` is + # zero-initialized, so the spin contribution starts neutral regardless. + if self.spin_flags is not None: + self.spin_mask = np.array( + [1.0 if flag else 0.0 for flag in self.spin_flags], + dtype=PRECISION_DICT[self.precision.lower()], + ) + self.spin_scale = np.ones( + (1,), dtype=PRECISION_DICT[self.precision.lower()] + ) + def call( self, *, edge_cache: EdgeCache, atype_flat: Any, n_nodes: int, + spin: Any = None, ) -> Any: """ Compute environment FiLM logits for l=0 conditioning. @@ -535,6 +596,12 @@ def call( Flattened atom types with shape (N,), where N = nf * nloc. n_nodes : int Number of nodes (N = nf * nloc). + spin : Array | None + Per-node spin vectors with shape (N, 3) for the native spin scheme. + Used only when ``use_spin`` is set; the source (neighbor) spin is + appended to the environment matrix as an envelope-gated coordinate + channel. When ``None`` the spin channels are zero-padded so the + coordinate dimension stays fixed. Returns ------- @@ -556,6 +623,37 @@ def call( r_hat = edge_vec * inv_r # (E, 3) r_tilde = xp.concat([s, s * r_hat], axis=-1) # (E, 4) + # === Step 1b. Append neighbor spin as extra coordinate channels === + # The source (neighbor) spin enters the environment matrix gated by the + # same C^3 envelope as the geometry, so it decays smoothly at rcut and a + # non-magnetic neighbor (s_j = 0) contributes exactly zero. The linear + # form keeps the magnetic force continuous at s = 0. + if self.spin_flags is not None: + device = array_api_compat.device(edge_vec) + if spin is not None: + src_i = xp.astype(src, xp.int64) + spin_src = xp.astype( + xp.take(spin, src_i, axis=0), r_tilde.dtype + ) # (E, 3) + # Gate by source type: a non-magnetic neighbor must not enter + # the energy, so its magnetic force ``-dE/ds`` stays exactly zero. + spin_mask = xp_asarray_nodetach(xp, self.spin_mask[...], device=device) + mask = xp.take( + spin_mask, + xp.take(xp.astype(atype_flat, xp.int64), src_i, axis=0), + axis=0, + )[:, None] # (E, 1) + spin_scale = xp.astype( + xp_asarray_nodetach(xp, self.spin_scale[...], device=device), + r_tilde.dtype, + ) + spin_chan = edge_env * spin_scale * spin_src * mask # (E, 3) + else: + spin_chan = xp.zeros( + (r_tilde.shape[0], 3), dtype=r_tilde.dtype, device=device + ) + r_tilde = xp.concat([r_tilde, spin_chan], axis=-1) # (E, coord_dim) + # === Step 2. Compute G network input and output === # Use independent type embeddings (decoupled from main type embedding) atype_src = xp.take(atype_flat, xp.astype(src, xp.int64), axis=0) # (E,) @@ -574,8 +672,10 @@ def call( # === Step 3. Aggregate outer product by destination node === # outer = r_tilde[:, :, None] * g[:, None, :], einsum "ei,ej->eij". - outer = r_tilde[:, :, None] * g[:, None, :] # (E, 4, embed_dim) - outer_flat = xp.reshape(outer, (n_edge, 4 * self.embed_dim)) # (E, 4*embed_dim) + outer = r_tilde[:, :, None] * g[:, None, :] # (E, coord_dim, embed_dim) + outer_flat = xp.reshape( + outer, (n_edge, self.coord_dim * self.embed_dim) + ) # (E, coord_dim*embed_dim) # Source Freeze Propagation Gate: mute the outer-product contribution # of any edge whose source node has a neighbor in the frozen zone. src_gate = edge_cache.edge_src_gate @@ -595,14 +695,16 @@ def call( ) env_agg = xp_add_at( xp.zeros( - (n_nodes, 4 * self.embed_dim), + (n_nodes, self.coord_dim * self.embed_dim), dtype=outer_flat.dtype, device=array_api_compat.device(outer_flat), ), dst, outer_flat, - ) # (N, 4*embed_dim) - env_agg = xp.reshape(env_agg, (n_nodes, 4, self.embed_dim)) # (N, 4, embed_dim) + ) # (N, coord_dim*embed_dim) + env_agg = xp.reshape( + env_agg, (n_nodes, self.coord_dim, self.embed_dim) + ) # (N, coord_dim, embed_dim) # === Step 4. Smooth normalization by envelope-squared degree === # Reuse the cache's inverse-sqrt degree so the version-aware @@ -610,8 +712,11 @@ def call( env_agg = env_agg * xp.astype(edge_cache.inv_sqrt_deg, env_agg.dtype) # === Step 5. D matrix construction: D = env_agg^T @ env_agg[:,:,:axis_dim] === - env_agg_t = xp.permute_dims(env_agg, (0, 2, 1)) # (N, embed_dim, 4) - env_agg_axis = env_agg[:, :, : self.axis_dim] # (N, 4, axis_dim) + # Summing over the coordinate axis makes D invariant to a joint rotation + # of the geometry and the spin channels; with the spin channels present, + # D additionally carries the neighbor spin-spin invariants. + env_agg_t = xp.permute_dims(env_agg, (0, 2, 1)) # (N, embed_dim, coord_dim) + env_agg_axis = env_agg[:, :, : self.axis_dim] # (N, coord_dim, axis_dim) D = xp.matmul(env_agg_t, env_agg_axis) # (N, embed_dim, axis_dim) # === Step 6. Output projection for FiLM logits === @@ -637,6 +742,8 @@ def _variables(self) -> dict[str, np.ndarray]: variables["rbf_proj_layer2.bias"] = to_numpy_array(self.rbf_proj_layer2.b) variables["g_layer1.bias"] = to_numpy_array(self.g_layer1.b) variables["g_layer2.bias"] = to_numpy_array(self.g_layer2.b) + if self.spin_flags is not None: + variables["spin_scale"] = to_numpy_array(self.spin_scale) return variables def _load_variables(self, variables: dict[str, Any]) -> None: @@ -663,6 +770,8 @@ def _load_variables(self, variables: dict[str, Any]) -> None: ) self.g_layer1.b = np.asarray(variables["g_layer1.bias"], dtype=prec) self.g_layer2.b = np.asarray(variables["g_layer2.bias"], dtype=prec) + if self.spin_flags is not None: + self.spin_scale = np.asarray(variables["spin_scale"], dtype=prec) def serialize(self) -> dict[str, Any]: return { @@ -679,6 +788,7 @@ def serialize(self) -> dict[str, Any]: "mlp_bias": self.mlp_bias, "activation_function": self.activation_function, "eps": self.eps, + "use_spin": self.spin_flags, "precision": np.dtype(PRECISION_DICT[self.precision]).name, "trainable": self.trainable, "seed": None, @@ -836,3 +946,297 @@ def deserialize(cls, data: dict[str, Any]) -> ChargeSpinEmbedding: obj = cls(**config) obj._load_variables(variables) return obj + + +class SpinEmbedding(NativeOP): + """ + Per-atom spin embedding for the native spin scheme. + + The per-atom spin vector ``s`` is injected as an equivariant extension of + the type embedding, producing two additive contributions to the descriptor + node features: + + - **l = 0 (invariant):** a small network of the squared magnitude ``|s|^2`` + yields a per-channel scalar added to the scalar type embedding. The + squared magnitude is used (rather than ``|s|``) so the feature is smooth + at ``s = 0`` and its gradient there vanishes, keeping the magnetic force + continuous as a spin crosses zero. + - **l = 1 (equivariant):** the Cartesian spin vector is mapped to the packed + ``l = 1`` coefficients through the SeZM Wigner-D convention (derived from + :func:`build_cartesian_basis`), then scaled by a per-type per-channel + weight. The map is linear in ``s``, so the contribution vanishes at + ``s = 0`` and rotates as an ``l = 1`` object under SO(3), i.e. + ``cart_to_l1(R s) = D^1(R) cart_to_l1(s)``. + + Both contributions are gated by a per-type spin mask, so atom types without + spin contribute exactly zero regardless of their (nominally zero) input. + + Parameters + ---------- + ntypes + Number of (real) atom types. + channels + Number of channels per (l, m) coefficient. + use_spin + Per-type boolean flags marking which atom types carry spin. + activation_function + Activation used by the magnitude network. + precision + Parameter precision. + seed + Random seed for initialization. + trainable + Whether parameters are trainable. + """ + + def __init__( + self, + *, + ntypes: int, + channels: int, + use_spin: list[bool], + activation_function: str = "silu", + precision: str = DEFAULT_PRECISION, + seed: int | list[int] | None = None, + trainable: bool = True, + ) -> None: + self.ntypes = int(ntypes) + self.channels = int(channels) + self.activation_function = str(activation_function) + self.precision = precision + self.trainable = bool(trainable) + if self.ntypes <= 0: + raise ValueError("`ntypes` must be positive") + if self.channels <= 0: + raise ValueError("`channels` must be positive") + if len(use_spin) != self.ntypes: + raise ValueError("`use_spin` length must equal `ntypes`") + prec = PRECISION_DICT[self.precision.lower()] + self.spin_flags = [bool(flag) for flag in use_spin] + + # === Per-type spin gate === + # Non-persistent: rebuilt from config on construction and moved with the + # module, so the deterministic mask never enters the serialized state. + self.spin_mask = np.array( + [1.0 if bool(flag) else 0.0 for flag in use_spin], dtype=prec + ) + + # === Cartesian -> packed l=1 projection === + # Derived from the SeZM packed basis so a spin vector rotates with the + # same Wigner-D block as the geometry. Non-persistent constant. + self.cart_to_l1 = self._build_cart_to_l1_matrix() + + # === l=0 magnitude network: |s|^2 -> channels === + # The leading ``1 -> channels`` layer carries a singleton input + # dimension that HybridMuon routes to its Adam path automatically. + seed_scalar = child_seed(seed, 0) + self.mag_layer1 = NativeLayer( + 1, + self.channels, + bias=False, + activation_function=self.activation_function, + precision=self.precision, + seed=child_seed(seed_scalar, 0), + trainable=self.trainable, + ) + self.mag_layer2 = NativeLayer( + self.channels, + self.channels, + bias=False, + activation_function=None, + precision=self.precision, + seed=child_seed(seed_scalar, 1), + trainable=self.trainable, + ) + + # === l=1 per-type per-channel weight === + # ``adam_`` prefix routes the table to Adam in HybridMuon, matching the + # type-embedding treatment for per-type lookup parameters. + init_std = 1.0 / math.sqrt(float(self.ntypes + self.channels)) + rng_vec = np.random.default_rng(child_seed(seed, 1)) + self.adam_spin_vec_weight = rng_vec.normal( + 0.0, init_std, size=(self.ntypes, self.channels) + ).astype(prec) + + # === l=1 per-source-type per-channel weight for neighbor aggregation === + # Separate from the on-site weight: this scales the neighbor's spin + # direction before it is aggregated into the center node's l=1 seed. + rng_nbr = np.random.default_rng(child_seed(seed, 2)) + self.adam_spin_nbr_weight = rng_nbr.normal( + 0.0, init_std, size=(self.ntypes, self.channels) + ).astype(prec) + + def call(self, spin: Any, atype: Any) -> tuple[Any, Any]: + """ + Compute the l=0 and l=1 spin contributions. + + Parameters + ---------- + spin + Per-atom spin vectors with shape (N, 3). + atype + Per-atom types with shape (N,). + + Returns + ------- + tuple[Array, Array] + ``(scalar, vector)`` where ``scalar`` has shape (N, channels) for + the l=0 contribution and ``vector`` has shape (N, 3, channels) for + the packed l=1 contribution (orders m = -1, 0, +1). Both are exactly + zero for atom types without spin. + """ + xp = array_api_compat.array_namespace(spin) + device = array_api_compat.device(spin) + dtype = get_xp_precision(xp, self.precision) + spin = xp.astype(spin, dtype) + index = xp.astype(atype, xp.int64) + spin_mask = xp_asarray_nodetach(xp, self.spin_mask[...], device=device) + mask = xp.take(spin_mask, index, axis=0)[:, None] # (N, 1) + + # === l=0: smooth invariant magnitude embedding === + mag2 = xp.sum(spin * spin, axis=-1, keepdims=True) # (N, 1) + scalar = self.mag_layer2(self.mag_layer1(mag2)) * mask # (N, C) + + # === l=1: equivariant direction embedding (linear in spin) === + cart_to_l1 = xp.astype( + xp_asarray_nodetach(xp, self.cart_to_l1[...], device=device), dtype + ) + # einsum "dk,nk->nd" as a matmul against the transposed projection. + l1 = xp.matmul(spin, xp.permute_dims(cart_to_l1, (1, 0))) # (N, 3) + weight_table = xp_asarray_nodetach( + xp, self.adam_spin_vec_weight[...], device=device + ) + weight = xp.take(weight_table, index, axis=0) # (N, C) + vector = l1[:, :, None] * weight[:, None, :] * mask[:, :, None] # (N, 3, C) + + return scalar, vector + + def edge_l1( + self, + spin: Any, + atype: Any, + edge_cache: EdgeCache, + ) -> Any: + """ + Build the per-edge neighbor-spin l=1 message for the GIE aggregation. + + Each edge carries the packed ``l = 1`` coefficients of the source + (neighbor) spin, scaled by a per-source-type per-channel weight and + gated by the C^3 envelope. The message is returned per edge; the + geometric initial embedding folds it into the l=1 rows and applies the + shared source gate, scatter and degree normalization, so a neighbor's + spin direction enters an atom's l=1 backbone before any interaction + block (the spin analogue of the geometric initial embedding). + + Parameters + ---------- + spin + Per-node spin vectors with shape (N, 3). + atype + Per-node types with shape (N,). + edge_cache + Edge cache providing ``src`` and ``edge_env``. + + Returns + ------- + Array + Per-edge packed l=1 message with shape (E, 3, channels), exactly + zero for non-magnetic neighbors. + """ + xp = array_api_compat.array_namespace(spin) + device = array_api_compat.device(spin) + dtype = get_xp_precision(xp, self.precision) + spin = xp.astype(spin, dtype) + src = xp.astype(edge_cache.src, xp.int64) + spin_src = xp.take(spin, src, axis=0) # (E, 3) + atype_src = xp.take(xp.astype(atype, xp.int64), src, axis=0) # (E,) + + # Packed l=1 of the neighbor spin; the global-frame vector needs no + # Wigner-D rotation (it rotates with the geometry by construction). + cart_to_l1 = xp.astype( + xp_asarray_nodetach(xp, self.cart_to_l1[...], device=device), dtype + ) + # einsum "dk,ek->ed" as a matmul against the transposed projection. + l1 = xp.matmul(spin_src, xp.permute_dims(cart_to_l1, (1, 0))) # (E, 3) + weight_table = xp_asarray_nodetach( + xp, self.adam_spin_nbr_weight[...], device=device + ) + weight = xp.take(weight_table, atype_src, axis=0) # (E, C) + spin_mask = xp_asarray_nodetach(xp, self.spin_mask[...], device=device) + mask = xp.take(spin_mask, atype_src, axis=0) # (E,) + gate = edge_cache.edge_env * mask[:, None] # (E, 1) + return gate[:, :, None] * l1[:, :, None] * weight[:, None, :] # (E, 3, C) + + def _build_cart_to_l1_matrix(self) -> np.ndarray: + """ + Build the ``(3, 3)`` Cartesian-to-packed-``l=1`` projection. + + The packed ``l = 1`` coefficient of a vector ``v`` is obtained by + projecting the skew-symmetric matrix ``[v]_x`` onto the antisymmetric + ``l = 1`` block of :func:`build_cartesian_basis`. With packed order + ``m = -1, 0, +1``, row ``d`` and Cartesian component ``k`` give + ``M[d, k] = <[e_k]_x, B[1 + d]>_F``, so ``coeff = M @ v`` and + ``M @ (R v) = D^1(R) (M @ v)``. + """ + prec = PRECISION_DICT[self.precision.lower()] + basis_l1 = build_cartesian_basis(1, dtype=prec)[1:4] + # Skew (cross-product) matrices of the Cartesian unit vectors, following + # ``[v]_x w = v x w`` (matching ``build_edge_cartesian_tensors``). + skew_basis = np.zeros((3, 3, 3), dtype=prec) + skew_basis[0, 1, 2], skew_basis[0, 2, 1] = -1.0, 1.0 + skew_basis[1, 0, 2], skew_basis[1, 2, 0] = 1.0, -1.0 + skew_basis[2, 0, 1], skew_basis[2, 1, 0] = -1.0, 1.0 + return np.einsum("kij,dij->dk", skew_basis, basis_l1) + + def _variables(self) -> dict[str, np.ndarray]: + """Variables keyed by the pt ``state_dict`` key names.""" + return { + "mag_layer1.matrix": to_numpy_array(self.mag_layer1.w), + "mag_layer2.matrix": to_numpy_array(self.mag_layer2.w), + "adam_spin_vec_weight": to_numpy_array(self.adam_spin_vec_weight), + "adam_spin_nbr_weight": to_numpy_array(self.adam_spin_nbr_weight), + } + + def _load_variables(self, variables: dict[str, Any]) -> None: + """Load variables keyed by the pt ``state_dict`` key names.""" + prec = PRECISION_DICT[self.precision.lower()] + self.mag_layer1.w = np.asarray(variables["mag_layer1.matrix"], dtype=prec) + self.mag_layer2.w = np.asarray(variables["mag_layer2.matrix"], dtype=prec) + self.adam_spin_vec_weight = np.asarray( + variables["adam_spin_vec_weight"], dtype=prec + ) + self.adam_spin_nbr_weight = np.asarray( + variables["adam_spin_nbr_weight"], dtype=prec + ) + + def serialize(self) -> dict[str, Any]: + """Serialize the SpinEmbedding to a dict.""" + return { + "@class": "SpinEmbedding", + "@version": 1, + "config": { + "ntypes": self.ntypes, + "channels": self.channels, + "use_spin": self.spin_flags, + "activation_function": self.activation_function, + "precision": np.dtype(PRECISION_DICT[self.precision]).name, + "trainable": self.trainable, + "seed": None, + }, + "@variables": self._variables(), + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> SpinEmbedding: + """Deserialize a SpinEmbedding from a dict.""" + data = data.copy() + data_cls = data.pop("@class") + if data_cls != "SpinEmbedding": + raise ValueError(f"Invalid class for SpinEmbedding: {data_cls}") + version = int(data.pop("@version")) + check_version_compatibility(version, 1, 1) + config = data.pop("config") + variables = data.pop("@variables") + obj = cls(**config) + obj._load_variables(variables) + return obj diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py b/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py index 87338fa744..4c63d1a492 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py @@ -75,7 +75,7 @@ Callable, ) -GridNetLayout = Literal["ndfc", "nfdc", "flat"] +GridNetLayout = Literal["ndfc", "nfdc", "fndc", "flat"] GridNetMode = Literal["self", "cross"] GridNetOp = Literal["glu", "mlp", "branch"] @@ -655,8 +655,10 @@ def __init__( raise ValueError("`op_type` must be one of 'glu', 'mlp', or 'branch'") self.precision = precision self.layout = str(layout).lower() - if self.layout not in {"ndfc", "nfdc", "flat"}: - raise ValueError("`layout` must be one of 'ndfc', 'nfdc', or 'flat'") + if self.layout not in {"ndfc", "nfdc", "fndc", "flat"}: + raise ValueError( + "`layout` must be one of 'ndfc', 'nfdc', 'fndc', or 'flat'" + ) if self.mode == "self" and self.layout == "flat": raise ValueError("`layout='flat'` is only supported for cross grid nets") self.mlp_bias = bool(mlp_bias) @@ -927,11 +929,17 @@ def _from_grid(self, grid: Any) -> Any: ) def _to_ndfc(self, value: Any) -> tuple[Any, tuple[int, ...]]: + # All grid operations run in the canonical ``(N, D, F, C)`` layout; the + # ``fndc`` re-orientation folds the focus-major SO(2) mixing layout into the + # same transpose the ``nfdc`` path performs, so the grid compute below is + # identical regardless of the caller's layout. xp = array_api_compat.array_namespace(value) if self.layout == "ndfc": return value, tuple(value.shape) if self.layout == "nfdc": return xp.permute_dims(value, (0, 2, 1, 3)), tuple(value.shape) + if self.layout == "fndc": + return xp.permute_dims(value, (1, 2, 0, 3)), tuple(value.shape) n_batch, coeff_dim, _ = value.shape return ( xp.reshape(value, (n_batch, coeff_dim, self.n_focus, -1)), @@ -948,6 +956,8 @@ def _restore_layout( return value if self.layout == "nfdc": return xp.permute_dims(value, (0, 2, 1, 3)) + if self.layout == "fndc": + return xp.permute_dims(value, (2, 0, 1, 3)) n_batch, coeff_dim, _ = shape_info return xp.reshape(value, (n_batch, coeff_dim, -1)) diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/norm.py b/deepmd/dpmodel/descriptor/dpa4_nn/norm.py index 2424ac1e1b..f436f8c3dc 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/norm.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/norm.py @@ -390,20 +390,20 @@ def call(self, x: Any) -> Any: Parameters ---------- x - Input array with shape (E, F, D_m_trunc, C). + Input array with shape (F, E, D_m_trunc, C). Returns ------- Array - Normalized array with shape `(E, F, D_m_trunc, C)`, same dtype as + Normalized array with shape `(F, E, D_m_trunc, C)`, same dtype as input. """ xp = array_api_compat.array_namespace(x) device = array_api_compat.device(x) in_dtype = x.dtype x = xp.astype(x, get_xp_precision(xp, self.precision)) - x0 = x[:, :, :1, :] # (E, F, 1, C) - xt = x[:, :, 1:, :] # (E, F, D_m_trunc-1, C) + x0 = x[:, :, :1, :] # (F, E, 1, C) + xt = x[:, :, 1:, :] # (F, E, D_m_trunc-1, C) # === Step 1. Center the scalar slice === x0 = x0 - xp.mean(x0, axis=-1, keepdims=True) @@ -416,7 +416,7 @@ def call(self, x: Any) -> Any: (xt * xt) * balance_weight[1:][None, None, :, None], axis=(2, 3) ) inv_rms = 1.0 / xp.sqrt(mean_variance + self.eps) - inv_rms = inv_rms[:, :, None, None] # (E, F, 1, 1) + inv_rms = inv_rms[:, :, None, None] # (F, E, 1, 1) x0 = x0 * inv_rms if self.degree_index_m.size > 1: @@ -426,7 +426,7 @@ def call(self, x: Any) -> Any: adam_scale = xp_asarray_nodetach(xp, self.adam_scale[...], device=device) degree_index_m = xp_asarray_nodetach(xp, self.degree_index_m, device=device) expanded_scale = xp.take(adam_scale, degree_index_m, axis=1) - expanded_scale = expanded_scale[None, ...] # (1, F, D_m_trunc, C) + expanded_scale = expanded_scale[:, None, ...] # (F, 1, D_m_trunc, C) x0 = x0 * expanded_scale[:, :, :1, :] if self.degree_index_m.size > 1: xt = xt * expanded_scale[:, :, 1:, :] @@ -434,8 +434,8 @@ def call(self, x: Any) -> Any: # === Step 4. Add scalar bias and restore layout === bias0 = xp.reshape( xp_asarray_nodetach(xp, self.bias0[...], device=device), - (1, self.n_focus, 1, -1), - ) # (1, F, 1, C) + (self.n_focus, 1, 1, -1), + ) # (F, 1, 1, C) x0 = x0 + bias0 out = x0 if self.degree_index_m.size == 1 else xp.concat([x0, xt], axis=2) diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/so2.py b/deepmd/dpmodel/descriptor/dpa4_nn/so2.py index 6173d77e45..c0c4c0b262 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/so2.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/so2.py @@ -134,8 +134,12 @@ class SO2Linear(NativeOP): rotates the output by the same angle. The weight is assembled once per forward (training) or cached (eval) - by ``_build_so2_weight()``, then applied via a single batched matmul - over all focus streams: ``einsum("efi,foi->efo")``. + by ``_build_so2_weight()`` in the ``(D_m*Cin, F, D_m*Cout)`` layout, then + applied as a per-``|m|``-block batched matmul over the focus streams. The + activation is carried in the focus-major layout ``(F, E, D_m, Cf)`` so that + the focus stream is the batch axis of the matmul: the assembled weight is + presented as ``(F, D_m*Cin, D_m*Cout)`` (a transient view, never a stored + parameter) and each block contracts with no transpose of the edge axis. Parameters ---------- @@ -306,21 +310,22 @@ def call(self, x: Array) -> Array: Parameters ---------- x - Input with shape (E, F, D_m_trunc, Cin), where D_m_trunc is the + Input with shape (F, E, D_m_trunc, Cin), where F is the focus stream + (the matmul batch axis), E the edge count, and D_m_trunc the coefficient dimension of the m-major layout truncated by `mmax`. Returns ------- Array - Output with shape (E, F, D_m_trunc, Cout), where Cout is output channels. + Output with shape (F, E, D_m_trunc, Cout), where Cout is output channels. """ xp = array_api_compat.array_namespace(x) device = array_api_compat.device(x) - # === Step 1. Flatten coefficient + channel axes for matmul === - # (E, F, D_m, Cin) -> (E, F, D_m*Cin) - n_edge = x.shape[0] + # === Step 1. Flatten coefficient + channel axes for the matmul === + # (F, E, D_m, Cin) -> (F, E, D_m*Cin); the focus stream stays the batch axis. + n_focus, n_edge = x.shape[0], x.shape[1] in_dim_total = self.reduced_dim * self.in_channels - x_flat = xp.reshape(x, (n_edge, self.n_focus, in_dim_total)) + x_flat = xp.reshape(x, (n_focus, n_edge, in_dim_total)) # === Step 2. Get block-diagonal weight === weight = self._build_so2_weight(xp, device) @@ -328,7 +333,7 @@ def call(self, x: Array) -> Array: # === Step 3. Block-diagonal matmul over focus streams + reshape back === out_flat = self._block_diagonal_matmul(x_flat, weight) out = xp.reshape( - out_flat, (n_edge, self.n_focus, self.reduced_dim, self.out_channels) + out_flat, (n_focus, n_edge, self.reduced_dim, self.out_channels) ) # === Step 4. Bias on l=0 scalar index === @@ -338,7 +343,7 @@ def call(self, x: Array) -> Array: (self.n_focus, self.out_channels), ) out = xp.concat( - [out[:, :, :1, :] + bias0[None, :, None, :], out[:, :, 1:, :]], axis=2 + [out[:, :, :1, :] + bias0[:, None, None, :], out[:, :, 1:, :]], axis=2 ) return out @@ -362,7 +367,7 @@ def _build_block_diag_slices(self) -> list[tuple[int, int, int, int]]: out_off += out_width return slices - def _build_so2_weight(self, xp: Any, device: Any) -> Array: + def _build_so2_weight(self, xp: Any = None, device: Any = None) -> Array: """ Assemble the per-focus block-diagonal SO(2) weight matrix. @@ -370,11 +375,26 @@ def _build_so2_weight(self, xp: Any, device: Any) -> Array: where both axes follow the same m-major coefficient ordering. Off-diagonal blocks (cross-|m|) are zero, enforcing SO(2) equivariance. + Parameters + ---------- + xp : Any, optional + The array namespace. Derived from ``self.weight_m0`` when either + ``xp`` or ``device`` is ``None``. + device : Any, optional + The target device. Derived from ``self.weight_m0`` when either + ``xp`` or ``device`` is ``None``. + Returns ------- Array Weight with shape (D_m*Cin, F, D_m*Cout). """ + # The ``call`` hot path passes both explicitly; the kernel value-path + # factories invoke ``_build_so2_weight()`` with no args, so fall back to + # the stored weight for the namespace/device (numpy or torch buffer). + if xp is None or device is None: + xp = array_api_compat.array_namespace(self.weight_m0) + device = array_api_compat.device(self.weight_m0) in_total = self.reduced_dim * self.in_channels out_total = self.reduced_dim * self.out_channels num_in_m0 = (self.lmax + 1) * self.in_channels @@ -435,32 +455,38 @@ def _block_diagonal_matmul(self, x_flat: Array, weight: Array) -> Array: ``weight`` is block-diagonal over ``|m|`` (cross-``|m|`` blocks are exactly zero), so concatenating the per-group matmuls reproduces the - dense ``einsum`` over the full ``(D_m*Cin, D_m*Cout)`` matrix while + dense contraction over the full ``(D_m*Cin, D_m*Cout)`` matrix while skipping the structural zeros. The result is fp32-equivalent to the dense path up to the matmul reduction order. + The focus stream is the batch axis of the per-block ``bmm``: the input + already carries it as ``(F, E, .)`` and the assembled weight is presented + as ``(F, D_m*Cin, D_m*Cout)``, so no edge-axis transpose is needed on + either operand and each block writes directly into the concatenated + output. The weight view is transient (the stored parameters keep their + assembled ``(D_m*Cin, F, D_m*Cout)`` layout). + Parameters ---------- x_flat : Array - Flattened input with shape ``(E, F, D_m*Cin)``. + Flattened input with shape ``(F, E, D_m*Cin)``. weight : Array Assembled block-diagonal weight with shape ``(D_m*Cin, F, D_m*Cout)``. Returns ------- Array - Flattened output with shape ``(E, F, D_m*Cout)``. + Flattened output with shape ``(F, E, D_m*Cout)``. """ xp = array_api_compat.array_namespace(x_flat) + weight = xp.permute_dims(weight, (1, 0, 2)) # (F, D_m*Cin, D_m*Cout) blocks = [ - # einsum("efi,ifo->efo"): a per-focus matmul batched over the focus - # axis, contracting the input coefficient/channel index i. - xp.permute_dims( - xp.matmul( - xp.permute_dims(x_flat[:, :, in0:in1], (1, 0, 2)), - xp.permute_dims(weight[in0:in1, :, out0:out1], (1, 0, 2)), - ), - (1, 0, 2), + # torch.bmm(x_flat[:, :, in0:in1], weight[:, in0:in1, out0:out1]): a + # per-focus matmul batched over the focus axis (the leading batch axis + # of xp.matmul), contracting the input coefficient/channel index i. + xp.matmul( + x_flat[:, :, in0:in1], + weight[:, in0:in1, out0:out1], ) for in0, in1, out0, out1 in self._block_diag_slices ] @@ -1508,6 +1534,48 @@ def __init__( or self.radial_so2_mode != "none" or self.node_wise_grid_product is not None ) + + # === Step 12. Optional fused flash-attention aggregation seam === + # The fused path folds the entire ``n_atten_head > 0`` value aggregation -- + # block-diagonal rotate-back, inverse-rotation rescale, envelope-gated + # softmax weighting, and the destination scatter -- into a single + # destination-segmented kernel, removing the transient ``x_message`` and + # weighted-value edge tensors and the scatter-add round trip. The pure + # array-API reference has no such kernel, so it never runs the fused flash + # path: ``use_flash_atten`` is fixed to ``False`` and the kernel/row-ptr + # hooks stay ``None``. The ``pt_expt`` backend recomputes + # ``use_flash_atten`` (Triton availability AND the supported attention + # layout) and binds ``_flash_atten_fn`` / ``_build_row_ptr_fn`` plus a + # fused ``_flash_aggregate`` override. + self.use_flash_atten = False + self._flash_atten_fn = None + self._build_row_ptr_fn = None + # Layout-support half of pt's ``use_flash_atten`` predicate -- everything + # except the ``use_triton_infer`` gate. The fused kernel only engages for + # the ``mmax == 1`` attention layout without the optional focus-mix / + # value / output projections (the deployed DPA4 configuration). Stored so + # ``pt_expt`` can re-enable flash by ANDing this with its own + # Triton-availability check, without duplicating the long predicate. + self._flash_atten_layout_ok = ( + self.n_atten_head > 0 + and self.mmax == 1 + and self.needs_local_frame + and not self.edge_cartesian + and not self.atten_f_mix + and self.attn_v_proj is None + and self.attn_o_proj is None + and self.attn_focus_mix is None + ) + + # === Step 13. Optional fused SO(2) value-path seam === + # The fused value path folds the rotate-to-local projection, radial + # mixing, and the full SO(2) mixing stack into a single kernel, emitting + # the pre-rotate-back per-focus local features directly. The pure + # array-API reference has no such kernel, so it never runs the fused + # value path: ``_value_path`` stays ``None`` and ``so2_message`` takes the + # dense branch. The ``pt_expt`` backend binds ``make_triton_value_path`` / + # ``make_cute_value_path`` here. + self._value_path = None self.trainable = bool(trainable) def call( @@ -1545,7 +1613,17 @@ def call( # === Step 2. Edge message: Cartesian product, SO(2) mixing, or the # rotation-free radial message when no local-frame operation is needed === - if self.edge_cartesian: + # In the fused flash-attention path the SO(2) message returns the + # pre-rotate-back per-focus local features; the rotate-back is folded into + # the aggregation kernel (Step 4). + run_flash = self.use_flash_atten and not self.training + x_local_flash: Array | None = None + x_message: Array | None = None + if run_flash: + x_local_flash, rad_feat = self.so2_message( + x, edge_cache, radial_feat, return_local=True + ) + elif self.edge_cartesian: x_message, rad_feat = self.cartesian_message(x, edge_cache, radial_feat) elif self.needs_local_frame: x_message, rad_feat = self.so2_message(x, edge_cache, radial_feat) @@ -1650,91 +1728,116 @@ def call( edge_mask=edge_cache.edge_mask, ) # (E, F, H) - # === Step 4.3. Value projection and head-wise aggregation === - value_focus = xp.astype( - xp.reshape( - x_message, + if run_flash: + # === Step 4.3f. Fused rotate-back + envelope-softmax-weighted + # segment scatter. One destination-segmented kernel folds the + # block-diagonal rotate-back, the inverse-rotation rescale, the + # per-edge ``attn_alpha`` weighting, and the destination reduction + # into a single atomic-free pass, returning the ungated aggregate + # ``(N, D, C_wide)``. The transient rotate-back message and + # weighted value tensors are never materialized. + # === Step 4.4f. Output-side head gate (cheap node-level) === + # The pure array-API reference has no fused kernel; dpmodel folds + # both Step 4.3f and Step 4.4f into the overridable + # ``_flash_aggregate`` seam (default raises ``NotImplementedError``; + # ``pt_expt`` overrides it with the fused Triton kernel). This + # branch is never entered here because ``use_flash_atten`` is + # always ``False`` in the dpmodel reference. + out = self._flash_aggregate( + x_local_flash, + edge_cache, + attn_alpha, + x_l0_node, + n_node, + compute_dtype, + ) # (N, D, C_wide) + else: + # === Step 4.3. Value projection and head-wise aggregation === + value_focus = xp.astype( + xp.reshape( + x_message, + ( + n_edge, + self.ebed_dim_full, + self.attn_n_focus, + self.attn_focus_dim, + ), + ), + compute_dtype, + ) # (E, D, Fa, Ca) + if self.attn_v_proj is not None: + value_focus = self.attn_v_proj(value_focus) + value_heads = xp.reshape( + value_focus, ( n_edge, self.ebed_dim_full, self.attn_n_focus, - self.attn_focus_dim, + self.n_atten_head, + self.head_dim, ), - ), - compute_dtype, - ) # (E, D, Fa, Ca) - if self.attn_v_proj is not None: - value_focus = self.attn_v_proj(value_focus) - value_heads = xp.reshape( - value_focus, - ( - n_edge, - self.ebed_dim_full, - self.attn_n_focus, - self.n_atten_head, - self.head_dim, - ), - ) # (E, D, Fa, H, Ch) - weighted_value = value_heads * xp.reshape( - attn_alpha, (n_edge, 1, self.attn_n_focus, self.n_atten_head, 1) - ) - out_heads = xp.zeros( - ( - n_node, - self.ebed_dim_full, - self.attn_n_focus, - self.n_atten_head, - self.head_dim, - ), - dtype=compute_dtype, - device=device, - ) # (N, D, Fa, H, Ch) - out_heads = xp_add_at(out_heads, dst, weighted_value) - - # === Step 4.4. Output-side head gate === - # "nfi,ifo->nfo": per-focus contraction over the input channel, - # expressed as a batched matmul over the focus axis. - attn_output_gate = xp_sigmoid( - xp.permute_dims( - xp.matmul( - xp.permute_dims( - self.attn_output_gate_norm( - xp.astype(x_l0_node, compute_dtype) + ) # (E, D, Fa, H, Ch) + weighted_value = value_heads * xp.reshape( + attn_alpha, (n_edge, 1, self.attn_n_focus, self.n_atten_head, 1) + ) + out_heads = xp.zeros( + ( + n_node, + self.ebed_dim_full, + self.attn_n_focus, + self.n_atten_head, + self.head_dim, + ), + dtype=compute_dtype, + device=device, + ) # (N, D, Fa, H, Ch) + out_heads = xp_add_at(out_heads, dst, weighted_value) + + # === Step 4.4. Output-side head gate === + # "nfi,ifo->nfo": per-focus contraction over the input channel, + # expressed as a batched matmul over the focus axis. + attn_output_gate = xp_sigmoid( + xp.permute_dims( + xp.matmul( + xp.permute_dims( + self.attn_output_gate_norm( + xp.astype(x_l0_node, compute_dtype) + ), + (1, 0, 2), ), - (1, 0, 2), - ), - xp.permute_dims( - xp_asarray_nodetach( - xp, self.adamw_attn_gate_w[...], device=device + xp.permute_dims( + xp_asarray_nodetach( + xp, self.adamw_attn_gate_w[...], device=device + ), + (1, 0, 2), ), - (1, 0, 2), ), + (1, 0, 2), + ) + ) # (N, F, H) + out_heads = out_heads * xp.reshape( + attn_output_gate, + (n_node, 1, self.attn_n_focus, self.n_atten_head, 1), + ) # (N, D, Fa, H, Ch) + + # === Step 4.5. Output projection and merge heads === + out_focus = xp.reshape( + out_heads, + ( + n_node, + self.ebed_dim_full, + self.attn_n_focus, + self.attn_focus_dim, ), - (1, 0, 2), - ) - ) # (N, F, H) - out_heads = out_heads * xp.reshape( - attn_output_gate, (n_node, 1, self.attn_n_focus, self.n_atten_head, 1) - ) # (N, D, Fa, H, Ch) - - # === Step 4.5. Output projection and merge heads === - out_focus = xp.reshape( - out_heads, - ( - n_node, - self.ebed_dim_full, - self.attn_n_focus, - self.attn_focus_dim, - ), - ) # (N, D, Fa, Ca) - if self.attn_o_proj is not None: - out_focus = self.attn_o_proj(out_focus) - out = xp.astype( - xp.reshape( - out_focus, (n_node, self.ebed_dim_full, self.hidden_channels) - ), - get_xp_precision(xp, self.precision), - ) # (N, D, C_wide) + ) # (N, D, Fa, Ca) + if self.attn_o_proj is not None: + out_focus = self.attn_o_proj(out_focus) + out = xp.astype( + xp.reshape( + out_focus, (n_node, self.ebed_dim_full, self.hidden_channels) + ), + get_xp_precision(xp, self.precision), + ) # (N, D, C_wide) # === Step 5. Optional message-node grid product === if self.message_node_grid_product is not None: @@ -1750,6 +1853,69 @@ def call( out = self.post_focus_mix(out[:, :, None, :])[:, :, 0, :] return out # (N, D, C) + def _flash_aggregate( + self, + x_local_flash: Array, + edge_cache: EdgeCache, + attn_alpha: Array, + x_l0_node: Array, + n_node: int, + compute_dtype: Any, + ) -> Array: + """ + Fused flash-attention value aggregation seam (overridable). + + Folds Step 4.3f and Step 4.4f of ``call`` -- the block-diagonal + rotate-back, the inverse-rotation degree rescale, the per-edge + envelope-gated softmax weighting, the destination reduction, and the + output-side head gate -- into a single destination-segmented pass that + returns the fully gated aggregate ``(N, D, C_wide)``. + + The pure array-API reference has no fused kernel, so it never enters this + path (``use_flash_atten`` is always ``False``) and this default + implementation raises. The ``pt_expt`` backend overrides this method with + the fused Triton flash-attention kernel, drawing on ``self._flash_atten_fn`` + / ``self._build_row_ptr_fn`` (bound when it re-enables ``use_flash_atten``), + ``edge_cache.Dt_full`` for the rotate-back, ``self.rotate_inv_rescale_full`` + for the degree rescale, and ``self.lmax`` / ``self.n_atten_head`` for the + block addressing. + + Parameters + ---------- + x_local_flash : Array + Pre-rotate-back per-focus local features with shape (E, F, D_m, Cf), + as returned by ``so2_message(..., return_local=True)``. + edge_cache : EdgeCache + Precomputed edge cache; supplies ``Dt_full`` (the block-diagonal + inverse rotation) and ``dst`` (the destination scatter index). + attn_alpha : Array + Envelope-gated softmax attention weights with shape (E, F, H). + x_l0_node : Array + Destination-node l=0 scalar features with shape (N, Fa, Ca), consumed + by the output-side head gate. + n_node : int + Number of nodes N. + compute_dtype + Compute-precision dtype for the aggregation. + + Returns + ------- + Array + The gated aggregate message with shape (N, D, C_wide). + + Raises + ------ + NotImplementedError + Always, in the dpmodel reference: the fused flash path is never taken + because ``use_flash_atten`` is ``False``. The ``pt_expt`` backend + provides the fused kernel implementation. + """ + raise NotImplementedError( + "The fused flash-attention aggregation is not implemented in the " + "dpmodel (array-API) reference; the pt_expt backend overrides " + "`_flash_aggregate` with the fused Triton kernel." + ) + def radial_message( self, x: Array, @@ -1823,6 +1989,7 @@ def so2_message( x: Array, edge_cache: EdgeCache, radial_feat: Array, + return_local: bool = False, ) -> tuple[Array, Array]: """ Build edge messages by rotate-to-local, SO(2) mixing, and rotate-back. @@ -1835,161 +2002,221 @@ def so2_message( Precomputed edge cache. radial_feat : Array Per-edge radial features with shape (E, lmax+1, C). + return_local : bool + If True, return the pre-rotate-back per-focus local features + ``(E, F, D_m, Cf)`` instead of the rotated-back message. Used by the + fused flash-attention aggregation, which folds the rotate-back into + its own kernel. Returns ------- tuple[Array, Array] ``(x_message, rad_feat)`` with shapes (E, D, C_wide) and - (E, D_m, C_wide). The ``l=0`` slice of ``rad_feat`` is consumed by - the attention aggregation. + (E, D_m, C_wide) by default, or ``(x_local, rad_feat)`` with + ``x_local`` of shape (E, F, D_m, Cf) when ``return_local`` is True. + The ``l=0`` slice of ``rad_feat`` is consumed by the attention + aggregation. """ xp = array_api_compat.array_namespace(x) device = array_api_compat.device(x) src = edge_cache.src n_edge = src.shape[0] - # === Step 1. Rotate to edge-aligned local frame === - x_local, x_dst_local = self._rotate_to_local(x, edge_cache) - - # === Step 2. Select radial/type features for reduced layout === - rad_feat = xp.take( - radial_feat, - xp_asarray_nodetach(xp, self.degree_index_m[...], device=device), - axis=1, - ) # (E, D_m, C) - if self.radial_hidden_proj is not None: - rad_feat = self.radial_hidden_proj(rad_feat) - if self.radial_degree_mixer is None: - x_local = x_local * rad_feat + # The fused value path (bound only by the ``pt_expt`` backend) folds + # the dense Steps 1-5 into a single kernel, returning the same + # pre-rotate-back ``(E, F, D_m, Cf)`` local features and reduced + # ``rad_feat`` that the dense exit produces, so the shared tail below + # is agnostic to which branch ran. + if self._value_path is not None and not self.training: + x_local, rad_feat = self._value_path(x, edge_cache, radial_feat) else: - x_local = self.radial_degree_mixer(x_local, rad_feat) - if self.node_wise_grid_product is not None: - x_local = x_local + self.node_wise_grid_product( - x_local, - x_dst_local, - ) - rad_feat_l0_focus = xp.reshape( - rad_feat[:, 0, :], (n_edge, self.n_focus, self.so2_focus_dim) - ) # (E, F, Cf) - - # === Step 3. Convert to SO(2) internal focus layout === - focus_gate_src: Array | None = None - x_local = xp.permute_dims( - xp.reshape( - x_local, (n_edge, self.reduced_dim, self.n_focus, self.so2_focus_dim) - ), - (0, 2, 1, 3), - ) # (E, F, D_m, Cf), strided - if self.focus_compete and self.n_focus > 1: - focus_gate_src = x_local[:, :, 0, :] - - # === Step 4. Multi-layer SO(2) mixing (pre-norm + residual) === - - def so2_l0_extractor(v: Array) -> Array: - """Extract scalar features from SO(2) reduced layout.""" - return xp.reshape(v[:, :, 0, :], (v.shape[0], self.hidden_channels)) - - def apply_bias_correction( - x_local: Array, - so2_linear: SO2Linear, - layer_idx: int, - ) -> Array: - if layer_idx != 0 or so2_linear.bias0 is None: - return x_local - bias0 = xp.reshape( - xp_asarray_nodetach(xp, so2_linear.bias0[...], device=device), - (self.n_focus, so2_linear.out_channels), - )[None, ...] - if so2_linear.out_channels == self.so2_focus_dim: - radial_factor = rad_feat_l0_focus - elif so2_linear.out_channels == 2 * self.so2_focus_dim: - radial_factor = xp.concat( - [rad_feat_l0_focus, rad_feat_l0_focus], axis=-1 - ) + # === Step 1. Rotate to edge-aligned local frame === + x_local, x_dst_local = self._rotate_to_local(x, edge_cache) + + # === Step 2. Select radial/type features for reduced layout === + rad_feat = xp.take( + radial_feat, + xp_asarray_nodetach(xp, self.degree_index_m[...], device=device), + axis=1, + ) # (E, D_m, C) + if self.radial_hidden_proj is not None: + rad_feat = self.radial_hidden_proj(rad_feat) + if self.radial_degree_mixer is None: + x_local = x_local * rad_feat else: - raise RuntimeError( - "Unexpected SO2Linear output width in bias correction" + x_local = self.radial_degree_mixer(x_local, rad_feat) + if self.node_wise_grid_product is not None: + x_local = x_local + self.node_wise_grid_product( + x_local, + x_dst_local, ) - bias_correction = bias0 * ( - radial_factor * xp.reshape(edge_cache.edge_env, (-1, 1, 1)) - 1.0 - ) - x_local = xp.concat( - [ - x_local[:, :, :1, :] + bias_correction[:, :, None, :], - x_local[:, :, 1:, :], - ], - axis=2, - ) - return x_local + rad_feat_l0_focus = xp.reshape( + rad_feat[:, 0, :], (n_edge, self.n_focus, self.so2_focus_dim) + ) # (E, F, Cf) - if self.use_so2_attn_res: - so2_depth_sources = [x_local] - for layer_idx, (so2_linear, inter_norm, non_linear) in enumerate( - zip( - self.so2_linears, - self.so2_inter_norms, - self.non_linearities, - strict=True, + # === Step 3. Cast to the focus-major SO(2) mixing layout (F, E, D_m, Cf) === + # The mixing stack runs with the focus stream on the batch axis, the native + # layout of the block-diagonal batched matmul: the per-focus linear consumes + # it with no edge-axis transpose and writes each ``|m|`` block with no + # reassembly cost. This is a strided view of the reduced global buffer, + # materialized by the first linear's reshape exactly as any reduced-layout + # view would be. + focus_gate_src: Array | None = None + x_local = xp.permute_dims( + xp.reshape( + x_local, + (n_edge, self.reduced_dim, self.n_focus, self.so2_focus_dim), + ), + (2, 0, 1, 3), + ) # (F, E, D_m, Cf), strided view + if self.focus_compete and self.n_focus > 1: + focus_gate_src = x_local[:, :, 0, :] # (F, E, Cf) + + # === Step 4. Multi-layer SO(2) mixing (pre-norm + residual) === + + def so2_l0_extractor(v: Array) -> Array: + """Extract scalar features from the edge-major layout (E, F, D_m, Cf).""" + return xp.reshape(v[:, :, 0, :], (v.shape[0], self.hidden_channels)) + + def apply_bias_correction( + x_local: Array, + so2_linear: SO2Linear, + layer_idx: int, + ) -> Array: + if layer_idx != 0 or so2_linear.bias0 is None: + return x_local + if so2_linear.out_channels == self.so2_focus_dim: + radial_factor = rad_feat_l0_focus + elif so2_linear.out_channels == 2 * self.so2_focus_dim: + radial_factor = xp.concat( + [rad_feat_l0_focus, rad_feat_l0_focus], axis=-1 + ) + else: + raise RuntimeError( + "Unexpected SO2Linear output width in bias correction" + ) + # Focus-major broadcast: bias0 (F, Cout), the radial l=0 factor + # (E, F, .) transposed to (F, E, .), the per-edge envelope over the + # edge axis, applied to the l=0 scalar slice (F, E, Cout). + bias0 = xp.reshape( + xp_asarray_nodetach(xp, so2_linear.bias0[...], device=device), + (self.n_focus, so2_linear.out_channels), ) - ): - x_local: Array = self.so2_layer_attn_res[layer_idx]( - sources=so2_depth_sources, - scalar_extractor=so2_l0_extractor, - current_x=x_local, + radial_factor = xp.permute_dims(radial_factor, (1, 0, 2)) # (F, E, .) + bias_correction = bias0[:, None, :] * ( + radial_factor * xp.reshape(edge_cache.edge_env, (1, -1, 1)) - 1.0 ) - residual = x_local - x_local = inter_norm(x_local) - x_local = so2_linear(x_local) - x_local = apply_bias_correction(x_local, so2_linear, layer_idx) - - x_local = non_linear(x_local) + x_local = xp.concat( + [ + x_local[:, :, :1, :] + bias_correction[:, :, None, :], + x_local[:, :, 1:, :], + ], + axis=2, + ) + return x_local - if self.layer_scale: - scale: Array = xp.reshape( - xp_asarray_nodetach( - xp, - self.adam_so2_layer_scales[layer_idx][...], - device=device, - ), - (1, self.n_focus, 1, self.so2_focus_dim), + if self.use_so2_attn_res: + # The depth-attention residual is a per-edge reduction over the + # layer history (``DepthAttnRes`` batches on axis 0), so the history + # is kept in the edge-major orientation and each mixing step + # transposes into the focus-major layout for the linear. + so2_depth_sources = [ + xp.permute_dims(x_local, (1, 0, 2, 3)), # (E, F, D_m, Cf) + ] + for layer_idx, (so2_linear, inter_norm, non_linear) in enumerate( + zip( + self.so2_linears, + self.so2_inter_norms, + self.non_linearities, + strict=True, ) - x_local = residual + scale * x_local - else: - x_local = residual + x_local - so2_depth_sources.append(x_local - residual) - else: - for layer_idx, (so2_linear, inter_norm, non_linear) in enumerate( - zip( - self.so2_linears, - self.so2_inter_norms, - self.non_linearities, - strict=True, - ) - ): - residual = x_local - x_local = inter_norm(x_local) - x_local = so2_linear(x_local) - x_local = apply_bias_correction(x_local, so2_linear, layer_idx) + ): + x_edge: Array = self.so2_layer_attn_res[layer_idx]( + sources=so2_depth_sources, + scalar_extractor=so2_l0_extractor, + current_x=xp.permute_dims(x_local, (1, 0, 2, 3)), + ) + x_local = xp.permute_dims(x_edge, (1, 0, 2, 3)) # (F, E, D_m, Cf) + residual = x_local + x_local = inter_norm(x_local) + x_local = so2_linear(x_local) + x_local = apply_bias_correction(x_local, so2_linear, layer_idx) - x_local = non_linear(x_local) + x_local = non_linear(x_local) - if self.layer_scale: - scale = xp.reshape( - xp_asarray_nodetach( - xp, - self.adam_so2_layer_scales[layer_idx][...], - device=device, - ), - (1, self.n_focus, 1, self.so2_focus_dim), + if self.layer_scale: + scale: Array = xp.reshape( + xp_asarray_nodetach( + xp, + self.adam_so2_layer_scales[layer_idx][...], + device=device, + ), + (self.n_focus, 1, 1, self.so2_focus_dim), + ) + x_local = residual + scale * x_local + else: + x_local = residual + x_local + so2_depth_sources.append( + xp.permute_dims(x_local - residual, (1, 0, 2, 3)) ) - x_local = residual + scale * x_local - else: - x_local = residual + x_local + else: + for layer_idx, (so2_linear, inter_norm, non_linear) in enumerate( + zip( + self.so2_linears, + self.so2_inter_norms, + self.non_linearities, + strict=True, + ) + ): + residual = x_local + x_local = inter_norm(x_local) + x_local = so2_linear(x_local) + x_local = apply_bias_correction(x_local, so2_linear, layer_idx) - # === Step 5. Cross-focus softmax competition === - if self.focus_compete and self.n_focus > 1: - alpha = self._focus_alpha(focus_gate_src) - x_local = x_local * xp.astype(alpha, x_local.dtype)[..., None, None] + x_local = non_linear(x_local) + + if self.layer_scale: + scale = xp.reshape( + xp_asarray_nodetach( + xp, + self.adam_so2_layer_scales[layer_idx][...], + device=device, + ), + (self.n_focus, 1, 1, self.so2_focus_dim), + ) + x_local = residual + scale * x_local + else: + x_local = residual + x_local + + # === Step 5. Cross-focus softmax competition === + if self.focus_compete and self.n_focus > 1: + # ``_focus_alpha`` is shared with the rotation-free radial and Cartesian + # messages in the edge-major (E, F) orientation; feed it the transposed + # view of the focus-major scalar and broadcast the weights back over the + # focus-major activation. + alpha = self._focus_alpha( + xp.permute_dims(focus_gate_src, (1, 0, 2)) + ) # (E, F) + x_local = ( + x_local + * xp.astype(xp.permute_dims(alpha, (1, 0)), x_local.dtype)[ + ..., None, None + ] + ) + + # === Exit. Restore the (E, F, D_m, Cf) orientation === + # Both the fused flash-attention aggregation kernel and the rotate-back + # consume this orientation through explicit strides, so the focus-major + # buffer is handed back as a view with no copy. + x_local = xp.permute_dims( + x_local, (1, 0, 2, 3) + ) # (E, F, D_m, Cf), strided view + + # The fused flash-attention aggregation consumes the per-focus + # ``(E, F, D_m, Cf)`` local layout directly and performs the rotate-back + # inside its kernel, so return before the standalone rotate-back. + if return_local: + return x_local, rad_feat # === Step 6. Rotate back to global frame === x_message = self._rotate_back(x_local, edge_cache, n_edge) @@ -2263,6 +2490,11 @@ def _build_so2_mixing( self.so2_inter_norms = inter_norms # === Step 5. Intermediate non-linearity (the last layer stays linear) === + # Both branches run inside the focus-major SO(2) mixing layout, so they are + # built with ``layout="fndc"``: the ``S2GridNet`` activation (an S2 or + # SO(3) grid GLU per its grid configuration) folds the focus-major + # re-orientation into its coefficient/grid transpose, and the coefficient + # ``GatedActivation`` projects its per-focus gate in the same layout. non_linearities: list[NativeOP] = [] for i in range(self.mixing_layers): if i >= self.mixing_layers - 1: @@ -2277,7 +2509,7 @@ def _build_so2_mixing( mode="self", op_type="glu", precision=self.compute_precision, - layout="nfdc", + layout="fndc", grid_resolution_list=self.s2_grid_resolution, coefficient_layout="m_major", grid_method=self.s2_grid_method, @@ -2296,7 +2528,7 @@ def _build_so2_mixing( precision=self.compute_precision, activation_function=self.activation_function, mlp_bias=self.mlp_bias, - layout="nfdc", + layout="fndc", trainable=trainable, seed=child_seed(seed_non_linearities, i), ) diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/wignerd.py b/deepmd/dpmodel/descriptor/dpa4_nn/wignerd.py index 42f99e2275..ed39cb269e 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/wignerd.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/wignerd.py @@ -979,29 +979,35 @@ def _precompute_powers( q: Any, max_power: int, ) -> Any: - """Precompute powers ``q_i^k`` as a dense table with shape ``(4, max_power+1, E)``.""" + """Precompute powers ``q_i^k`` as a dense table with shape ``(4, max_power+1, E)``. + + The table is built by an explicit multiply chain: a ``cumprod`` over + the short power axis lowers to a scan whose forward and leave-one-out + backward cost several milliseconds per model call at typical edge + counts, whereas the unrolled chain stays a fusable pointwise sequence. + """ xp = array_api_compat.array_namespace(q) - device = array_api_compat.device(q) - n_edge = q.shape[0] components = xp.permute_dims(q, (1, 0)) - ones = xp.ones((4, n_edge), dtype=q.dtype, device=device) + ones = xp.ones_like(components) if max_power == 0: - return xp.reshape(ones, (4, 1, n_edge)) - # Cumulative products built by iterated multiplication (``max_power`` is a - # compile-time constant, so the unrolled loop is export-friendly). - levels = [ones] - acc = ones - for _ in range(max_power): - acc = acc * components - levels.append(acc) - return xp.stack(levels, axis=1) + return ones[:, None, :] + powers = [ones, components] + for _ in range(max_power - 1): + powers.append(powers[-1] * components) + return xp.stack(powers, axis=1) @staticmethod def _build_monomial_matrix( powers: Any, monomial_exponents: Any, ) -> Any: - """Assemble the monomial design matrix for one fixed degree by gather/prod.""" + """Assemble the monomial design matrix for one fixed degree. + + The four gathered factor rows are combined by explicit multiplies: + ``prod(dim=0)`` lowers to a ``cumprod`` scan pair (forward plus + leave-one-out backward) on the large ``(4, M, E)`` intermediate, + while two multiply levels keep the chain pointwise and fusable. + """ xp = array_api_compat.array_namespace(powers) n_mono = monomial_exponents.shape[0] n_edge = powers.shape[-1] @@ -1010,7 +1016,25 @@ def _build_monomial_matrix( (4, n_mono, n_edge), ) selected = xp_take_along_axis(powers, gather_idx, axis=1) - return xp.permute_dims(xp.prod(selected, axis=0), (1, 0)) + product = (selected[0] * selected[1]) * (selected[2] * selected[3]) + return xp.permute_dims(product, (1, 0)) + + def _monomial_matrix( + self, + edge_quaternion: Any, + exp_name: str, + max_power: int, + ) -> Any: + """Evaluate one degree kernel's monomial basis via the dense power-table chain.""" + xp = array_api_compat.array_namespace(edge_quaternion) + device = array_api_compat.device(edge_quaternion) + powers = self._precompute_powers(edge_quaternion, max_power) + return self._build_monomial_matrix( + powers, + xp_asarray_nodetach( + xp, getattr(self.small_order_kernels, exp_name), device=device + ), + ) def _compute_l1_block(self, edge_quaternion: Any) -> Any: """Compute the vector block directly from the Cartesian rotation matrix.""" @@ -1048,11 +1072,7 @@ def _compute_l3_block(self, edge_quaternion: Any) -> Any: xp = array_api_compat.array_namespace(edge_quaternion) device = array_api_compat.device(edge_quaternion) n_edge = edge_quaternion.shape[0] - powers = self._precompute_powers(edge_quaternion, 6) - monomials = self._build_monomial_matrix( - powers, - xp_asarray_nodetach(xp, self.small_order_kernels.exp_l3, device=device), - ) + monomials = self._monomial_matrix(edge_quaternion, "exp_l3", 6) c = xp_asarray_nodetach( xp, self.small_order_kernels.C_l3, @@ -1070,11 +1090,7 @@ def _compute_l3l4_blocks( xp = array_api_compat.array_namespace(edge_quaternion) device = array_api_compat.device(edge_quaternion) n_edge = edge_quaternion.shape[0] - powers = self._precompute_powers(edge_quaternion, 8) - monomials = self._build_monomial_matrix( - powers, - xp_asarray_nodetach(xp, self.small_order_kernels.exp_l4, device=device), - ) + monomials = self._monomial_matrix(edge_quaternion, "exp_l4", 8) c = xp_asarray_nodetach( xp, self.small_order_kernels.C_combined_l3l4, @@ -1091,11 +1107,7 @@ def _compute_l5_block(self, edge_quaternion: Any) -> Any: xp = array_api_compat.array_namespace(edge_quaternion) device = array_api_compat.device(edge_quaternion) n_edge = edge_quaternion.shape[0] - powers = self._precompute_powers(edge_quaternion, 10) - monomials = self._build_monomial_matrix( - powers, - xp_asarray_nodetach(xp, self.small_order_kernels.exp_l5, device=device), - ) + monomials = self._monomial_matrix(edge_quaternion, "exp_l5", 10) c = xp_asarray_nodetach( xp, self.small_order_kernels.C_l5, @@ -1113,11 +1125,7 @@ def _compute_l5l6_blocks( xp = array_api_compat.array_namespace(edge_quaternion) device = array_api_compat.device(edge_quaternion) n_edge = edge_quaternion.shape[0] - powers = self._precompute_powers(edge_quaternion, 12) - monomials = self._build_monomial_matrix( - powers, - xp_asarray_nodetach(xp, self.small_order_kernels.exp_l6, device=device), - ) + monomials = self._monomial_matrix(edge_quaternion, "exp_l6", 12) c = xp_asarray_nodetach( xp, self.small_order_kernels.C_combined_l5l6, diff --git a/deepmd/dpmodel/train/validation.py b/deepmd/dpmodel/train/validation.py index 1f567362f0..d4a36a216a 100644 --- a/deepmd/dpmodel/train/validation.py +++ b/deepmd/dpmodel/train/validation.py @@ -33,20 +33,15 @@ resolve_full_validation_start_step, ) from deepmd.utils.eval_metrics import ( - FULL_VALIDATION_METRIC_FAMILY_BY_KEY, - FULL_VALIDATION_METRIC_KEY_MAP, + ENERGY_FULL_VALIDATION_PROFILE, ) log = logging.getLogger(__name__) -LOG_COLUMN_ORDER = [ - ("E_MAE", "mae_e_per_atom"), - ("E_RMSE", "rmse_e_per_atom"), - ("F_MAE", "mae_f"), - ("F_RMSE", "rmse_f"), - ("V_MAE", "mae_v_per_atom"), - ("V_RMSE", "rmse_v_per_atom"), -] +# The backend-independent validator drives only energy-type models: the JAX +# backend has no spin model, so it is bound to the energy full-validation +# profile. Spin support is decided by the per-backend profile selection. +FULL_VALIDATION_PROFILE = ENERGY_FULL_VALIDATION_PROFILE TOPK_RECORDS_INFO_KEY = "full_validation_topk_records" BEST_METRIC_NAME_INFO_KEY = "full_validation_metric" @@ -61,11 +56,6 @@ VAL_LOG_COLUMN_GAP = " " VAL_LOG_HEADER_PREFIX = "# " VAL_LOG_DATA_PREFIX = " " -METRIC_LOG_UNIT_MAP = { - "e": ("meV/atom", 1000.0), - "f": ("meV/Å", 1000.0), - "v": ("meV/atom", 1000.0), -} @dataclass(frozen=True) @@ -120,15 +110,14 @@ def resolve_best_checkpoint_dir( def parse_validation_metric(metric: str) -> tuple[str, str]: """Parse the configured full validation metric.""" normalized_metric = normalize_full_validation_metric(metric) - if normalized_metric not in FULL_VALIDATION_METRIC_KEY_MAP: - supported_metrics = ", ".join( - item.upper() for item in FULL_VALIDATION_METRIC_KEY_MAP - ) + metric_key_map = FULL_VALIDATION_PROFILE.metric_key_map + if normalized_metric not in metric_key_map: + supported_metrics = ", ".join(item.upper() for item in metric_key_map) raise ValueError( "validating.validation_metric must be one of " f"{supported_metrics}, got {metric!r}." ) - return normalized_metric, FULL_VALIDATION_METRIC_KEY_MAP[normalized_metric] + return normalized_metric, metric_key_map[normalized_metric] def format_metric_for_log( @@ -136,7 +125,7 @@ def format_metric_for_log( ) -> tuple[str, float, str]: """Format a full validation metric for user-facing logging.""" metric_family, metric_kind = metric_name.split(":") - metric_unit, metric_scale = METRIC_LOG_UNIT_MAP[metric_family] + metric_unit, metric_scale = FULL_VALIDATION_PROFILE.unit_by_family[metric_family] metric_label = f"{metric_family.upper()}:{metric_kind.upper()}" return metric_label, metric_value * metric_scale, metric_unit @@ -145,10 +134,10 @@ def format_metric_value_for_table( metric_key: str, metric_value: float ) -> tuple[float, str]: """Format one table metric value and its unit for `val.log`.""" - metric_family = FULL_VALIDATION_METRIC_FAMILY_BY_KEY.get(metric_key) + metric_family = FULL_VALIDATION_PROFILE.metric_family_by_key.get(metric_key) if metric_family is None: raise ValueError(f"Unknown full validation metric key: {metric_key}") - metric_unit, metric_scale = METRIC_LOG_UNIT_MAP[metric_family] + metric_unit, metric_scale = FULL_VALIDATION_PROFILE.unit_by_family[metric_family] return metric_value * metric_scale, metric_unit @@ -196,6 +185,7 @@ def __init__( stale_state_keys: tuple[str, ...] = STALE_FULL_VALIDATION_INFO_KEYS, emit_best_save_log: bool = True, ) -> None: + self.profile = FULL_VALIDATION_PROFILE self.state_store = state_store self.rank = rank self.checkpoint_dir = ( @@ -242,7 +232,7 @@ def __init__( restart_training and self.full_val_file.exists() ) self.table_column_specs = [] - for column_name, metric_key in LOG_COLUMN_ORDER: + for column_name, metric_key in self.profile.column_order: _, metric_unit = format_metric_value_for_table(metric_key, 1.0) header_label = f"{column_name}({metric_unit})" self.table_column_specs.append( @@ -540,10 +530,7 @@ def _write_log_file(self, result: FullValidationResult) -> None: for _, header_label, column_width in self.table_column_specs: header += VAL_LOG_COLUMN_GAP + f"{header_label:^{column_width}s}" header += "\n" - header += ( - "# E uses per-atom energy, F uses component-wise force errors, " - "and V uses virial normalized by natoms.\n" - ) + header += self.profile.log_header_note fout.write(header) self._should_write_header = False self._write_mode = "a" diff --git a/deepmd/jax/train/validation.py b/deepmd/jax/train/validation.py index 62cb21563c..9796b19a69 100644 --- a/deepmd/jax/train/validation.py +++ b/deepmd/jax/train/validation.py @@ -13,7 +13,6 @@ import numpy as np from deepmd.dpmodel.train.validation import ( - LOG_COLUMN_ORDER, FullValidatorBase, ) from deepmd.jax.env import ( @@ -74,7 +73,7 @@ def evaluate_all_systems(self) -> dict[str, float]: aggregated = weighted_average([metric for metric in system_metrics if metric]) return { metric_key: float(aggregated[metric_key]) - for _, metric_key in LOG_COLUMN_ORDER + for _, metric_key in self.profile.column_order if metric_key in aggregated } diff --git a/deepmd/kernels/__init__.py b/deepmd/kernels/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/kernels/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/kernels/cute/__init__.py b/deepmd/kernels/cute/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/kernels/cute/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/kernels/cute/sezm/__init__.py b/deepmd/kernels/cute/sezm/__init__.py new file mode 100644 index 0000000000..b514a7c1d1 --- /dev/null +++ b/deepmd/kernels/cute/sezm/__init__.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +""" +CuTe-DSL fused SO(2) value-path operator for SeZM / DPA4. + +This package hosts a single bucketed CuTe operator that folds the entire per-edge +value path of :class:`~deepmd.pt.model.descriptor.sezm_nn.so2.SO2Convolution` +(``rotate_to_local`` -> radial degree mix -> the three-layer gated SO(2) mixing +stack -> focus competition) into a fused forward kernel and a matching +recompute backward kernel, keeping the per-edge intermediates on chip. It is an +opt-in inference path enabled by ``DP_CUTE_INFER``; the final local features are +handed to the committed flash-attention aggregation for rotate-back and scatter. +Kernel entry points are internal implementation details of the SeZM descriptor; +the package-level API only exposes availability and the value-path factory. + +Current limitations +------------------- +Performance + On H20 / fp32 the operator is about 2.8x slower than the compiled Triton + + flash-attention path (roughly 489 / 724 ms versus 174 / 262 ms per force step + at 2000 / 4000 atoms). Peak memory is at parity with, or marginally below, + the compiled path (about 0.5 / 0.8 GB lower) and roughly 1.68x below the + eager path. The bottleneck is the recompute backward, which dominates the + kernel time: its occupancy is capped by the block-diagonal weight held + resident in shared memory, and both the forward and backward GEMMs run at the + hand-written plateau of about 21% of fp32 peak (versus about 52% for cuBLAS). + +Deployment + This is a Python-inference-only path. The ``cutlass.cute`` kernels are + nvcc / NVRTC JIT-compiled at runtime and do not bake into the AOTInductor + ``.pt2`` artifact, so the operator is unavailable to the LAMMPS / GPUMD C++ + inference path. ``DP_CUTE_INFER`` and ``DP_TRITON_INFER`` both claim the + fused SO(2) value path and are mutually exclusive; enabling both is + rejected at construction. + +Correctness + The force is bit-exact against the eager reference (energy relative error + about 1e-9, force relative error about 5e-7 in fp32). +""" + +from __future__ import ( + annotations, +) + +from .forward import ( + SEZM_CUTE_AVAILABLE, +) +from .operator import ( + make_cute_value_path, +) + +__all__ = [ + "SEZM_CUTE_AVAILABLE", + "make_cute_value_path", +] diff --git a/deepmd/kernels/cute/sezm/backward.py b/deepmd/kernels/cute/sezm/backward.py new file mode 100644 index 0000000000..9afa940f8b --- /dev/null +++ b/deepmd/kernels/cute/sezm/backward.py @@ -0,0 +1,692 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +# pyright: reportMissingImports=false +# ruff: noqa: ANN001, ANN201, ANN202, ANN204, ANN205 +""" +CuTe-DSL fused recompute-backward kernel for the SeZM SO(2) value path. + +Given the upstream gradients ``g_out`` (w.r.t. the pre-focus-compete local +features ``x_local``) and ``g_fgate`` (w.r.t. the pre-mixing ``l = 0`` scalar), +one bucketed kernel recomputes the forward value path from the small saved +inputs (``x``, ``D_to_m``, ``Kc``) entirely on chip and backpropagates it, +emitting the position-path gradients that carry the ``edge_vec`` -> force +dependence:: + + grad_x node-feature gradient, scattered to source nodes + grad_D_to_m Wigner-rotation gradient, atomically summed over focus + grad_Kc radial degree-kernel gradient, atomically summed over focus + +The weights are frozen on the inference force path. No ``E x D_m x C`` +intermediate reaches DRAM: the kernel recomputes the forward storing only the two +gated-layer pre-activations in shared memory, then backpropagates the residual +stack (gated-activation backward in registers/smem), the radial degree mix, and +the rotation. This kernel is specialized to the three-layer ``[gated, gated, +identity]`` mixing stack of the deployed configuration. + +Buffers per CTA (one focus of a bucket of ``B`` edges): four ``B x (D_m*Cf)`` +scratch tensors plus one ``max_block^2`` weight/gate scratch, inside the sm_90 +shared-memory limit at ``B = 16``. All accumulation is fp32 IEEE. +""" + +from __future__ import ( + annotations, +) + +import torch + +from .forward import ( + SEZM_CUTE_AVAILABLE, +) + +if SEZM_CUTE_AVAILABLE: + import cuda.bindings.driver as cuda + import cutlass + import cutlass.cute as cute + import cutlass.cute.math as cmath + from cutlass.cute.runtime import ( + from_dlpack, + ) + + from .forward import ( + ForwardRunner, + ) + + class BackwardProgram: + """Bucketed fused recompute-backward program for the SO(2) value path. + + Parameters + ---------- + lmax, mmax, cf, n_focus, n_layers, bucket, threads, rb, rn + Kernel configuration (see :class:`.forward.ForwardProgram`). + """ + + def __init__( + self, + *, + lmax: int, + mmax: int, + cf: int, + n_focus: int, + n_layers: int, + bucket: int, + threads: int, + rb: int, + rn: int, + ) -> None: + self.lmax, self.mmax, self.cf, self.nf, self.nl = ( + lmax, + mmax, + cf, + n_focus, + n_layers, + ) + self._B, self._T, self._RB, self._RN = bucket, threads, rb, rn + self.D = (lmax + 1) ** 2 + self.Dm = (lmax + 1) + sum(2 * (lmax - m + 1) for m in range(1, mmax + 1)) + self.Cw = n_focus * cf + self.gate_out = lmax * cf + self.ngroup = 1 + 2 * mmax + self.FLAT = self.Dm * cf + groups = [lmax + 1] + [2 * (lmax - m + 1) for m in range(1, mmax + 1)] + self._blocks: list[tuple[int, int]] = [] + off = 0 + for g in groups: + self._blocks.append((off, g * cf)) + off += g * cf + self._max_sb = max(sb for _, sb in self._blocks) + self._scr = max( + self._max_sb * self._max_sb, + cf * self.gate_out + 2 * bucket * self.gate_out, + ) + assert self._B % rb == 0 + for _, sb in self._blocks: + assert sb % rn == 0 + + @cute.jit + def __call__( + self, + mGout, + mGfg, + mX, + mSrc, + mDtoM, + mKc, + mCB, + mW, + mGW, + mExpand, + mGx, + mGD, + mGKc, + n_edge: cutlass.Int32, + n_bucket: cutlass.Int32, + stream: cuda.CUstream, + ): + self.kernel( + mGout, + mGfg, + mX, + mSrc, + mDtoM, + mKc, + mCB, + mW, + mGW, + mExpand, + mGx, + mGD, + mGKc, + n_edge, + ).launch(grid=[n_bucket, self.nf, 1], block=[self._T, 1, 1], stream=stream) + + @cute.kernel + def kernel( + self, + mGout, + mGfg, + mX, + mSrc, + mDtoM, + mKc, + mCB, + mW, + mGW, + mExpand, + mGx, + mGD, + mGKc, + n_edge: cutlass.Int32, + ): + D = cutlass.const_expr(self.D) + Dm = cutlass.const_expr(self.Dm) + CF = cutlass.const_expr(self.cf) + FLAT = cutlass.const_expr(self.FLAT) + B = cutlass.const_expr(self._B) + T = cutlass.const_expr(self._T) + + tidx, _, _ = cute.arch.thread_idx() + bucket, focus, _ = cute.arch.block_idx() + e0 = bucket * B + + smem = cutlass.utils.SmemAllocator() + b0 = smem.allocate_tensor( + cutlass.Float32, cute.make_layout((B, FLAT), stride=(FLAT, 1)), 16 + ) + b1 = smem.allocate_tensor( + cutlass.Float32, cute.make_layout((B, FLAT), stride=(FLAT, 1)), 16 + ) + b2 = smem.allocate_tensor( + cutlass.Float32, cute.make_layout((B, FLAT), stride=(FLAT, 1)), 16 + ) + b3 = smem.allocate_tensor( + cutlass.Float32, cute.make_layout((B, FLAT), stride=(FLAT, 1)), 16 + ) + s_w = smem.allocate_tensor( + cutlass.Float32, + cute.make_layout((cutlass.const_expr(self._scr),), stride=(1,)), + 16, + ) + + # === Step 1. Forward recompute: store z_0 -> b1, z_1 -> b2 === + # b0 carries the running layer input h_l; b3 is the activation temp. + i = tidx + while i < B * FLAT: + b = i // FLAT + rem = i % FLAT + dm = rem // CF + c = rem % CF + e = e0 + b + eq = e if e < n_edge else n_edge - 1 + src = mSrc[eq] + acc = cutlass.Float32(0.0) + for k in cutlass.range_constexpr(D): + acc += mDtoM[eq, dm, k] * mX[src, k, focus * CF + c] + b1[b, rem] = acc + i += T + cute.arch.sync_threads() + i = tidx + while i < B * FLAT: + b = i // FLAT + rem = i % FLAT + o = rem // CF + c = rem % CF + e = e0 + b + eq = e if e < n_edge else n_edge - 1 + acc = cutlass.Float32(0.0) + for ii in cutlass.range_constexpr(Dm): + acc += mKc[eq, o, ii] * b1[b, ii * CF + c] + b0[b, rem] = acc * mCB[focus * CF + c] + i += T + cute.arch.sync_threads() + for lyr in cutlass.range_constexpr(self.nl - 1): + zbuf = b1 if lyr == 0 else b2 + self._gemm_fwd(b0, zbuf, s_w, mW, lyr, focus, tidx) + self._gated_fwd(zbuf, b3, s_w, mGW, mExpand, lyr, focus, tidx) + i = tidx + while i < B * FLAT: + b = i // FLAT + rem = i % FLAT + b0[b, rem] = b0[b, rem] + b3[b, rem] + i += T + cute.arch.sync_threads() + + # === Step 2. Reverse the residual stack; b0 accumulates grad_h === + i = tidx + while i < B * FLAT: + b = i // FLAT + rem = i % FLAT + o = rem // CF + c = rem % CF + b0[b, rem] = mGout[e0 + b, focus, o, c] + i += T + cute.arch.sync_threads() + # layer 2 (identity): grad_z = grad_out; grad_h += W_2^T @ grad_z + self._gemm_bwd(b0, b3, s_w, mW, self.nl - 1, focus, tidx) + i = tidx + while i < B * FLAT: + b = i // FLAT + rem = i % FLAT + b0[b, rem] = b0[b, rem] + b3[b, rem] + i += T + cute.arch.sync_threads() + # layer 1 (gated): grad_z_1 -> b2 in place; grad_h += W_1^T @ grad_z_1 + self._gated_bwd(b2, b0, s_w, mGW, mExpand, 1, focus, tidx) + self._gemm_bwd(b2, b3, s_w, mW, 1, focus, tidx) + i = tidx + while i < B * FLAT: + b = i // FLAT + rem = i % FLAT + b0[b, rem] = b0[b, rem] + b3[b, rem] + i += T + cute.arch.sync_threads() + # layer 0 (gated): grad_z_0 -> b1 in place; grad_h += W_0^T @ grad_z_0 + self._gated_bwd(b1, b0, s_w, mGW, mExpand, 0, focus, tidx) + self._gemm_bwd(b1, b3, s_w, mW, 0, focus, tidx) + i = tidx + while i < B * FLAT: + b = i // FLAT + rem = i % FLAT + b0[b, rem] = b0[b, rem] + b3[b, rem] + i += T + cute.arch.sync_threads() + # add the focus-competition gradient into the l=0 row + i = tidx + while i < B * CF: + b = i // CF + c = i % CF + b0[b, c] = b0[b, c] + mGfg[e0 + b, focus, c] + i += T + cute.arch.sync_threads() + + # === Step 3. Radial + rotation backward; b0 holds grad_h0 === + # recompute x_rot -> b1 + i = tidx + while i < B * FLAT: + b = i // FLAT + rem = i % FLAT + dm = rem // CF + c = rem % CF + e = e0 + b + eq = e if e < n_edge else n_edge - 1 + src = mSrc[eq] + acc = cutlass.Float32(0.0) + for k in cutlass.range_constexpr(D): + acc += mDtoM[eq, dm, k] * mX[src, k, focus * CF + c] + b1[b, rem] = acc + i += T + cute.arch.sync_threads() + # grad_Kc[o, ii] = sum_c channel_basis[c] * grad_h0[o, c] * x_rot[ii, c] + i = tidx + while i < B * Dm * Dm: + b = i // (Dm * Dm) + rem = i % (Dm * Dm) + o = rem // Dm + ii = rem % Dm + e = e0 + b + acc = cutlass.Float32(0.0) + for c in cutlass.range_constexpr(CF): + acc += mCB[focus * CF + c] * b0[b, o * CF + c] * b1[b, ii * CF + c] + cute.arch.atomic_add(mGKc.iterator + mGKc.layout((e, o, ii)), acc) + i += T + # grad_x_rot[ii, c] = channel_basis[c] * sum_o Kc[o, ii] * grad_h0[o, c] -> b2 + i = tidx + while i < B * Dm * CF: + b = i // (Dm * CF) + rem = i % (Dm * CF) + ii = rem // CF + c = rem % CF + e = e0 + b + eq = e if e < n_edge else n_edge - 1 + acc = cutlass.Float32(0.0) + for o in cutlass.range_constexpr(Dm): + acc += mKc[eq, o, ii] * b0[b, o * CF + c] + b2[b, ii * CF + c] = acc * mCB[focus * CF + c] + i += T + cute.arch.sync_threads() + # grad_x[src, k, c] += sum_ii D_to_m[ii, k] * grad_x_rot[ii, c] + i = tidx + while i < B * D * CF: + b = i // (D * CF) + rem = i % (D * CF) + k = rem // CF + c = rem % CF + e = e0 + b + eq = e if e < n_edge else n_edge - 1 + src = mSrc[eq] + acc = cutlass.Float32(0.0) + for ii in cutlass.range_constexpr(Dm): + acc += mDtoM[eq, ii, k] * b2[b, ii * CF + c] + cute.arch.atomic_add( + mGx.iterator + mGx.layout((src, k, focus * CF + c)), acc + ) + i += T + # grad_D_to_m[ii, k] = sum_c grad_x_rot[ii, c] * x_src[k, c] + i = tidx + while i < B * Dm * D: + b = i // (Dm * D) + rem = i % (Dm * D) + ii = rem // D + k = rem % D + e = e0 + b + eq = e if e < n_edge else n_edge - 1 + src = mSrc[eq] + acc = cutlass.Float32(0.0) + for c in cutlass.range_constexpr(CF): + acc += b2[b, ii * CF + c] * mX[src, k, focus * CF + c] + cute.arch.atomic_add(mGD.iterator + mGD.layout((e, ii, k)), acc) + i += T + + @cute.jit + def _gemm_fwd(self, hbuf, zbuf, s_w, mW, lyr, focus, tidx): + """Recompute ``zbuf = hbuf @ W[lyr, focus]`` (block-diagonal).""" + T = cutlass.const_expr(self._T) + RB = cutlass.const_expr(self._RB) + RN = cutlass.const_expr(self._RN) + B = cutlass.const_expr(self._B) + for ob, sb in cutlass.const_expr(self._blocks): + j = tidx + while j < sb * sb: + s_w[(j // sb) * sb + (j % sb)] = mW[ + lyr, focus, ob + (j // sb), ob + (j % sb) + ] + j += T + cute.arch.sync_threads() + n_mt = cutlass.const_expr((B // RB) * (sb // RN)) + racc = cute.make_rmem_tensor( + cute.make_layout((RB, RN)), cutlass.Float32 + ) + mt = tidx + while mt < n_mt: + bi = (mt // (sb // RN)) * RB + nj = (mt % (sb // RN)) * RN + for r in cutlass.range_constexpr(RB): + for s in cutlass.range_constexpr(RN): + racc[r, s] = cutlass.Float32(0.0) + for k in cutlass.range(sb): + for r in cutlass.range_constexpr(RB): + a_rk = hbuf[bi + r, ob + k] + for s in cutlass.range_constexpr(RN): + racc[r, s] += a_rk * s_w[k * sb + nj + s] + for r in cutlass.range_constexpr(RB): + for s in cutlass.range_constexpr(RN): + zbuf[bi + r, ob + nj + s] = racc[r, s] + mt += T + cute.arch.sync_threads() + + @cute.jit + def _gemm_bwd(self, gzbuf, ghbuf, s_w, mW, lyr, focus, tidx): + """Compute ``ghbuf = W[lyr, focus]^T @ gzbuf`` (block-diagonal).""" + T = cutlass.const_expr(self._T) + RB = cutlass.const_expr(self._RB) + RN = cutlass.const_expr(self._RN) + B = cutlass.const_expr(self._B) + for ob, sb in cutlass.const_expr(self._blocks): + j = tidx + while j < sb * sb: + s_w[(j // sb) * sb + (j % sb)] = mW[ + lyr, focus, ob + (j // sb), ob + (j % sb) + ] + j += T + cute.arch.sync_threads() + n_mt = cutlass.const_expr((B // RB) * (sb // RN)) + racc = cute.make_rmem_tensor( + cute.make_layout((RB, RN)), cutlass.Float32 + ) + mt = tidx + while mt < n_mt: + bi = (mt // (sb // RN)) * RB + kj = (mt % (sb // RN)) * RN # W in-index (grad_h output column) + for r in cutlass.range_constexpr(RB): + for s in cutlass.range_constexpr(RN): + racc[r, s] = cutlass.Float32(0.0) + for n in cutlass.range(sb): # sum over the W out-index + for r in cutlass.range_constexpr(RB): + gz_rn = gzbuf[bi + r, ob + n] + for s in cutlass.range_constexpr(RN): + racc[r, s] += gz_rn * s_w[(kj + s) * sb + n] + for r in cutlass.range_constexpr(RB): + for s in cutlass.range_constexpr(RN): + ghbuf[bi + r, ob + kj + s] = racc[r, s] + mt += T + cute.arch.sync_threads() + + @cute.jit + def _gated_fwd(self, zbuf, abuf, s_w, mGW, mExpand, lyr, focus, tidx): + """Recompute ``abuf = GatedActivation(zbuf)`` (silu l=0, gate l>0).""" + CF = cutlass.const_expr(self.cf) + GO = cutlass.const_expr(self.gate_out) + Dm = cutlass.const_expr(self.Dm) + B = cutlass.const_expr(self._B) + T = cutlass.const_expr(self._T) + SIG_OFF = cutlass.const_expr(self.cf * self.gate_out) + j = tidx + while j < CF * GO: + s_w[(j // GO) * GO + (j % GO)] = mGW[lyr, focus, j // GO, j % GO] + j += T + cute.arch.sync_threads() + j = tidx + while j < B * GO: + b = j // GO + o = j % GO + acc = cutlass.Float32(0.0) + for ii in cutlass.range_constexpr(CF): + acc += zbuf[b, ii] * s_w[ii * GO + o] + s_w[SIG_OFF + b * GO + o] = cutlass.Float32(1.0) / ( + cutlass.Float32(1.0) + cmath.exp(-acc) + ) + j += T + cute.arch.sync_threads() + j = tidx + while j < B * CF: + b = j // CF + c = j % CF + z = zbuf[b, c] + abuf[b, c] = z / (cutlass.Float32(1.0) + cmath.exp(-z)) + j += T + j = tidx + while j < B * (Dm - 1) * CF: + b = j // ((Dm - 1) * CF) + rem = j % ((Dm - 1) * CF) + d1 = rem // CF + c = rem % CF + lidx = mExpand[d1] + abuf[b, (d1 + 1) * CF + c] = ( + zbuf[b, (d1 + 1) * CF + c] * s_w[SIG_OFF + b * GO + lidx * CF + c] + ) + j += T + cute.arch.sync_threads() + + @cute.jit + def _gated_bwd(self, zbuf, gabuf, s_w, mGW, mExpand, lyr, focus, tidx): + """Backprop the gated activation in place: ``zbuf`` (z_l) -> grad_z_l. + + With ``g_a = gabuf`` the incoming gradient and the recomputed gate + sigmoids ``sig``:: + + grad_z[dm, c] = g_a[dm, c] * sig[expand[dm-1], c] (dm > 0) + g_sig[L, c] = sum_{dm: expand[dm-1]=L} g_a[dm, c] * z[dm, c] + grad_z[0, i] = g_a[0, i] * silu'(z[0, i]) + + sum_{o'} Wg[i, o'] * g_sig * sig*(1-sig) + """ + CF = cutlass.const_expr(self.cf) + GO = cutlass.const_expr(self.gate_out) + Dm = cutlass.const_expr(self.Dm) + LMAX = cutlass.const_expr(self.lmax) + NG = cutlass.const_expr(self.ngroup) + B = cutlass.const_expr(self._B) + T = cutlass.const_expr(self._T) + SIG_OFF = cutlass.const_expr(self.cf * self.gate_out) + GGL_OFF = cutlass.const_expr(self.cf * self.gate_out + B * self.gate_out) + j = tidx + while j < CF * GO: + s_w[(j // GO) * GO + (j % GO)] = mGW[lyr, focus, j // GO, j % GO] + j += T + cute.arch.sync_threads() + j = tidx + while j < B * GO: + b = j // GO + o = j % GO + acc = cutlass.Float32(0.0) + for ii in cutlass.range_constexpr(CF): + acc += zbuf[b, ii] * s_w[ii * GO + o] + s_w[SIG_OFF + b * GO + o] = cutlass.Float32(1.0) / ( + cutlass.Float32(1.0) + cmath.exp(-acc) + ) + j += T + cute.arch.sync_threads() + # g_gl[L, c] = (sum over the (1 + 2*mmax) |m| groups) * sigmoid'(gate). + # For degree l = L + 1 the contributing coefficients are dm = 1 + L + k*lmax. + j = tidx + while j < B * GO: + b = j // GO + o = j % GO + gate_l = o // CF + c = o % CF + gsig = cutlass.Float32(0.0) + for kk in cutlass.range_constexpr(NG): + dm = 1 + gate_l + kk * LMAX + gsig += gabuf[b, dm * CF + c] * zbuf[b, dm * CF + c] + s = s_w[SIG_OFF + b * GO + o] + s_w[GGL_OFF + b * GO + o] = gsig * s * (cutlass.Float32(1.0) - s) + j += T + cute.arch.sync_threads() + # grad_z[dm>0]: reads sig, writes disjoint from the l=0 slice. + j = tidx + while j < B * (Dm - 1) * CF: + b = j // ((Dm - 1) * CF) + rem = j % ((Dm - 1) * CF) + d1 = rem // CF + c = rem % CF + lidx = mExpand[d1] + zbuf[b, (d1 + 1) * CF + c] = ( + gabuf[b, (d1 + 1) * CF + c] * s_w[SIG_OFF + b * GO + lidx * CF + c] + ) + j += T + # grad_z[0]: reads z[0, i] before overwriting it. + j = tidx + while j < B * CF: + b = j // CF + ii = j % CF + z0 = zbuf[b, ii] + sg = cutlass.Float32(1.0) / (cutlass.Float32(1.0) + cmath.exp(-z0)) + silup = sg + z0 * sg * (cutlass.Float32(1.0) - sg) + acc = gabuf[b, ii] * silup + for o in cutlass.range_constexpr(GO): + acc += s_w[ii * GO + o] * s_w[GGL_OFF + b * GO + o] + zbuf[b, ii] = acc + j += T + cute.arch.sync_threads() + + class BackwardRunner: + """Compile-once driver for :class:`BackwardProgram` (force-path gradients). + + Parameters + ---------- + weights + Packed weights (see :class:`.forward.ForwardRunner`). + lmax, mmax, cf, n_focus, n_layers, bucket, threads, rb, rn + Kernel configuration. + """ + + def __init__( + self, + weights, + *, + lmax: int, + mmax: int, + cf: int, + n_focus: int, + n_layers: int, + bucket: int, + threads: int, + rb: int, + rn: int, + ) -> None: + self.op = BackwardProgram( + lmax=lmax, + mmax=mmax, + cf=cf, + n_focus=n_focus, + n_layers=n_layers, + bucket=bucket, + threads=threads, + rb=rb, + rn=rn, + ) + self._B = bucket + self.nf, self.cf, self.Dm, self.D = n_focus, cf, self.op.Dm, self.op.D + self._compiled = None + self._stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + fr = ForwardRunner( + weights, + lmax=lmax, + mmax=mmax, + cf=cf, + n_focus=n_focus, + n_layers=n_layers, + bucket=bucket, + threads=threads, + rb=rb, + rn=rn, + ) + self.m_w, self.m_gw, self.m_cb, self.m_expand = ( + fr.m_w, + fr.m_gw, + fr.m_cb, + fr.m_expand, + ) + + @staticmethod + def _dyn(t: torch.Tensor, leading: int): + return from_dlpack(t, assumed_align=16).mark_layout_dynamic( + leading_dim=leading + ) + + def __call__( + self, + x: torch.Tensor, + src: torch.Tensor, + d_to_m: torch.Tensor, + kc: torch.Tensor, + g_out: torch.Tensor, + g_fgate: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Emit the value-path gradients. + + Parameters + ---------- + x, src, d_to_m, kc + Forward inputs (see :meth:`.forward.ForwardRunner.__call__`). + g_out : torch.Tensor + Gradient w.r.t. ``x_local`` with shape (E, F, D_m, Cf). + g_fgate : torch.Tensor + Gradient w.r.t. ``focus_gate`` with shape (E, F, Cf). + + Returns + ------- + grad_x : torch.Tensor + Node-feature gradient with shape (N, D, C_wide). + grad_d_to_m : torch.Tensor + Wigner-rotation gradient with shape (E, D_m, D). + grad_kc : torch.Tensor + Radial degree-kernel gradient with shape (E, D_m, D_m). + """ + n_edge = src.shape[0] + n_bucket = (n_edge + self._B - 1) // self._B + n_pad = n_bucket * self._B + s32 = src.to(torch.int32) + g_out_p, g_fg_p = g_out, g_fgate + if n_pad > n_edge: + s32 = torch.cat([s32, s32.new_zeros(n_pad - n_edge)]) + g_out_p = torch.cat( + [g_out, g_out.new_zeros(n_pad - n_edge, *g_out.shape[1:])] + ) + g_fg_p = torch.cat( + [g_fgate, g_fgate.new_zeros(n_pad - n_edge, *g_fgate.shape[1:])] + ) + grad_x = torch.zeros_like(x) + grad_d = torch.zeros( + n_pad, self.Dm, self.D, device=x.device, dtype=torch.float32 + ) + grad_kc = torch.zeros( + n_pad, self.Dm, self.Dm, device=x.device, dtype=torch.float32 + ) + views = ( + self._dyn(g_out_p, 3), + self._dyn(g_fg_p, 2), + self._dyn(x, 2), + self._dyn(s32, 0), + self._dyn(d_to_m, 2), + self._dyn(kc, 2), + from_dlpack(self.m_cb, assumed_align=16), + from_dlpack(self.m_w, assumed_align=16), + from_dlpack(self.m_gw, assumed_align=16), + from_dlpack(self.m_expand, assumed_align=16), + self._dyn(grad_x, 2), + self._dyn(grad_d, 2), + self._dyn(grad_kc, 2), + ) + args = (*views, cutlass.Int32(n_edge), cutlass.Int32(n_bucket)) + if self._compiled is None: + self._compiled = cute.compile(self.op, *args, stream=self._stream) + self._compiled(*args, stream=self._stream) + return grad_x, grad_d[:n_edge], grad_kc[:n_edge] diff --git a/deepmd/kernels/cute/sezm/forward.py b/deepmd/kernels/cute/sezm/forward.py new file mode 100644 index 0000000000..4490b3fe1b --- /dev/null +++ b/deepmd/kernels/cute/sezm/forward.py @@ -0,0 +1,444 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +# pyright: reportMissingImports=false +# ruff: noqa: ANN001, ANN201, ANN204, ANN205 +""" +CuTe-DSL fused forward kernel for the SeZM SO(2) value path. + +One bucketed kernel folds the entire per-edge value path of ``SO2Convolution`` +into a single launch:: + + gather x[src] -> rotate_to_local (D_to_m) [prologue, in smem] + -> radial degree mix (Kc, channel_basis) [prologue, in smem] + -> 3x (block-diagonal SO2Linear + GatedActivation + residual) [in smem] + -> x_local (E, F, D_m, Cf) [+ pre-mixing l=0 scalar] + +The block-diagonal ``SO2Linear`` weight is staged in shared memory once per +bucket and reused across the bucket's ``B`` edges (register-blocked ``RB x RN`` +FMA micro-tile), so the ``E x D_m x C`` intermediates of all three mixing layers +stay resident on chip and never reach DRAM. The focus competition (a per-edge +softmax of the pre-mixing ``l = 0`` feature) is applied outside the kernel from +the returned ``focus_gate`` scalar. + +Grid ``(n_bucket, n_focus)``; one CTA owns ``B`` edges of one focus stream. +All accumulation is fp32 IEEE (no TF32) to keep the potential-energy surface +smooth. +""" + +from __future__ import ( + annotations, +) + +import torch + +try: + import cuda.bindings.driver as cuda + import cutlass + import cutlass.cute as cute + import cutlass.cute.math as cmath + from cutlass.cute.runtime import ( + from_dlpack, + ) + + SEZM_CUTE_AVAILABLE = True +except Exception: # pragma: no cover - import guard for non-CuTe environments + SEZM_CUTE_AVAILABLE = False + + +if SEZM_CUTE_AVAILABLE: + + class ForwardProgram: + """Bucketed fused forward program for the SO(2) value path. + + Parameters + ---------- + lmax : int + Maximum spherical harmonic degree. + mmax : int + Maximum SO(2) order retained in the reduced layout. + cf : int + Per-focus channel width ``Cf``. + n_focus : int + Number of focus streams ``F``. + n_layers : int + Number of SO(2) mixing layers. + bucket : int + Edges processed per CTA ``B``. + threads : int + Threads per CTA. + rb, rn : int + Register micro-tile dimensions (``RB`` bucket rows, ``RN`` output + columns per thread) of the block-diagonal GEMM. + """ + + def __init__( + self, + *, + lmax: int, + mmax: int, + cf: int, + n_focus: int, + n_layers: int, + bucket: int, + threads: int, + rb: int, + rn: int, + ) -> None: + self.lmax, self.mmax, self.cf, self.nf, self.nl = ( + lmax, + mmax, + cf, + n_focus, + n_layers, + ) + self._B, self._T, self._RB, self._RN = bucket, threads, rb, rn + self.D = (lmax + 1) ** 2 + self.Dm = (lmax + 1) + sum(2 * (lmax - m + 1) for m in range(1, mmax + 1)) + self.Cw = n_focus * cf + self.gate_out = lmax * cf + self.FLAT = self.Dm * cf + # Block-diagonal |m| block widths in the flattened coeff*channel axis: + # m = 0 spans (lmax + 1) coefficients, each |m| > 0 spans 2*(lmax-m+1). + groups = [lmax + 1] + [2 * (lmax - m + 1) for m in range(1, mmax + 1)] + self._blocks: list[tuple[int, int]] = [] + off = 0 + for g in groups: + self._blocks.append((off, g * cf)) + off += g * cf + self._max_sb = max(sb for _, sb in self._blocks) + # Shared scratch reused for the resident weight block and, during the + # gated activation, the gate weight plus per-edge sigmoid buffer. + self._scr = max( + self._max_sb * self._max_sb, cf * self.gate_out + bucket * self.gate_out + ) + assert self._B % rb == 0 + for _, sb in self._blocks: + assert sb % rn == 0 + + @cute.jit + def __call__( + self, + mX, + mSrc, + mDtoM, + mKc, + mCB, + mW, + mGW, + mExpand, + mOut, + mFocusGate, + n_edge: cutlass.Int32, + n_bucket: cutlass.Int32, + stream: cuda.CUstream, + ): + self.kernel( + mX, mSrc, mDtoM, mKc, mCB, mW, mGW, mExpand, mOut, mFocusGate, n_edge + ).launch(grid=[n_bucket, self.nf, 1], block=[self._T, 1, 1], stream=stream) + + @cute.kernel + def kernel( + self, + mX, + mSrc, + mDtoM, + mKc, + mCB, + mW, + mGW, + mExpand, + mOut, + mFocusGate, + n_edge: cutlass.Int32, + ): + D = cutlass.const_expr(self.D) + Dm = cutlass.const_expr(self.Dm) + CF = cutlass.const_expr(self.cf) + FLAT = cutlass.const_expr(self.FLAT) + GO = cutlass.const_expr(self.gate_out) + B = cutlass.const_expr(self._B) + T = cutlass.const_expr(self._T) + RB = cutlass.const_expr(self._RB) + RN = cutlass.const_expr(self._RN) + SCR = cutlass.const_expr(self._scr) + GATE_OFF = cutlass.const_expr(self.cf * self.gate_out) + NGATE = cutlass.const_expr(self.nl - 1) + + tidx, _, _ = cute.arch.thread_idx() + bucket, focus, _ = cute.arch.block_idx() + e0 = bucket * B + + smem = cutlass.utils.SmemAllocator() + s_cur = smem.allocate_tensor( + cutlass.Float32, cute.make_layout((B, FLAT), stride=(FLAT, 1)), 16 + ) + s_tmp = smem.allocate_tensor( + cutlass.Float32, cute.make_layout((B, FLAT), stride=(FLAT, 1)), 16 + ) + s_scr = smem.allocate_tensor( + cutlass.Float32, cute.make_layout((SCR,), stride=(1,)), 16 + ) + + # === Step 1. rotate_to_local: s_tmp[b, dm*Cf+c] = sum_k D_to_m x[src] === + # Padding edges (e >= n_edge) clamp their read index; their output rows + # are sliced off by the caller. + i = tidx + while i < B * FLAT: + b = i // FLAT + rem = i % FLAT + dm = rem // CF + c = rem % CF + e = e0 + b + eq = e if e < n_edge else n_edge - 1 + src = mSrc[eq] + acc = cutlass.Float32(0.0) + for k in cutlass.range_constexpr(D): + acc += mDtoM[eq, dm, k] * mX[src, k, focus * CF + c] + s_tmp[b, rem] = acc + i += T + cute.arch.sync_threads() + + # === Step 2. radial degree mix: s_cur = channel_basis * (Kc @ x_rot) === + i = tidx + while i < B * FLAT: + b = i // FLAT + rem = i % FLAT + o = rem // CF + c = rem % CF + e = e0 + b + eq = e if e < n_edge else n_edge - 1 + acc = cutlass.Float32(0.0) + for ii in cutlass.range_constexpr(Dm): + acc += mKc[eq, o, ii] * s_tmp[b, ii * CF + c] + s_cur[b, rem] = acc * mCB[focus * CF + c] + i += T + cute.arch.sync_threads() + + # === Step 3. Emit the pre-mixing l=0 scalar for the focus competition === + i = tidx + while i < B * CF: + b = i // CF + c = i % CF + mFocusGate[e0 + b, focus, c] = s_cur[b, c] + i += T + cute.arch.sync_threads() + + # === Step 4. Multi-layer gated SO(2) mixing (block-diagonal, residual) === + for lyr in cutlass.range_constexpr(self.nl): + # --- SO2Linear: s_tmp = s_cur @ W[lyr, focus] (per |m| block) --- + for ob, sb in cutlass.const_expr(self._blocks): + j = tidx + while j < sb * sb: + k = j // sb + n = j % sb + s_scr[k * sb + n] = mW[lyr, focus, ob + k, ob + n] + j += T + cute.arch.sync_threads() + n_mt = cutlass.const_expr((B // RB) * (sb // RN)) + racc = cute.make_rmem_tensor( + cute.make_layout((RB, RN)), cutlass.Float32 + ) + mt = tidx + while mt < n_mt: + bi = (mt // (sb // RN)) * RB + nj = (mt % (sb // RN)) * RN + for r in cutlass.range_constexpr(RB): + for s in cutlass.range_constexpr(RN): + racc[r, s] = cutlass.Float32(0.0) + for k in cutlass.range(sb): + for r in cutlass.range_constexpr(RB): + a_rk = s_cur[bi + r, ob + k] + for s in cutlass.range_constexpr(RN): + racc[r, s] += a_rk * s_scr[k * sb + nj + s] + for r in cutlass.range_constexpr(RB): + for s in cutlass.range_constexpr(RN): + s_tmp[bi + r, ob + nj + s] = racc[r, s] + mt += T + cute.arch.sync_threads() + + # --- GatedActivation (gated layers) or identity (last layer) --- + if lyr < NGATE: + # gate FocusLinear weight -> s_scr[0 : Cf*GO] + j = tidx + while j < CF * GO: + ii = j // GO + o = j % GO + s_scr[ii * GO + o] = mGW[lyr, focus, ii, o] + j += T + cute.arch.sync_threads() + # gate sigmoids from the l=0 scalar -> s_scr[GATE_OFF + b*GO + o] + j = tidx + while j < B * GO: + b = j // GO + o = j % GO + acc = cutlass.Float32(0.0) + for ii in cutlass.range_constexpr(CF): + acc += s_tmp[b, ii] * s_scr[ii * GO + o] + s_scr[GATE_OFF + b * GO + o] = cutlass.Float32(1.0) / ( + cutlass.Float32(1.0) + cmath.exp(-acc) + ) + j += T + cute.arch.sync_threads() + # l=0: silu(z) = z / (1 + exp(-z)) + j = tidx + while j < B * CF: + b = j // CF + c = j % CF + z = s_tmp[b, c] + s_tmp[b, c] = z / (cutlass.Float32(1.0) + cmath.exp(-z)) + j += T + # l>0: z * sigmoid(gate[expand[dm-1]]) + j = tidx + while j < B * (Dm - 1) * CF: + b = j // ((Dm - 1) * CF) + rem = j % ((Dm - 1) * CF) + d1 = rem // CF + c = rem % CF + lidx = mExpand[d1] + z = s_tmp[b, (d1 + 1) * CF + c] + s_tmp[b, (d1 + 1) * CF + c] = ( + z * s_scr[GATE_OFF + b * GO + lidx * CF + c] + ) + j += T + cute.arch.sync_threads() + + # --- residual add: s_cur += activation(s_tmp) --- + i = tidx + while i < B * FLAT: + b = i // FLAT + rem = i % FLAT + s_cur[b, rem] = s_cur[b, rem] + s_tmp[b, rem] + i += T + cute.arch.sync_threads() + + # === Step 5. Write the pre-focus-compete local features (E, F, D_m, Cf) === + i = tidx + while i < B * FLAT: + b = i // FLAT + rem = i % FLAT + o = rem // CF + c = rem % CF + mOut[e0 + b, focus, o, c] = s_cur[b, rem] + i += T + + class ForwardRunner: + """Compile-once driver for :class:`ForwardProgram`. + + Prepares the static packed weights on construction and dispatches the + compiled kernel over dynamic edge counts. + + Parameters + ---------- + weights + Packed weights exposing ``so2_w`` (L, F, D_m*Cf, D_m*Cf), + ``gate_w`` (L, Cf, F, lmax*Cf), ``has_gate`` (L,), and + ``channel_basis`` (C_wide,). + lmax, mmax, cf, n_focus, n_layers, bucket, threads, rb, rn + Kernel configuration (see :class:`ForwardProgram`). + """ + + def __init__( + self, + weights, + *, + lmax: int, + mmax: int, + cf: int, + n_focus: int, + n_layers: int, + bucket: int, + threads: int, + rb: int, + rn: int, + ) -> None: + self.op = ForwardProgram( + lmax=lmax, + mmax=mmax, + cf=cf, + n_focus=n_focus, + n_layers=n_layers, + bucket=bucket, + threads=threads, + rb=rb, + rn=rn, + ) + self._B = bucket + self.nf, self.cf, self.Dm = n_focus, cf, self.op.Dm + self._compiled = None + self._stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + self._pack(weights, lmax, cf, n_focus, n_layers) + + def _pack(self, w, lmax: int, cf: int, nf: int, nl: int) -> None: + dev = w.so2_w.device + self.m_w = w.so2_w.detach().contiguous() + gate = torch.zeros(nl, nf, cf, lmax * cf, device=dev, dtype=torch.float32) + for layer in range(nl): + if bool(w.has_gate[layer]): + gate[layer] = w.gate_w[layer].detach().permute(1, 0, 2).contiguous() + self.m_gw = gate.contiguous() + self.m_cb = w.channel_basis.detach().contiguous() + # m-major degree index l(dm); the gate expand maps dm>0 -> (l-1). + l_index = list(range(lmax + 1)) + for m in range(1, self.op.mmax + 1): + l_index += list(range(m, lmax + 1)) * 2 + self.m_expand = torch.tensor( + [li - 1 for li in l_index[1:]], device=dev, dtype=torch.int32 + ).contiguous() + + @staticmethod + def _dyn(t: torch.Tensor, leading: int): + return from_dlpack(t, assumed_align=16).mark_layout_dynamic( + leading_dim=leading + ) + + def __call__( + self, + x: torch.Tensor, + src: torch.Tensor, + d_to_m: torch.Tensor, + kc: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Run the fused forward. + + Parameters + ---------- + x : torch.Tensor + Node features with shape (N, D, C_wide). + src : torch.Tensor + Per-edge source-node indices with shape (E,). + d_to_m : torch.Tensor + Row-projected Wigner-D with shape (E, D_m, D). + kc : torch.Tensor + Radial degree kernel with shape (E, D_m, D_m). + + Returns + ------- + x_local : torch.Tensor + Pre-focus-compete local features with shape (E, F, D_m, Cf). + focus_gate : torch.Tensor + Pre-mixing l=0 scalar with shape (E, F, Cf). + """ + n_edge = src.shape[0] + n_bucket = (n_edge + self._B - 1) // self._B + n_pad = n_bucket * self._B + s32 = src.to(torch.int32) + if n_pad > n_edge: + s32 = torch.cat([s32, s32.new_zeros(n_pad - n_edge)]) + out = x.new_empty(n_pad, self.nf, self.Dm, self.cf) + fgate = x.new_empty(n_pad, self.nf, self.cf) + views = ( + self._dyn(x, 2), + self._dyn(s32, 0), + self._dyn(d_to_m, 2), + self._dyn(kc, 2), + from_dlpack(self.m_cb, assumed_align=16), + from_dlpack(self.m_w, assumed_align=16), + from_dlpack(self.m_gw, assumed_align=16), + from_dlpack(self.m_expand, assumed_align=16), + self._dyn(out, 3), + self._dyn(fgate, 2), + ) + args = (*views, cutlass.Int32(n_edge), cutlass.Int32(n_bucket)) + if self._compiled is None: + self._compiled = cute.compile(self.op, *args, stream=self._stream) + self._compiled(*args, stream=self._stream) + return out[:n_edge], fgate[:n_edge] diff --git a/deepmd/kernels/cute/sezm/operator.py b/deepmd/kernels/cute/sezm/operator.py new file mode 100644 index 0000000000..d3efd30781 --- /dev/null +++ b/deepmd/kernels/cute/sezm/operator.py @@ -0,0 +1,330 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +""" +Autograd operator and entry point wiring the fused CuTe kernels into the SeZM +SO(2) convolution value path. + +:class:`_SO2ValuePathFunction` runs the fused forward kernel (saving only the +small inputs ``x``, ``D_to_m``, ``Kc``) and, on the force path, recomputes the +value path in the fused backward kernel -- so the per-edge ``E x D_m x C`` +intermediates stay off DRAM across the whole autograd graph. :func:`make_cute_value_path` +builds the per-convolution entry :class:`_CuteSO2ValuePath`, which computes the +radial-/scalar-only tensors (``D_to_m``, ``Kc``, ``rad_feat``) in ordinary +autograd, invokes the operator to produce the pre-focus-compete local features +and the pre-mixing ``l = 0`` scalar, and applies the focus competition -- exactly +as the reference path does. The packed weights are extracted lazily on the first +call so they reflect the loaded checkpoint. +""" + +from __future__ import ( + annotations, +) + +from dataclasses import ( + dataclass, +) +from typing import ( + TYPE_CHECKING, +) + +import torch + +from deepmd.pt.model.descriptor.sezm_nn.indexing import ( + project_D_to_m, +) + +from .forward import ( + SEZM_CUTE_AVAILABLE, +) + +if TYPE_CHECKING: + from deepmd.pt.model.descriptor.sezm_nn.edge_cache import ( + EdgeFeatureCache, + ) + from deepmd.pt.model.descriptor.sezm_nn.so2 import ( + SO2Convolution, + ) + +if SEZM_CUTE_AVAILABLE: + from .backward import ( + BackwardRunner, + ) + from .forward import ( + ForwardRunner, + ) + +# Validated configuration of the fused operator: the deployed three-layer +# ``[gated, gated, identity]`` mixing stack in the ``lmax = 3, mmax = 1`` layout. +_SUPPORTED_LMAX = 3 +_SUPPORTED_MMAX = 1 +_SUPPORTED_LAYERS = 3 + + +@dataclass +class _PackedWeights: + """Static weights of the SO(2) value path, packed for the fused kernels. + + Attributes + ---------- + so2_w : torch.Tensor + Assembled block-diagonal SO2Linear weight per layer with shape + (L, F, D_m*Cf, D_m*Cf), in ``(in, out)`` convention. + gate_w : torch.Tensor + GatedActivation ``FocusLinear`` weight per layer with shape + (L, Cf, F, lmax*Cf); zero for non-gated layers. + has_gate : torch.Tensor + Boolean per-layer flag with shape (L,). + channel_basis : torch.Tensor + Radial degree-mixer channel basis with shape (C_wide,). + """ + + so2_w: torch.Tensor + gate_w: torch.Tensor + has_gate: torch.Tensor + channel_basis: torch.Tensor + + +def _pack_weights(conv: SO2Convolution) -> _PackedWeights: + """Extract and pack the SO(2) value-path weights from a convolution block.""" + n_layers = conv.mixing_layers + so2_w = torch.stack( + [ + conv.so2_linears[layer]._build_so2_weight().permute(1, 0, 2).contiguous() + for layer in range(n_layers) + ] + ) + gate_w, has_gate = [], [] + for layer in range(n_layers): + non_linear = conv.non_linearities[layer] + if type(non_linear).__name__ == "GatedActivation" and non_linear.lmax > 0: + gate_w.append( + non_linear.gate_linear.weight.view( + conv.so2_focus_dim, conv.n_focus, conv.lmax * conv.so2_focus_dim + ).contiguous() + ) + has_gate.append(True) + else: + gate_w.append( + torch.zeros_like(gate_w[0]) + if gate_w + else torch.zeros( + conv.so2_focus_dim, + conv.n_focus, + conv.lmax * conv.so2_focus_dim, + device=so2_w.device, + dtype=so2_w.dtype, + ) + ) + has_gate.append(False) + return _PackedWeights( + so2_w=so2_w, + gate_w=torch.stack(gate_w), + has_gate=torch.tensor(has_gate, device=so2_w.device, dtype=torch.bool), + channel_basis=conv.radial_degree_mixer.channel_basis.reshape(-1).contiguous(), + ) + + +class _SO2ValuePathFunction(torch.autograd.Function): + """Fused CuTe forward with a recompute backward for the force path.""" + + @staticmethod + def forward(ctx, x, d_to_m, kc, src, fwd_runner, bwd_runner): # noqa: ANN001, ANN205 + with torch.no_grad(): + x_local, focus_gate = fwd_runner( + x.detach(), src, d_to_m.detach(), kc.detach() + ) + ctx.save_for_backward(x, d_to_m, kc) + ctx.src = src + ctx.bwd_runner = bwd_runner + return x_local, focus_gate + + @staticmethod + def backward(ctx, grad_local, grad_focus_gate): # noqa: ANN001, ANN205 + x, d_to_m, kc = ctx.saved_tensors + need = ctx.needs_input_grad + grad_x, grad_d_to_m, grad_kc = ctx.bwd_runner( + x.detach(), + ctx.src, + d_to_m.detach(), + kc.detach(), + grad_local.detach().contiguous(), + grad_focus_gate.detach().contiguous(), + ) + return ( + grad_x if need[0] else None, + grad_d_to_m if need[1] else None, + grad_kc if need[2] else None, + None, + None, + None, + ) + + +class _CuteSO2ValuePath: + """Per-convolution entry that runs the value path through the fused kernels. + + The convolution is held by reference so the packed weights are extracted + lazily on the first call (after the checkpoint is loaded) and the kernels are + compiled on first use. + + Parameters + ---------- + conv : SO2Convolution + The owning convolution block. + bucket_fwd, threads_fwd, bucket_bwd, threads_bwd, rb, rn : int + Fused-kernel launch configuration. + """ + + def __init__( + self, + conv: SO2Convolution, + *, + bucket_fwd: int = 32, + threads_fwd: int = 1024, + bucket_bwd: int = 16, + threads_bwd: int = 512, + rb: int = 4, + rn: int = 4, + ) -> None: + self._conv = conv + self._cfg = { + "lmax": conv.lmax, + "mmax": conv.mmax, + "cf": conv.so2_focus_dim, + "n_focus": conv.n_focus, + "n_layers": conv.mixing_layers, + "rb": rb, + "rn": rn, + } + self._launch = { + "bucket_fwd": bucket_fwd, + "threads_fwd": threads_fwd, + "bucket_bwd": bucket_bwd, + "threads_bwd": threads_bwd, + } + self._fwd_runner = None + self._bwd_runner = None + + def _build(self) -> None: + weights = _pack_weights(self._conv) + self._fwd_runner = ForwardRunner( + weights, + bucket=self._launch["bucket_fwd"], + threads=self._launch["threads_fwd"], + **self._cfg, + ) + self._bwd_runner = BackwardRunner( + weights, + bucket=self._launch["bucket_bwd"], + threads=self._launch["threads_bwd"], + **self._cfg, + ) + + def __call__( + self, + x: torch.Tensor, + edge_cache: EdgeFeatureCache, + radial_feat: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Compute the SO(2) local features and radial features via the fused op. + + Parameters + ---------- + x : torch.Tensor + Node features with shape (N, D, C_wide). + edge_cache : EdgeFeatureCache + Precomputed edge cache (provides ``src`` and the Wigner ``D_full``). + radial_feat : torch.Tensor + Per-edge radial features with shape (E, lmax+1, C). + + Returns + ------- + x_local : torch.Tensor + Post-focus-compete local features with shape (E, F, D_m, Cf). + rad_feat : torch.Tensor + Projected radial features with shape (E, D_m, C_wide); its ``l = 0`` + slice is consumed by the attention aggregation. + """ + if self._fwd_runner is None: + self._build() + conv = self._conv + src = edge_cache.src + + # === Step 1. Radial-/scalar-only tensors (kept in ordinary autograd) === + d_to_m = project_D_to_m( + edge_cache.D_full, + conv.coeff_index_m, + conv.ebed_dim_full, + None, + conv.lmax, + conv.mmax, + ) + rad_feat = radial_feat[:, conv.degree_index_m, :] + rad_feat = conv.radial_hidden_proj(rad_feat) + mixer = conv.radial_degree_mixer + kernel_flat = mixer._project_radial(rad_feat) + compact = kernel_flat.view(src.shape[0], mixer.degree_kernel_size, mixer.rank) + kc = mixer._scatter_rank_kernel(compact).squeeze(-1) + + # === Step 2. Fused value path -> pre-focus-compete local + l=0 scalar === + x_local, focus_gate = _SO2ValuePathFunction.apply( + x, d_to_m, kc, src, self._fwd_runner, self._bwd_runner + ) + + # === Step 3. Cross-focus softmax competition (rotation-free scalars) === + if conv.focus_compete and conv.n_focus > 1: + alpha = conv._focus_alpha(focus_gate) + x_local = x_local * alpha.to(dtype=x_local.dtype).unsqueeze(-1).unsqueeze( + -1 + ) + return x_local, rad_feat + + +def _is_supported(conv: SO2Convolution) -> bool: + """Return whether ``conv`` matches the validated fused-operator configuration.""" + if ( + conv.lmax != _SUPPORTED_LMAX + or conv.mmax != _SUPPORTED_MMAX + or conv.mixing_layers != _SUPPORTED_LAYERS + or conv.node_wise_grid_product is not None + or conv.use_so2_attn_res + or conv.layer_scale + or conv.radial_degree_mixer is None + or conv.radial_hidden_proj is None + ): + return False + if any(type(norm).__name__ != "Identity" for norm in conv.so2_inter_norms): + return False + if any(linear.bias0 is not None for linear in conv.so2_linears): + return False + non_linears = conv.non_linearities + if any( + type(non_linears[layer]).__name__ != "GatedActivation" + or ( + getattr(non_linears[layer].scalar_act, "activation", None) + or getattr(non_linears[layer], "activation_function", None) + ) + != "silu" + for layer in range(_SUPPORTED_LAYERS - 1) + ): + return False + return type(non_linears[-1]).__name__ == "Identity" + + +def make_cute_value_path(conv: SO2Convolution) -> _CuteSO2ValuePath | None: + """Build the fused CuTe value-path entry for a convolution block. + + Parameters + ---------- + conv : SO2Convolution + The convolution block to accelerate. + + Returns + ------- + _CuteSO2ValuePath or None + The entry callable when the CuTe backend is available and ``conv`` matches + the validated configuration; otherwise ``None`` (the caller falls back to + the reference path). + """ + if not SEZM_CUTE_AVAILABLE or not _is_supported(conv): + return None + return _CuteSO2ValuePath(conv) diff --git a/deepmd/kernels/triton/__init__.py b/deepmd/kernels/triton/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/kernels/triton/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/kernels/triton/sezm/__init__.py b/deepmd/kernels/triton/sezm/__init__.py new file mode 100644 index 0000000000..a0c3b2c7a1 --- /dev/null +++ b/deepmd/kernels/triton/sezm/__init__.py @@ -0,0 +1,46 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Hardware-accelerated SeZM/DPA4 operators. + +This package hosts ``make_fx``-composable Triton implementations of SeZM hot +paths. Kernel entry points are internal implementation details of the SeZM +descriptor; the package-level API only exposes availability. +""" + +from .force_assembly import ( + FORCE_ASSEMBLY_TRITON_AVAILABLE, +) +from .radial_mix import ( + RADIAL_MIX_TRITON_AVAILABLE, +) +from .so2_block_gemm import ( + SO2_BLOCK_GEMM_TRITON_AVAILABLE, +) +from .so2_rotation import ( + TRITON_ROTATION_AVAILABLE, +) +from .so2_stack_fp16x3 import ( + STACK_FP16X3_TRITON_AVAILABLE, +) +from .so2_value_path import ( + SO2_VALUE_PATH_TRITON_AVAILABLE, +) +from .wigner_monomials import ( + WIGNER_MONOMIALS_TRITON_AVAILABLE, +) + +# Every kernel module guards its ``@triton.jit`` definitions behind a ``triton`` +# import, so the module-level checks are equivalent. Expose a single +# package-level availability flag. +TRITON_AVAILABLE = ( + TRITON_ROTATION_AVAILABLE + and RADIAL_MIX_TRITON_AVAILABLE + and SO2_BLOCK_GEMM_TRITON_AVAILABLE + and SO2_VALUE_PATH_TRITON_AVAILABLE + and STACK_FP16X3_TRITON_AVAILABLE + and WIGNER_MONOMIALS_TRITON_AVAILABLE + and FORCE_ASSEMBLY_TRITON_AVAILABLE +) + +__all__ = [ + "TRITON_AVAILABLE", +] diff --git a/deepmd/kernels/triton/sezm/flash_atten.py b/deepmd/kernels/triton/sezm/flash_atten.py new file mode 100644 index 0000000000..13385cf860 --- /dev/null +++ b/deepmd/kernels/triton/sezm/flash_atten.py @@ -0,0 +1,1069 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +# pyright: reportMissingImports=false +# ruff: noqa: ANN001, ANN202, RUF005 +"""Fused flash-attention edge->node aggregation for the SeZM/DPA4 SO(2) attention. + +This module fuses the *entire* value-aggregation of the ``n_atten_head > 0`` +branch of :class:`SO2Convolution` into a single destination-segmented Triton +kernel, for the shipped ``mmax == 1`` block-diagonal Wigner-D layout with +``atten_f_mix == atten_v_proj == atten_o_proj == False`` (the deployed DPA4 +checkpoint). + +Operation +--------- +The eager attention aggregation is (per destination atom ``n``, degree row +``d`` of the packed ``(l, m)`` layout, hidden channel ``c = f * Cf + cf``):: + + out[n, d, c] = gate[n, f, h] * sum_{e: dst[e]=n} + alpha[e, f, h] * rescale[d] * RotBack_e(x_local)[d, c] + +with ``h = cf // head_dim`` the attention head of channel ``c`` and + + RotBack_e(x_local)[d, c] = sum_m Dt[e, d, l^2+l+m] * x_local[e, f, (l,m), cf] + +the block-diagonal ``local -> global`` rotation (block ``l`` of the transposed +Wigner-D). ``alpha`` is the destination-wise envelope-gated softmax weight (a +scalar per ``(edge, focus, head)``) and ``gate`` is the output-side head gate (a +scalar per ``(node, focus, head)``). + +Because the softmax weight ``alpha`` and the output gate ``gate`` are cheap, +memory-negligible ``(E|N, F, H)`` scalars, only the *heavy* tensor work is fused +into the kernel: the block-diagonal ``rotate_back`` of the value, the per-edge +softmax weighting, the inverse-rotation rescale, and the destination-segmented +reduction. This kernel therefore computes the ungated aggregate + + pre_gate[n, d, c] = rescale[d] * sum_{e: dst[e]=n} alpha[e, f, h] * RotBack_e[d, c] + +and the caller applies the node-level ``out = pre_gate * gate`` afterwards (its +backward is handled by autograd). This "two-step" split (a scalar segmented +softmax outside, the weighted rotate-back segment reduction fused inside) is +chosen over folding the online softmax into this kernel because the softmax +operates on scalar logits with no bandwidth cost, while the backward of the +fully-flashed variant would have to recompute or save the per-destination +softmax statistics -- for zero memory benefit. The chosen split keeps the heavy +kernel *linear* in ``(x_local, Dt, alpha)``, so the hand-written backward is +exact, saves no forward activation, and stays cheap. + +The single fused forward removes the two largest transient edge tensors of the +eager path -- the rotate-back message ``x_message`` (E, D, C_wide) and the +``alpha``-weighted value ``weighted_value`` (E, D, C_wide) -- and the +``index_add`` round trip, which is the source of the end-to-end peak-memory +reduction. + +Forward layout +-------------- +One Triton program per destination node, reducing its edge segment through a +destination-sorted CSR topology (``argsort`` + ``searchsorted`` built inside +the op; the traced edge list carries masked padding edges in arbitrary +destination order, so no sortedness invariant exists at this level): each +edge's block-diagonal ``rotate_back`` is assembled from the three retained +orders using per-degree register vectors (every reduced order is read exactly +once, no redundant gather), weighted by ``alpha``, and accumulated into a +``DIM``-row register tuple; the rescale is applied once per row at the final +store. The contention-free CSR reduction is both faster than a per-edge +atomic scatter (which serializes on the colliding edges of each atom at +typical neighbor counts) and deterministic. + +Backward layout +--------------- +One Triton program per edge (no cross-edge accumulation, hence no atomics): it +reloads ``grad_pre_gate`` at the edge's destination, recomputes ``rotate_back`` +from ``x_local`` and ``Dt``, and emits the exact per-edge gradients w.r.t. +``x_local`` (E, F, D_m, Cf), ``Dt`` (E, D, D; structural block-diagonal +non-zeros only, matching the shipped rotation kernels) and ``alpha`` (E, F, H). + +Registration +------------ +Forward and backward are functional ``torch.library.triton_op`` instances +(``mutates_args=()``) with registered fake kernels and an autograd formula whose +backward is itself a ``triton_op``. ``triton_op`` + ``wrap_triton`` (rather than +an opaque ``custom_op``) lets Inductor see through to the Triton kernels and bake +the cubins into the SeZM ``.pt2``, so the force graph +(``autograd.grad(energy, edge_vec)``) traces under ``make_fx`` and runs inside +the LAMMPS C++ runtime with no Python op registration. ``row_ptr`` / ``dst`` are +integer topology derived from the neighbor list (never the coordinates), so they +carry no gradient; ``rescale`` is a constant buffer and is likewise not +differentiated. +""" + +from __future__ import ( + annotations, +) + +import torch +from torch import ( + Tensor, +) +from torch.library import ( + wrap_triton, +) + +from deepmd.pt.model.descriptor.sezm_nn.indexing import ( + build_m_major_index, +) + +from .tile_configs import ( + flash_bwd_block_config, +) + +__all__ = [ + "FLASH_ATTEN_TRITON_AVAILABLE", + "build_row_ptr", + "flash_atten_aggregate", + "flash_atten_aggregate_reference", +] + +try: + import triton + import triton.language as tl + + FLASH_ATTEN_TRITON_AVAILABLE = True +except ImportError: # pragma: no cover - exercised only without triton + FLASH_ATTEN_TRITON_AVAILABLE = False + + +# ====================================================================== +# CSR row-pointer + per-row degree-map construction (integer, gradient-free) +# ====================================================================== +def build_row_ptr(dst_sorted: Tensor, n_nodes) -> Tensor: + """Build CSR row offsets ``(N + 1,)`` from an ascending destination index. + + ``searchsorted`` on the sorted destinations is the traceable, allocation-light + way to obtain segment boundaries; it lowers cleanly under ``make_fx`` and + needs no data-dependent control flow. ``n_nodes`` may be a ``SymInt``. + """ + boundaries = torch.arange( + n_nodes + 1, device=dst_sorted.device, dtype=dst_sorted.dtype + ) + return torch.searchsorted(dst_sorted, boundaries).to(torch.int64) + + +# ====================================================================== +# Eager reference / fallback implementation +# ====================================================================== +def flash_atten_aggregate_reference( + x_local: Tensor, + wigner_dt: Tensor, + rescale: Tensor, + alpha: Tensor, + dst: Tensor, + n_nodes: int, + lmax: int, + n_head: int, +) -> Tensor: + """Eager ground truth for :func:`flash_atten_aggregate` (block-diagonal). + + Parameters + ---------- + x_local : Tensor + Per-focus SO(2) features with shape ``(E, F, D_m, Cf)`` in the m-major + ``mmax == 1`` reduced layout, ``D_m = 3 * lmax + 1``. + wigner_dt : Tensor + Transposed block-diagonal Wigner-D with shape ``(E, D, D)``, + ``D = (lmax + 1) ** 2``. + rescale : Tensor + Inverse-rotation degree rescale aligned with the packed layout, ``(D,)``. + alpha : Tensor + Envelope-gated softmax weight with shape ``(E, F, H)``. + dst : Tensor + Destination node indices with shape ``(E,)``. + n_nodes : int + Number of destination nodes ``N``. + lmax : int + Maximum degree. + n_head : int + Number of attention heads ``H``. + + Returns + ------- + Tensor + Ungated aggregate ``pre_gate`` with shape ``(N, D, C_wide)``, + ``C_wide = F * Cf``. + """ + n_edge, n_focus, reduced_dim, focus_dim = x_local.shape + dim = (int(lmax) + 1) ** 2 + c_wide = n_focus * focus_dim + head_dim = focus_dim // int(n_head) + coeff = build_m_major_index(int(lmax), 1, device=x_local.device) + + xl_std = x_local.transpose(1, 2).reshape(n_edge, reduced_dim, c_wide) + dt_from_m = wigner_dt[:, :dim, :dim].index_select(2, coeff) # (E, D, D_m) + # Cast the constant fp64 ``rescale`` to the feature dtype so the reduction + # stays in the caller's compute precision (no-op for fp64). + resc = rescale.view(1, dim, 1).to(x_local.dtype) + rb = torch.bmm(dt_from_m, xl_std) * resc # (E, D, C_wide) + # alpha (E, F, H) -> per-channel weight (E, C_wide) with c = f*Cf + h*head_dim + ch + alpha_full = alpha.repeat_interleave(head_dim, dim=2).reshape(n_edge, c_wide) + weighted = rb * alpha_full[:, None, :] + out = x_local.new_zeros(n_nodes, dim, c_wide) + out.index_add_(0, dst, weighted) + return out + + +def _flash_atten_backward_reference( + grad_pre_gate: Tensor, + x_local: Tensor, + wigner_dt: Tensor, + rescale: Tensor, + alpha: Tensor, + dst: Tensor, + lmax: int, + n_head: int, +) -> tuple[Tensor, Tensor, Tensor]: + """Closed-form eager backward of :func:`flash_atten_aggregate_reference`. + + A closed form (not a nested ``autograd.grad``) is required because the + backward operator carries no autograd formula and is dispatched under + ``_AutoDispatchBelowAutograd`` when the SeZM ``.pt2`` force graph is replayed + under :func:`torch.no_grad`, matching the discipline in ``radial_mix.py``. + + Returns ``(grad_x_local, grad_wigner, grad_alpha)``. + """ + n_edge, n_focus, reduced_dim, focus_dim = x_local.shape + dim = (int(lmax) + 1) ** 2 + c_wide = n_focus * focus_dim + head_dim = focus_dim // int(n_head) + coeff = build_m_major_index(int(lmax), 1, device=x_local.device) + + xl_std = x_local.transpose(1, 2).reshape(n_edge, reduced_dim, c_wide) + dt_from_m = wigner_dt[:, :dim, :dim].index_select(2, coeff) # (E, D, D_m) + rb_pre = torch.bmm(dt_from_m, xl_std) # (E, D, C_wide) + # Cast ``rescale`` to the feature dtype (see the forward reference). + resc = rescale.view(1, dim, 1).to(x_local.dtype) + rb = rb_pre * resc + alpha_full = alpha.repeat_interleave(head_dim, dim=2).reshape(n_edge, c_wide) + + grad_weighted = grad_pre_gate.index_select(0, dst) # (E, D, C_wide) + + # grad w.r.t. alpha: sum over degree rows and head channels of grad*rb. + grad_alpha_full = (grad_weighted * rb).sum(dim=1) # (E, C_wide) + grad_alpha = grad_alpha_full.reshape(n_edge, n_focus, int(n_head), head_dim).sum( + dim=3 + ) # (E, F, H) + + # grad w.r.t. the rotate-back message, then split into x_local and Dt grads. + grad_rb_pre = grad_weighted * alpha_full[:, None, :] * resc # (E, D, C_wide) + grad_xl_std = torch.bmm(dt_from_m.transpose(1, 2), grad_rb_pre) # (E, D_m, C_wide) + grad_x_local = grad_xl_std.reshape( + n_edge, reduced_dim, n_focus, focus_dim + ).transpose(1, 2) # (E, F, D_m, Cf) + + grad_dt_from_m = torch.bmm(grad_rb_pre, xl_std.transpose(1, 2)) # (E, D, D_m) + grad_block = wigner_dt.new_zeros(n_edge, dim, dim) + grad_block.index_copy_(2, coeff, grad_dt_from_m) + grad_wigner = torch.zeros_like(wigner_dt) + grad_wigner[:, :dim, :dim] = grad_block + return grad_x_local.contiguous(), grad_wigner, grad_alpha + + +# ====================================================================== +# Triton kernels (mmax == 1; LMAX / layout are constexpr; channels vectorized) +# ====================================================================== +if FLASH_ATTEN_TRITON_AVAILABLE: + # The segmented forward carries a DIM-row register accumulator per + # program, so low warp counts dominate; higher counts only pay off for + # wide channel tiles. + _FWD_CONFIGS = [ + triton.Config({}, num_warps=1, num_stages=1), + triton.Config({}, num_warps=1, num_stages=2), + triton.Config({}, num_warps=2, num_stages=1), + triton.Config({}, num_warps=2, num_stages=2), + triton.Config({}, num_warps=4, num_stages=2), + ] + _BWD_CONFIGS = [ + triton.Config({}, num_warps=1, num_stages=1), + triton.Config({}, num_warps=2, num_stages=1), + triton.Config({}, num_warps=4, num_stages=1), + triton.Config({}, num_warps=2, num_stages=2), + triton.Config({}, num_warps=4, num_stages=2), + ] + + @triton.autotune(configs=_FWD_CONFIGS, key=["C_wide"]) + @triton.jit + def _flash_fwd_kernel( + xl_ptr, + dt_ptr, + resc_ptr, + w_ptr, + order_ptr, + row_ptr_ptr, + out_ptr, + n_node, + C_wide, + xl_se, + xl_sf, + xl_sr, + xl_sc, + dt_se, + dt_sr, + dt_sk, + w_se, + w_sf, + w_sh, + o_sn, + o_sd, + o_sc, + LMAX: tl.constexpr, + CF: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_C: tl.constexpr, + ): + """One program per node: indirect CSR segment reduction of the rotate-back. + + ``order`` lists edge ids sorted by destination and ``row_ptr`` holds + the segment offsets, so program ``n`` reduces the edges + ``order[row_ptr[n]..row_ptr[n+1]]`` -- an indirect, atomic-free and + deterministic destination reduction that replaces the per-edge atomic + scatter and accepts any edge order (the compiled SeZM graph keeps + masked padding edges, so no sortedness invariant exists). Per edge, + each retained reduced order is read exactly once and neither the + rotate-back message nor the weighted value is materialized to DRAM. + Channels are the vectorized axis; ``c = f * Cf + cf`` decodes the + per-focus ``x_local`` layout in place and the attention head is + ``h = cf // head_dim``. The ``DIM``-row accumulator lives in a + loop-carried register tuple, so the rescale is applied once per row + at the final store. + """ + DIM: tl.constexpr = (LMAX + 1) * (LMAX + 1) + + node = tl.program_id(0).to(tl.int64) + chan = tl.arange(0, BLOCK_C) + cmask = chan < C_wide + beg = tl.load(row_ptr_ptr + node).to(tl.int64) + end = tl.load(row_ptr_ptr + node + 1).to(tl.int64) + + # Channel decode c = f * Cf + cf, head h = cf // head_dim. Masked + # lanes clamp the focus index so pointer arithmetic stays in range. + fv = tl.where(cmask, chan // CF, 0) + cfv = chan % CF + hv = cfv // HEAD_DIM + xl_co = fv * xl_sf + cfv * xl_sc # per-channel focus offset into x_local + w_col = fv * w_sf + hv * w_sh # per-channel (focus, head) offset into alpha + + acc = () + for _ in tl.static_range(DIM): + acc = acc + (tl.zeros((BLOCK_C,), dtype=tl.float32),) + + for i in range(beg, end): + edge = tl.load(order_ptr + i).to(tl.int64) + wv = tl.load(w_ptr + edge * w_se + w_col, mask=cmask, other=0.0).to( + tl.float32 + ) + new_acc = () + for l in tl.static_range(0, LMAX + 1): + base = l * l + r0 = base + l # packed column of order m=0 + xl0 = tl.load( + xl_ptr + edge * xl_se + l * xl_sr + xl_co, mask=cmask, other=0.0 + ).to(tl.float32) + if l >= 1: + xlm = tl.load( + xl_ptr + edge * xl_se + (LMAX + l) * xl_sr + xl_co, + mask=cmask, + other=0.0, + ).to(tl.float32) + xlp = tl.load( + xl_ptr + edge * xl_se + (2 * LMAX + l) * xl_sr + xl_co, + mask=cmask, + other=0.0, + ).to(tl.float32) + for j in tl.static_range(0, 2 * l + 1): + d = base + j # full packed output row + rb = ( + tl.load(dt_ptr + edge * dt_se + d * dt_sr + r0 * dt_sk).to( + tl.float32 + ) + * xl0 + ) + if l >= 1: + rb += ( + tl.load( + dt_ptr + edge * dt_se + d * dt_sr + (r0 - 1) * dt_sk + ).to(tl.float32) + * xlm + ) + rb += ( + tl.load( + dt_ptr + edge * dt_se + d * dt_sr + (r0 + 1) * dt_sk + ).to(tl.float32) + * xlp + ) + # Loop-carried tuples require inline constexpr subscripts + # (the Triton frontend rejects composite index variables). + new_acc = new_acc + (acc[l * l + j] + rb * wv,) + acc = new_acc + + for d in tl.static_range(DIM): + resc = tl.load(resc_ptr + d).to(tl.float32) + tl.store( + out_ptr + node * o_sn + d * o_sd + chan * o_sc, + acc[d] * resc, + mask=cmask, + ) + + @triton.autotune(configs=_BWD_CONFIGS, key=["C_wide"]) + @triton.jit + def _flash_bwd_kernel( + gp_ptr, + xl_ptr, + dt_ptr, + resc_ptr, + w_ptr, + dst_ptr, + gxl_ptr, + gdt_ptr, + gw_ptr, + n_edge, + C_wide, + gp_sn, + gp_sd, + gp_sc, + xl_se, + xl_sf, + xl_sr, + xl_sc, + dt_se, + dt_sr, + dt_sk, + w_se, + w_sf, + w_sh, + gxl_se, + gxl_sf, + gxl_sr, + gxl_sc, + gdt_se, + gdt_sr, + gdt_sk, + gw_se, + gw_sf, + gw_sh, + LMAX: tl.constexpr, + CF: tl.constexpr, + HEAD_DIM: tl.constexpr, + NFOCUS: tl.constexpr, + NHEAD: tl.constexpr, + BLOCK_C: tl.constexpr, + ): + """One program per edge: exact per-edge gradients of the fused forward. + + Reloads ``grad_pre_gate`` at the edge's destination, recomputes the + block-diagonal ``rotate_back`` from ``x_local`` / ``Dt``, and stores + ``grad_x_local``, ``grad_Dt`` (structural non-zeros) and ``grad_alpha`` + (reduced over each (focus, head) channel group). No cross-edge + accumulation, hence no atomics. + """ + edge = tl.program_id(0).to(tl.int64) + n = tl.load(dst_ptr + edge).to(tl.int64) + chan = tl.arange(0, BLOCK_C) + cmask = chan < C_wide + fv = chan // CF + cfv = chan % CF + hv = cfv // HEAD_DIM + xl_co = fv * xl_sf + cfv * xl_sc + gxl_co = fv * gxl_sf + cfv * gxl_sc + w_col = fv * w_sf + hv * w_sh + grp = fv * NHEAD + hv # (BLOCK_C,) flat (focus, head) group id + + wv = tl.load(w_ptr + edge * w_se + w_col, mask=cmask, other=0.0).to(tl.float32) + gw_chan = tl.zeros((BLOCK_C,), dtype=tl.float32) + + for l in tl.static_range(0, LMAX + 1): + base = l * l + r0 = base + l + xl0 = tl.load( + xl_ptr + edge * xl_se + l * xl_sr + xl_co, mask=cmask, other=0.0 + ).to(tl.float32) + gxl0 = tl.zeros((BLOCK_C,), dtype=tl.float32) + if l >= 1: + xlm = tl.load( + xl_ptr + edge * xl_se + (LMAX + l) * xl_sr + xl_co, + mask=cmask, + other=0.0, + ).to(tl.float32) + xlp = tl.load( + xl_ptr + edge * xl_se + (2 * LMAX + l) * xl_sr + xl_co, + mask=cmask, + other=0.0, + ).to(tl.float32) + gxlm = tl.zeros((BLOCK_C,), dtype=tl.float32) + gxlp = tl.zeros((BLOCK_C,), dtype=tl.float32) + for j in tl.static_range(0, 2 * l + 1): + d = base + j + resc = tl.load(resc_ptr + d).to(tl.float32) + gpr = ( + tl.load( + gp_ptr + n * gp_sn + d * gp_sd + chan * gp_sc, + mask=cmask, + other=0.0, + ).to(tl.float32) + * resc + ) + grad_rb = gpr * wv + w0 = tl.load(dt_ptr + edge * dt_se + d * dt_sr + r0 * dt_sk).to( + tl.float32 + ) + rb = w0 * xl0 + gxl0 += w0 * grad_rb + tl.store( + gdt_ptr + edge * gdt_se + d * gdt_sr + r0 * gdt_sk, + tl.sum(grad_rb * xl0).to(gdt_ptr.dtype.element_ty), + ) + if l >= 1: + wm = tl.load( + dt_ptr + edge * dt_se + d * dt_sr + (r0 - 1) * dt_sk + ).to(tl.float32) + wp = tl.load( + dt_ptr + edge * dt_se + d * dt_sr + (r0 + 1) * dt_sk + ).to(tl.float32) + rb += wm * xlm + wp * xlp + gxlm += wm * grad_rb + gxlp += wp * grad_rb + tl.store( + gdt_ptr + edge * gdt_se + d * gdt_sr + (r0 - 1) * gdt_sk, + tl.sum(grad_rb * xlm).to(gdt_ptr.dtype.element_ty), + ) + tl.store( + gdt_ptr + edge * gdt_se + d * gdt_sr + (r0 + 1) * gdt_sk, + tl.sum(grad_rb * xlp).to(gdt_ptr.dtype.element_ty), + ) + gw_chan += gpr * rb + tl.store( + gxl_ptr + edge * gxl_se + l * gxl_sr + gxl_co, + gxl0.to(gxl_ptr.dtype.element_ty), + mask=cmask, + ) + if l >= 1: + tl.store( + gxl_ptr + edge * gxl_se + (LMAX + l) * gxl_sr + gxl_co, + gxlm.to(gxl_ptr.dtype.element_ty), + mask=cmask, + ) + tl.store( + gxl_ptr + edge * gxl_se + (2 * LMAX + l) * gxl_sr + gxl_co, + gxlp.to(gxl_ptr.dtype.element_ty), + mask=cmask, + ) + + for g in tl.static_range(0, NFOCUS * NHEAD): + f = g // NHEAD + h = g % NHEAD + val = tl.sum(tl.where((grp == g) & cmask, gw_chan, 0.0)) + tl.store( + gw_ptr + edge * gw_se + f * gw_sf + h * gw_sh, + val.to(gw_ptr.dtype.element_ty), + ) + + @triton.jit + def _flash_bwd_block_kernel( + gp_ptr, # (N, D, C) upstream gradient of the ungated aggregate + xl_ptr, # (E, F, D_m, Cf) local features + dt_ptr, # (E, D, D) transposed block-diagonal Wigner-D, contiguous + resc_ptr, # (D,) inverse-rotation rescale + w_ptr, # (E, F, H) attention weights, contiguous + dst_ptr, # (E,) + gxl_ptr, # (E, F, D_m, Cf) out + gdt_ptr, # (E, D, D) out (pre-zeroed, structural non-zeros written) + gw_ptr, # (E, F, H) out, contiguous + n_edge, + gp_sn, + gp_sd, + xl_se, + xl_sf, + xl_sr, + xl_sc, + gxl_se, + gxl_sf, + gxl_sr, + gxl_sc, + L: tl.constexpr, + CF: tl.constexpr, + CW: tl.constexpr, # C_wide = F * Cf + CP: tl.constexpr, # next power of two >= CW (vector lane count) + HEAD_DIM: tl.constexpr, + NHEAD: tl.constexpr, + BLOCK_E: tl.constexpr, + ): + """Edge-block variant of the flash-attention backward. + + The per-edge kernel closes one cross-lane ``tl.sum`` per structural + Wigner non-zero -- serialized warp shuffle-reduction chains that + dominate its runtime on narrow hidden widths. This variant processes + ``BLOCK_E`` edges per program with the channel axis kept as the + vector axis: every ``grad_Dt`` entry becomes one batched axis-1 + reduction of a ``(BLOCK_E, CP)`` tile, every ``grad_x_local`` term is + a rank-1 vector FMA with the per-edge Wigner scalar broadcast over + channels, and the per-edge scalars are loaded as coalesced + ``(BLOCK_E,)`` vectors. Channels are padded to the power-of-two lane + count ``CP`` with masked lanes (no memory traffic, only register + pressure, which the launch table absorbs with a smaller ``BLOCK_E``). + + The schedule wins only where the reduction overhead of the per-edge + kernel dominates; :func:`~.tile_configs.flash_bwd_block_config` acts + as the win list. + """ + DIM: tl.constexpr = (L + 1) * (L + 1) + NG: tl.constexpr = (CW // CF) * NHEAD # flat (focus, head) group count + PADDED: tl.constexpr = CP != CW + + pid = tl.program_id(0) + offs_e = (pid * BLOCK_E + tl.arange(0, BLOCK_E)).to(tl.int64) + e_mask = offs_e < n_edge + eq = tl.where(e_mask, offs_e, 0) + chan = tl.arange(0, CP) + if PADDED: + c_mask = chan < CW + em = e_mask[:, None] & c_mask[None, :] + # Masked lanes clamp their decode so pointer arithmetic stays valid. + fv = tl.where(c_mask, chan // CF, 0) + cfv = tl.where(c_mask, chan % CF, 0) + else: + em = e_mask[:, None] + fv = chan // CF + cfv = chan % CF + hv = cfv // HEAD_DIM + grp = fv * NHEAD + hv # (CP,) flat (focus, head) group id + + dst = tl.load(dst_ptr + eq, mask=e_mask, other=0).to(tl.int64) + # Attention weight broadcast to channels: w[e, f(c), h(c)]. + wv = tl.load(w_ptr + (eq * NG)[:, None] + grp[None, :], mask=em, other=0.0) + + xl_row = xl_ptr + (eq * xl_se)[:, None] + (fv * xl_sf + cfv * xl_sc)[None, :] + gxl_row = ( + gxl_ptr + (eq * gxl_se)[:, None] + (fv * gxl_sf + cfv * gxl_sc)[None, :] + ) + dt_base = dt_ptr + eq * DIM * DIM + gdt_base = gdt_ptr + eq * DIM * DIM + # The launcher passes a contiguous upstream gradient (channel stride 1). + gp_row = gp_ptr + (dst * gp_sn)[:, None] + chan[None, :] + + gw_acc = tl.zeros((BLOCK_E, CP), dtype=tl.float32) + + for l in tl.static_range(0, L + 1): + base = l * l + r0 = base + l # packed reduced column of order m = 0 + xl0 = tl.load(xl_row + l * xl_sr, mask=em, other=0.0) + gxl0 = tl.zeros((BLOCK_E, CP), dtype=tl.float32) + if l >= 1: + xlm = tl.load(xl_row + (L + l) * xl_sr, mask=em, other=0.0) + xlp = tl.load(xl_row + (2 * L + l) * xl_sr, mask=em, other=0.0) + gxlm = tl.zeros((BLOCK_E, CP), dtype=tl.float32) + gxlp = tl.zeros((BLOCK_E, CP), dtype=tl.float32) + for j in tl.static_range(0, 2 * l + 1): + d = base + j + resc = tl.load(resc_ptr + d) + gpr = tl.load(gp_row + d * gp_sd, mask=em, other=0.0) * resc + grad_rb = gpr * wv + dt0 = tl.load(dt_base + d * DIM + r0, mask=e_mask, other=0.0) + gxl0 += dt0[:, None] * grad_rb + tl.store( + gdt_base + d * DIM + r0, + tl.sum(grad_rb * xl0, axis=1), + mask=e_mask, + ) + rb = dt0[:, None] * xl0 + if l >= 1: + dtm = tl.load(dt_base + d * DIM + (r0 - 1), mask=e_mask, other=0.0) + dtp = tl.load(dt_base + d * DIM + (r0 + 1), mask=e_mask, other=0.0) + gxlm += dtm[:, None] * grad_rb + gxlp += dtp[:, None] * grad_rb + tl.store( + gdt_base + d * DIM + (r0 - 1), + tl.sum(grad_rb * xlm, axis=1), + mask=e_mask, + ) + tl.store( + gdt_base + d * DIM + (r0 + 1), + tl.sum(grad_rb * xlp, axis=1), + mask=e_mask, + ) + rb += dtm[:, None] * xlm + dtp[:, None] * xlp + gw_acc += gpr * rb + tl.store(gxl_row + l * gxl_sr, gxl0, mask=em) + if l >= 1: + tl.store(gxl_row + (L + l) * gxl_sr, gxlm, mask=em) + tl.store(gxl_row + (2 * L + l) * gxl_sr, gxlp, mask=em) + + # grad_alpha: reduce gw_acc over each (focus, head) channel group. + for g in tl.static_range(NG): + val = tl.sum(tl.where((grp == g)[None, :] & em, gw_acc, 0.0), axis=1) + tl.store(gw_ptr + eq * NG + g, val, mask=e_mask) + + +# ====================================================================== +# Tile helper + zero-edge guard +# ====================================================================== +def _tile_channels(channels: int) -> int: + """Smallest power-of-two channel tile of at least 16 covering ``channels``.""" + tile = 16 + while tile < int(channels): + tile *= 2 + return tile + + +def _has_no_edges(n_edge) -> bool: + """Return true only for a concrete zero-edge call (SymInt-safe guard).""" + return type(n_edge) is int and n_edge == 0 + + +# ====================================================================== +# Triton launch wrappers +# ====================================================================== +def _launch_forward( + x_local: Tensor, + wigner_dt: Tensor, + rescale: Tensor, + alpha: Tensor, + dst: Tensor, + n_nodes, + lmax: int, + n_head: int, +) -> Tensor: + n_edge, n_focus, _reduced_dim, focus_dim = x_local.shape + dim = (int(lmax) + 1) ** 2 + c_wide = n_focus * focus_dim + # The segment reduction accumulates in float32 registers regardless of + # the input precision and writes each output row exactly once. + out = torch.empty(n_nodes, dim, c_wide, dtype=torch.float32, device=x_local.device) + if _has_no_edges(n_edge): + return out.zero_().to(x_local.dtype) + # Destination CSR topology built inside the op: the graph-level edge list + # carries masked padding edges in arbitrary destination order, so the + # segment reduction needs its own sorted order (integer ops, no gradient). + order = torch.argsort(dst) + boundaries = torch.arange(n_nodes + 1, device=dst.device, dtype=dst.dtype) + row_ptr = torch.searchsorted(dst.index_select(0, order), boundaries) + wrap_triton(_flash_fwd_kernel)[(n_nodes,)]( + x_local, + wigner_dt, + rescale, + alpha, + order, + row_ptr, + out, + n_nodes, + c_wide, + x_local.stride(0), + x_local.stride(1), + x_local.stride(2), + x_local.stride(3), + wigner_dt.stride(0), + wigner_dt.stride(1), + wigner_dt.stride(2), + alpha.stride(0), + alpha.stride(1), + alpha.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), + LMAX=int(lmax), + CF=focus_dim, + HEAD_DIM=focus_dim // int(n_head), + BLOCK_C=_tile_channels(c_wide), + ) + return out.to(x_local.dtype) + + +def _launch_backward( + grad_pre_gate: Tensor, + x_local: Tensor, + wigner_dt: Tensor, + rescale: Tensor, + alpha: Tensor, + dst: Tensor, + lmax: int, + n_head: int, +) -> tuple[Tensor, Tensor, Tensor]: + n_edge, n_focus, _reduced_dim, focus_dim = x_local.shape + c_wide = n_focus * focus_dim + grad_x_local = torch.empty_like(x_local) + grad_wigner = torch.zeros_like(wigner_dt, memory_format=torch.contiguous_format) + grad_alpha = torch.empty_like(alpha) + if _has_no_edges(n_edge): + return grad_x_local, grad_wigner, grad_alpha + # The edge-block schedule engages on swept-and-winning (C_wide, lmax) + # keys; every other shape keeps the per-edge kernel. The branch resolves + # at trace time, so exactly one kernel reaches the compiled graph. + block_cfg = flash_bwd_block_config(int(c_wide), int(lmax)) + if block_cfg is not None: + block_e, warps, stages = block_cfg + wrap_triton(_flash_bwd_block_kernel)[(triton.cdiv(n_edge, block_e),)]( + grad_pre_gate, + x_local, + wigner_dt.contiguous(), + rescale, + alpha, + dst, + grad_x_local, + grad_wigner, + grad_alpha, + n_edge, + grad_pre_gate.stride(0), + grad_pre_gate.stride(1), + x_local.stride(0), + x_local.stride(1), + x_local.stride(2), + x_local.stride(3), + grad_x_local.stride(0), + grad_x_local.stride(1), + grad_x_local.stride(2), + grad_x_local.stride(3), + L=int(lmax), + CF=focus_dim, + CW=c_wide, + CP=triton.next_power_of_2(c_wide), + HEAD_DIM=focus_dim // int(n_head), + NHEAD=int(n_head), + BLOCK_E=block_e, + num_warps=warps, + num_stages=stages, + ) + return grad_x_local, grad_wigner, grad_alpha + wrap_triton(_flash_bwd_kernel)[(n_edge,)]( + grad_pre_gate, + x_local, + wigner_dt, + rescale, + alpha, + dst, + grad_x_local, + grad_wigner, + grad_alpha, + n_edge, + c_wide, + grad_pre_gate.stride(0), + grad_pre_gate.stride(1), + grad_pre_gate.stride(2), + x_local.stride(0), + x_local.stride(1), + x_local.stride(2), + x_local.stride(3), + wigner_dt.stride(0), + wigner_dt.stride(1), + wigner_dt.stride(2), + alpha.stride(0), + alpha.stride(1), + alpha.stride(2), + grad_x_local.stride(0), + grad_x_local.stride(1), + grad_x_local.stride(2), + grad_x_local.stride(3), + grad_wigner.stride(0), + grad_wigner.stride(1), + grad_wigner.stride(2), + grad_alpha.stride(0), + grad_alpha.stride(1), + grad_alpha.stride(2), + LMAX=int(lmax), + CF=focus_dim, + HEAD_DIM=focus_dim // int(n_head), + NFOCUS=n_focus, + NHEAD=int(n_head), + BLOCK_C=_tile_channels(c_wide), + ) + return grad_x_local, grad_wigner, grad_alpha + + +# ====================================================================== +# Dispatch helpers (triton on CUDA float, eager otherwise) +# ====================================================================== +def _use_triton(tensor: Tensor) -> bool: + return ( + FLASH_ATTEN_TRITON_AVAILABLE + and tensor.is_cuda + and tensor.dtype in (torch.float16, torch.bfloat16, torch.float32) + ) + + +def _forward_impl( + x_local: Tensor, + wigner_dt: Tensor, + rescale: Tensor, + alpha: Tensor, + row_ptr: Tensor, + dst: Tensor, + lmax: int, + n_head: int, +) -> Tensor: + if not _use_triton(x_local): + return flash_atten_aggregate_reference( + x_local, + wigner_dt, + rescale, + alpha, + dst, + int(row_ptr.shape[0] - 1), + int(lmax), + int(n_head), + ) + # ``x_local`` is passed with its native (possibly transposed) strides -- the + # kernel addresses it through the stride arguments, and preserving the layout + # keeps the backward's ``grad_x_local`` stride-compatible with the stock + # ``rotate_back_block_so2`` path so the downstream SO(2) backward reshapes + # stay viewable under symbolic (make_fx / AOT) restride. ``N`` is taken from + # ``row_ptr.shape`` (a SymInt) so the ``natoms`` axis is never specialized. + return _launch_forward( + x_local, + wigner_dt, + rescale.contiguous(), + alpha.contiguous(), + dst.contiguous(), + row_ptr.shape[0] - 1, + int(lmax), + int(n_head), + ) + + +def _backward_impl( + grad_pre_gate: Tensor, + x_local: Tensor, + wigner_dt: Tensor, + rescale: Tensor, + alpha: Tensor, + dst: Tensor, + lmax: int, + n_head: int, +) -> tuple[Tensor, Tensor, Tensor]: + if not _use_triton(x_local): + return _flash_atten_backward_reference( + grad_pre_gate, + x_local, + wigner_dt, + rescale, + alpha, + dst, + int(lmax), + int(n_head), + ) + # Keep ``x_local``'s native strides so ``grad_x_local = empty_like(x_local)`` + # matches the stock ``rotate_back_block_so2`` backward layout (see the + # forward note); only ``grad_pre_gate`` is made contiguous for coalesced + # gather reads. + return _launch_backward( + grad_pre_gate.contiguous(), + x_local, + wigner_dt, + rescale.contiguous(), + alpha.contiguous(), + dst.contiguous(), + int(lmax), + int(n_head), + ) + + +# ====================================================================== +# Functional triton_op + fake + autograd registration +# ====================================================================== +_flash_op = torch.library.triton_op( + "sezm_triton::flash_atten_aggregate", mutates_args=() +)(_forward_impl) + +_flash_bwd_op = torch.library.triton_op( + "sezm_triton::flash_atten_aggregate_bwd", mutates_args=() +)(_backward_impl) + + +@_flash_op.register_fake +def _(x_local, wigner_dt, rescale, alpha, row_ptr, dst, lmax, n_head): + n_focus = x_local.shape[1] + focus_dim = x_local.shape[3] + dim = (int(lmax) + 1) ** 2 + # ``N`` is derived from ``row_ptr`` (not an int arg) so the dynamic ``natoms`` + # axis survives ``torch.export`` without specialization. + return x_local.new_empty(row_ptr.shape[0] - 1, dim, n_focus * focus_dim) + + +@_flash_bwd_op.register_fake +def _(grad_pre_gate, x_local, wigner_dt, rescale, alpha, dst, lmax, n_head): + return ( + torch.empty_like(x_local), + torch.empty_like(wigner_dt), + torch.empty_like(alpha), + ) + + +def _setup_context(ctx, inputs, output): + x_local, wigner_dt, rescale, alpha, row_ptr, dst, lmax, n_head = inputs + ctx.save_for_backward(x_local, wigner_dt, rescale, alpha, dst) + ctx.lmax = lmax + ctx.n_head = n_head + + +def _backward(ctx, grad_out): + x_local, wigner_dt, rescale, alpha, dst = ctx.saved_tensors + grad_x_local, grad_wigner, grad_alpha = _flash_bwd_op( + grad_out.contiguous(), + x_local, + wigner_dt, + rescale, + alpha, + dst, + ctx.lmax, + ctx.n_head, + ) + # inputs: x_local, wigner_dt, rescale, alpha, row_ptr, dst, lmax, n_head. + # rescale is a constant buffer; row_ptr/dst are integer topology. + return grad_x_local, grad_wigner, None, grad_alpha, None, None, None, None + + +_flash_op.register_autograd(_backward, setup_context=_setup_context) + + +# ====================================================================== +# Public API +# ====================================================================== +def flash_atten_aggregate( + x_local: Tensor, + wigner_dt: Tensor, + rescale: Tensor, + alpha: Tensor, + row_ptr: Tensor, + dst: Tensor, + lmax: int, + n_head: int, +) -> Tensor: + """Fused block-diagonal rotate-back + envelope-softmax weighting + edge scatter. + + Computes the ungated attention aggregate + + ``pre_gate[n, d, c] = rescale[d] * + sum_{e: dst[e]=n} alpha[e, f, h] * RotBack_e(x_local)[d, c]`` + + for the ``mmax == 1`` block-diagonal layout, equivalent to the eager + ``rotate_back -> rescale -> value-reshape -> alpha-weight -> index_add`` chain + of :class:`SO2Convolution` (the caller applies the node-level output gate + ``out = pre_gate * gate`` afterwards). + + Parameters + ---------- + x_local : Tensor + Per-focus SO(2) features with shape ``(E, F, D_m, Cf)``. + wigner_dt : Tensor + Transposed block-diagonal Wigner-D with shape ``(E, D, D)``. + rescale : Tensor + Inverse-rotation degree rescale with shape ``(D,)``. + alpha : Tensor + Envelope-gated softmax weight with shape ``(E, F, H)``. + row_ptr : Tensor + Row offsets with shape ``(N + 1,)`` from :func:`build_row_ptr`; only + its length carries the (SymInt) node count ``N`` for the output + allocation and the fake kernel, so the ``natoms`` axis is never + specialized. The forward builds its own destination-sorted CSR + topology from ``dst`` (the traced edge list carries masked padding + edges in arbitrary order), so no sortedness invariant is required. + dst : Tensor + Destination node indices with shape ``(E,)`` (the forward segment key + and the backward gather index). + lmax : int + Maximum degree. + n_head : int + Number of attention heads ``H``. + + Returns + ------- + Tensor + Ungated aggregate with shape ``(N, D, C_wide)``, ``C_wide = F * Cf``. + """ + return _flash_op( + x_local, wigner_dt, rescale, alpha, row_ptr, dst, int(lmax), int(n_head) + ) diff --git a/deepmd/kernels/triton/sezm/force_assembly.py b/deepmd/kernels/triton/sezm/force_assembly.py new file mode 100644 index 0000000000..7c49788fc3 --- /dev/null +++ b/deepmd/kernels/triton/sezm/force_assembly.py @@ -0,0 +1,246 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +# pyright: reportMissingImports=false +# ruff: noqa: ANN001, ANN202 +"""Segmented force / virial assembly from per-edge energy gradients. + +Given the per-edge gradient ``g_e = dE / d(edge_vec_e)`` of an edge-based +energy, the extended force and per-atom virial are + + ``F_k = sum_{dst(e)=k} g_e - sum_{src(e)=k} g_e`` + ``W_k = 0.5 * sum_{e: k in {src(e), dst(e)}} ( -g_e (x) edge_vec_e )``. + +The reference assembly issues four ``index_add`` scatters (force to both +endpoints, half virial to both endpoints) plus a materialized ``(E, 9)`` +outer product. Row-atomic scatters serialize on the colliding edges of each +atom, so this operator performs two CSR segment-reduction launches instead +(one over the destination order, one over the source order), each +recomputing the per-edge outer product on the fly. One program owns one +extended atom; the 12 output scalars (3 force + 9 virial) accumulate in +float64 registers over the segment, which both removes the atomic +serialization and tightens the summation error over the reference fp32 +atomics. + +The operator is inference-only in practice: the caller keeps the reference +path whenever the force graph must remain differentiable (``create_graph``), +so no autograd formula is registered. +""" + +from __future__ import ( + annotations, +) + +import torch +from torch import ( + Tensor, +) +from torch.library import ( + wrap_triton, +) + +__all__ = [ + "FORCE_ASSEMBLY_TRITON_AVAILABLE", + "edge_force_assembly", +] + +try: + import triton + import triton.language as tl + + FORCE_ASSEMBLY_TRITON_AVAILABLE = True +except ImportError: # pragma: no cover - exercised only without triton + FORCE_ASSEMBLY_TRITON_AVAILABLE = False + + +# ====================================================================== +# Eager reference / fallback implementation +# ====================================================================== +def _force_assembly_reference( + g: Tensor, + edge_vec: Tensor, + dst_order: Tensor, + dst_row_ptr: Tensor, + src_order: Tensor, + src_row_ptr: Tensor, +) -> tuple[Tensor, Tensor]: + """Eager ground truth built from the CSR topology via ``index_add``.""" + n_ext = dst_row_ptr.shape[0] - 1 + ar = torch.arange(n_ext, device=g.device, dtype=dst_order.dtype) + dst = torch.repeat_interleave(ar, dst_row_ptr[1:] - dst_row_ptr[:-1]) + src = torch.repeat_interleave(ar, src_row_ptr[1:] - src_row_ptr[:-1]) + g_dst = g.index_select(0, dst_order) + g_src = g.index_select(0, src_order) + force = g.new_zeros((n_ext, 3)) + force.index_add_(0, dst, g_dst) + force.index_add_(0, src, -g_src) + half_w_dst = -0.5 * torch.einsum( + "ek,ej->ekj", g_dst, edge_vec.index_select(0, dst_order) + ).reshape(-1, 9) + half_w_src = -0.5 * torch.einsum( + "ek,ej->ekj", g_src, edge_vec.index_select(0, src_order) + ).reshape(-1, 9) + virial = g.new_zeros((n_ext, 9)) + virial.index_add_(0, dst, half_w_dst) + virial.index_add_(0, src, half_w_src) + return force, virial + + +# ====================================================================== +# Triton kernels +# ====================================================================== +if FORCE_ASSEMBLY_TRITON_AVAILABLE: + + @triton.jit + def _force_segment_kernel( + g_ptr, # (E, 3) per-edge energy gradient + ev_ptr, # (E, 3) per-edge displacement + order_ptr, # (E,) edge ids sorted by the segment key + row_ptr_ptr, # (N_ext + 1,) CSR offsets into ``order`` + f_ptr, # (N_ext, 3) + w_ptr, # (N_ext, 9) + FORCE_SIGN: tl.constexpr, # +1 for the dst pass, -1 for the src pass + ACCUMULATE: tl.constexpr, # add into the outputs instead of overwriting + ): + """One endpoint pass of the force / virial segment reduction. + + The virial lanes address the ``(3, 3)`` outer product through a + padded 16-lane index ``(k, j) = (lane // 4, lane % 4)`` so both the + force and virial rows stay vectorized; the outer product + ``-0.5 * g_k * v_j`` is recomputed per edge in registers and never + materialized. Accumulation runs in float64. + """ + node = tl.program_id(0).to(tl.int64) + beg = tl.load(row_ptr_ptr + node).to(tl.int64) + end = tl.load(row_ptr_ptr + node + 1).to(tl.int64) + kf = tl.arange(0, 4) # force lanes (3 used) + kw = tl.arange(0, 16) # virial lanes (9 used) + f_mask = kf < 3 + w_mask = ((kw // 4) < 3) & ((kw % 4) < 3) + acc_f = tl.zeros((4,), dtype=tl.float64) + acc_w = tl.zeros((16,), dtype=tl.float64) + for i in range(beg, end): + e = tl.load(order_ptr + i).to(tl.int64) + g_vec = tl.load(g_ptr + e * 3 + kf, mask=f_mask, other=0.0).to(tl.float64) + v_j = tl.load(ev_ptr + e * 3 + kw % 4, mask=(kw % 4) < 3, other=0.0).to( + tl.float64 + ) + g_k = tl.load(g_ptr + e * 3 + kw // 4, mask=(kw // 4) < 3, other=0.0).to( + tl.float64 + ) + acc_f += g_vec + acc_w -= 0.5 * g_k * v_j + acc_f = acc_f * FORCE_SIGN + w_col = (kw // 4) * 3 + (kw % 4) + if ACCUMULATE: + f_prev = tl.load(f_ptr + node * 3 + kf, mask=f_mask, other=0.0).to( + tl.float64 + ) + acc_f += f_prev + w_prev = tl.load(w_ptr + node * 9 + w_col, mask=w_mask, other=0.0).to( + tl.float64 + ) + acc_w += w_prev + tl.store(f_ptr + node * 3 + kf, acc_f.to(f_ptr.dtype.element_ty), mask=f_mask) + tl.store( + w_ptr + node * 9 + w_col, acc_w.to(w_ptr.dtype.element_ty), mask=w_mask + ) + + +# ====================================================================== +# Dispatch, operator registration and public API +# ====================================================================== +def _use_triton(tensor: Tensor) -> bool: + return ( + FORCE_ASSEMBLY_TRITON_AVAILABLE + and tensor.is_cuda + and tensor.dtype in (torch.float32, torch.float64) + ) + + +def _force_assembly_impl( + g: Tensor, + edge_vec: Tensor, + dst_order: Tensor, + dst_row_ptr: Tensor, + src_order: Tensor, + src_row_ptr: Tensor, +) -> tuple[Tensor, Tensor]: + if not _use_triton(g): + return _force_assembly_reference( + g, edge_vec, dst_order, dst_row_ptr, src_order, src_row_ptr + ) + n_ext = dst_row_ptr.shape[0] - 1 + force = torch.empty((n_ext, 3), dtype=g.dtype, device=g.device) + virial = torch.empty((n_ext, 9), dtype=g.dtype, device=g.device) + wrap_triton(_force_segment_kernel)[(n_ext,)]( + g, + edge_vec, + dst_order, + dst_row_ptr, + force, + virial, + FORCE_SIGN=1, + ACCUMULATE=False, + num_warps=1, + num_stages=2, + ) + wrap_triton(_force_segment_kernel)[(n_ext,)]( + g, + edge_vec, + src_order, + src_row_ptr, + force, + virial, + FORCE_SIGN=-1, + ACCUMULATE=True, + num_warps=1, + num_stages=2, + ) + return force, virial + + +_force_assembly_op = torch.library.triton_op( + "sezm_triton::edge_force_assembly", mutates_args=() +)(_force_assembly_impl) + + +@_force_assembly_op.register_fake +def _(g, edge_vec, dst_order, dst_row_ptr, src_order, src_row_ptr): + n_ext = dst_row_ptr.shape[0] - 1 + return g.new_empty((n_ext, 3)), g.new_empty((n_ext, 9)) + + +def edge_force_assembly( + g: Tensor, + edge_vec: Tensor, + dst_order: Tensor, + dst_row_ptr: Tensor, + src_order: Tensor, + src_row_ptr: Tensor, +) -> tuple[Tensor, Tensor]: + """Assemble the extended force and per-atom virial from edge gradients. + + Parameters + ---------- + g : Tensor + Per-edge energy gradient ``dE / d(edge_vec)`` with shape (E, 3). + edge_vec : Tensor + Per-edge displacement vectors with shape (E, 3). + dst_order, src_order : Tensor + Edge ids sorted by destination / source extended index, each with + shape (E,) (from ``torch.argsort``). + dst_row_ptr, src_row_ptr : Tensor + CSR offsets into the respective order with shape (N_ext + 1,) + (from ``torch.searchsorted`` on the sorted keys); the length carries + the extended-atom count, so the atom axis is never specialized. + + Returns + ------- + force : Tensor + Extended force with shape (N_ext, 3). + virial : Tensor + Extended per-atom virial with shape (N_ext, 9), split symmetrically + between the two endpoints of each edge. + """ + return _force_assembly_op( + g, edge_vec, dst_order, dst_row_ptr, src_order, src_row_ptr + ) diff --git a/deepmd/pt/model/descriptor/sezm_nn/triton/radial_mix.py b/deepmd/kernels/triton/sezm/radial_mix.py similarity index 99% rename from deepmd/pt/model/descriptor/sezm_nn/triton/radial_mix.py rename to deepmd/kernels/triton/sezm/radial_mix.py index 187d1154de..6bf8fd4feb 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/triton/radial_mix.py +++ b/deepmd/kernels/triton/sezm/radial_mix.py @@ -39,7 +39,7 @@ Inference-only contract ----------------------- -The operator is opt-in through ``DP_TRITON_INFER`` and is only used in +The operator is opt-in through ``DP_TRITON_INFER >= 1`` and is only used in evaluation, where the force is obtained from ``autograd.grad(energy, coord)``. The backward therefore returns gradients with respect to ``compact`` and ``x_local`` (both of which carry a path to the coordinates) and ``None`` for diff --git a/deepmd/kernels/triton/sezm/so2_block_gemm.py b/deepmd/kernels/triton/sezm/so2_block_gemm.py new file mode 100644 index 0000000000..0f08ab3dcf --- /dev/null +++ b/deepmd/kernels/triton/sezm/so2_block_gemm.py @@ -0,0 +1,528 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +# pyright: reportMissingImports=false +# ruff: noqa: ANN001, ANN202 +"""Triton ``BN=64`` block-diagonal fp32 GEMM for the SeZM/DPA4 ``SO2Linear``. + +The ``SO2Linear`` mixing contracts an activation ``x`` of shape ``(F, E, K)`` +with a per-focus block-diagonal weight ``W`` of shape ``(F, K, N)``. The eager +path (``_block_diagonal_matmul``) slices the diagonal ``|m|`` blocks and issues a +``torch.bmm`` per block, concatenating the outputs. Two structural costs remain: + +* the assembled weight is presented as a *strided* view + (``permute(1, 0, 2)`` of the stored ``(K, F, N)`` parameter), and the block + concatenation materializes a fresh output; and +* the pow2 column tiling that cuBLAS/Triton default to wastes ~25% on the + ``N = 192`` block (192 rounds up to 256). + +This module drives the block-diagonal contraction with one Triton launch per +diagonal ``|m|`` block (the block dims are ``constexpr`` so the contraction loop +is statically sized and fully pipelined). Each launch + +1. consumes the strided ``(F, K, N)`` weight and the ``(F, E, K)`` activation + *without any contiguity copy* (all access is via strides), streaming only its + block's contraction range and never touching the structural off-``|m|`` zeros + or concatenating a fresh output, and +2. tiles the output ``N`` axis at exactly ``BN = 64`` -- a divisor of both 128 + and 192 -- so no column is padded. + +Every ``tl.dot`` runs with ``input_precision="ieee"`` (true IEEE fp32, no TF32), +matching the smooth potential-energy-surface contract of the descriptor. + +Composability +------------- +The forward and backward are functional ``torch.library.triton_op`` instances +(``mutates_args=()``) with registered fake kernels and an autograd formula, so +``make_fx(tracing_mode="symbolic") -> aot_module_simplified -> Inductor`` captures +the energy path together with the force autograd graph. ``triton_op`` + +``wrap_triton`` (vs ``custom_op``) lets Inductor see through to the Triton kernel +and bake the cubin into the AOTInductor ``.pt2``, exactly as +``so2_rotation.py`` / ``radial_mix.py`` do. + +Inference-only contract +----------------------- +The operator is opt-in and only used in evaluation, where the force is obtained +from ``autograd.grad(energy, coord)``. The block-diagonal weight is a parameter +(never a function of the coordinates), so the backward returns the gradient +w.r.t. the activation ``x`` -- which carries the coordinate path -- and ``None`` +for the weight, mirroring the parameter handling in ``radial_mix.py``. +""" + +from __future__ import ( + annotations, +) + +import torch +from torch import ( + Tensor, +) +from torch.library import ( + wrap_triton, +) + +# Activation-gradient backend for the force path. +# +# ``False`` (default): the backward uses the eager per-block ``bmm``, whose +# ``grad_out`` slices the Inductor memory planner reads through +# ``reinterpret_tensor`` (no copy), so the compiled force graph reuses the +# edge-sized gradient buffers and peak memory tracks the committed baseline. The +# Triton backward is faster in isolation (~1.4x over cuBLAS) but issues one +# launch per diagonal block; Inductor materializes a separate ``grad_out`` copy +# for each such consumer, inflating the compiled force-graph peak by ~13% for a +# ~1-2% end-to-end gain. The trade is unfavorable, so it remains opt-in. +_TRITON_BACKWARD = False + +__all__ = [ + "SO2_BLOCK_GEMM_TRITON_AVAILABLE", + "block_diag_gemm", + "block_diag_gemm_reference", + "slices_supported", +] + +try: + import triton + import triton.language as tl + + SO2_BLOCK_GEMM_TRITON_AVAILABLE = True +except ImportError: # pragma: no cover - exercised only without triton + SO2_BLOCK_GEMM_TRITON_AVAILABLE = False + + +# The BM=128, BN=64, BK=32 / 8-warp / 3-stage tile is the cuBLAS-beating config +# for these block shapes. BN is *fixed* at 64 -- the divisor of 128 and 192 that +# removes the pow2 column padding -- because the N-tile -> block mapping assumes +# each tile lies wholly inside one diagonal block. +_BM = 128 +_BN = 64 +_BK = 32 +_NUM_WARPS = 8 +_NUM_STAGES = 3 + + +# ====================================================================== +# Eager reference / fallback implementation +# ====================================================================== +def block_diag_gemm_reference( + x_flat: Tensor, weight: Tensor, slices: list[tuple[int, int, int, int]] +) -> Tensor: + """Eager ground truth: per-block ``bmm`` on the strided operands + concat. + + Parameters + ---------- + x_flat : Tensor + Activation with shape ``(F, E, K)``. + weight : Tensor + Block-diagonal weight presented as ``(F, K, N)`` (a strided view). + slices : list of (int, int, int, int) + The ``(in0, in1, out0, out1)`` diagonal blocks in m-major order. + + Returns + ------- + Tensor + Output with shape ``(F, E, N)``. + """ + blocks = [ + torch.bmm(x_flat[:, :, in0:in1], weight[:, in0:in1, out0:out1]) + for in0, in1, out0, out1 in slices + ] + return torch.cat(blocks, dim=-1) + + +def _block_diag_gemm_bwd_reference( + grad_out: Tensor, weight: Tensor, slices: list[tuple[int, int, int, int]] +) -> Tensor: + """Eager backward: ``grad_x = grad_out @ W^T`` per diagonal block. + + Only the activation gradient is produced; the weight is a parameter and is + never differentiated on the inference force path. + + Parameters + ---------- + grad_out : Tensor + Upstream gradient with shape ``(F, E, N)``. + weight : Tensor + Block-diagonal weight presented as ``(F, K, N)`` (a strided view). + slices : list of (int, int, int, int) + The ``(in0, in1, out0, out1)`` diagonal blocks in m-major order. + + Returns + ------- + Tensor + Activation gradient with shape ``(F, E, K)``. + """ + n_focus, n_edge = grad_out.shape[0], grad_out.shape[1] + k_total = weight.shape[1] + grad_x = grad_out.new_zeros(n_focus, n_edge, k_total) + for in0, in1, out0, out1 in slices: + grad_x[:, :, in0:in1] = torch.bmm( + grad_out[:, :, out0:out1], + weight[:, in0:in1, out0:out1].transpose(1, 2), + ) + return grad_x + + +# ====================================================================== +# Triton kernels (one launch per diagonal block; the block dims KLEN / NLEN are +# constexpr so the contraction loop is statically sized and fully pipelined) +# ====================================================================== +if SO2_BLOCK_GEMM_TRITON_AVAILABLE: + + @triton.jit + def _block_gemm_fwd_kernel( + a_ptr, + w_ptr, + c_ptr, + n_edge, + k0, + n0, + sab, + sae, + sak, + swb, + swk, + swn, + scb, + sce, + scn, + BM: tl.constexpr, + BN: tl.constexpr, + BK: tl.constexpr, + KLEN: tl.constexpr, + NLEN: tl.constexpr, + ): + """One diagonal block: ``C[:, :, n0:n0+NLEN] = X[:, :, k0:k0+KLEN] @ W``. + + The strided operands (``x``, the permuted ``weight``, and the whole + output) are passed as full buffers with the block offsets ``k0 / n0`` + applied inside the kernel; this keeps every access on the parent buffers + (no ``reinterpret_tensor`` slice, which Inductor would otherwise clone on + the strided focus-major layout). ``N`` is tiled at ``BN`` (a divisor of + ``NLEN``) so no column is padded; the ``KLEN`` contraction is statically + unrolled and pipelined. + """ + pid = tl.program_id(0) + bid = tl.program_id(1).to(tl.int64) + num_pid_n: tl.constexpr = NLEN // BN + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + offs_m = (pid_m * BM + tl.arange(0, BM)).to(tl.int64) + offs_n = (n0 + pid_n * BN + tl.arange(0, BN)).to(tl.int64) + offs_k = tl.arange(0, BK) + + a_ptrs = ( + a_ptr + bid * sab + (offs_m[:, None] * sae + (k0 + offs_k[None, :]) * sak) + ) + w_ptrs = ( + w_ptr + bid * swb + ((k0 + offs_k[:, None]) * swk + offs_n[None, :] * swn) + ) + + acc = tl.zeros((BM, BN), dtype=tl.float32) + m_mask = offs_m[:, None] < n_edge + for _ in range(0, KLEN, BK): + a = tl.load(a_ptrs, mask=m_mask, other=0.0) + w = tl.load(w_ptrs) + acc = tl.dot(a, w, acc, input_precision="ieee") + a_ptrs += BK * sak + w_ptrs += BK * swk + + c_ptrs = c_ptr + bid * scb + (offs_m[:, None] * sce + offs_n[None, :] * scn) + tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty), mask=m_mask) + + @triton.jit + def _block_gemm_dx_kernel( + go_ptr, + wt_ptr, + gx_ptr, + n_edge, + k0, + n0, + sgb, + sge, + sgn, + swb, + swn, + swk, + sxb, + sxe, + sxk, + BM: tl.constexpr, + BN: tl.constexpr, + BK: tl.constexpr, + KLEN: tl.constexpr, + NLEN: tl.constexpr, + ): + """One diagonal block of the activation gradient ``GX = GO @ W^T``. + + Contraction over the block's ``NLEN`` output axis: ``grad_x[e, k] = + sum_n grad_out[e, n] W[k, n]``. The full ``grad_out`` / ``grad_x`` buffers + are addressed with the block offsets ``n0 / k0`` applied inside the kernel + (no slice views, which Inductor clones on the strided focus-major grad). + ``weight_t`` is ``Wt[n, k]`` with a contiguous ``k`` axis so both operands + load coalesced with no register transpose. + """ + pid = tl.program_id(0) + bid = tl.program_id(1).to(tl.int64) + num_pid_k: tl.constexpr = KLEN // BN + pid_m = pid // num_pid_k + pid_k = pid % num_pid_k + + offs_m = (pid_m * BM + tl.arange(0, BM)).to(tl.int64) + offs_k = (k0 + pid_k * BN + tl.arange(0, BN)).to(tl.int64) + offs_c = tl.arange(0, BK) # contraction over the block's N axis + + go_ptrs = ( + go_ptr + bid * sgb + (offs_m[:, None] * sge + (n0 + offs_c[None, :]) * sgn) + ) + wt_ptrs = ( + wt_ptr + bid * swb + ((n0 + offs_c[:, None]) * swn + offs_k[None, :] * swk) + ) + + acc = tl.zeros((BM, BN), dtype=tl.float32) + m_mask = offs_m[:, None] < n_edge + for _ in range(0, NLEN, BK): + go = tl.load(go_ptrs, mask=m_mask, other=0.0) + wt = tl.load(wt_ptrs) + acc = tl.dot(go, wt, acc, input_precision="ieee") + go_ptrs += BK * sgn + wt_ptrs += BK * swn + + gx_ptrs = gx_ptr + bid * sxb + (offs_m[:, None] * sxe + offs_k[None, :] * sxk) + tl.store(gx_ptrs, acc.to(gx_ptr.dtype.element_ty), mask=m_mask) + + +# ====================================================================== +# Launch wrappers +# ====================================================================== +def _has_no_edges(n_edge) -> bool: + """Return true only for eager zero-edge calls; never guards a SymInt.""" + return type(n_edge) is int and n_edge == 0 + + +def _launch_forward( + x_flat: Tensor, weight: Tensor, slices: list[tuple[int, int, int, int]], n_out: int +) -> Tensor: + n_focus, n_edge, _ = x_flat.shape + out = torch.empty( + (n_focus, n_edge, n_out), dtype=x_flat.dtype, device=x_flat.device + ) + if _has_no_edges(n_edge): + return out + m_tiles = triton.cdiv(n_edge, _BM) + for in0, in1, out0, out1 in slices: + klen, nlen = in1 - in0, out1 - out0 + wrap_triton(_block_gemm_fwd_kernel)[(m_tiles * (nlen // _BN), n_focus)]( + x_flat, + weight, + out, + n_edge, + in0, + out0, + x_flat.stride(0), + x_flat.stride(1), + x_flat.stride(2), + weight.stride(0), + weight.stride(1), + weight.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), + BM=_BM, + BN=_BN, + BK=_BK, + KLEN=klen, + NLEN=nlen, + num_warps=_NUM_WARPS, + num_stages=_NUM_STAGES, + ) + return out + + +def _launch_backward_dx( + grad_out: Tensor, + weight: Tensor, + slices: list[tuple[int, int, int, int]], + k_out: int, +) -> Tensor: + n_focus, n_edge, _ = grad_out.shape + grad_x = torch.empty( + (n_focus, n_edge, k_out), dtype=grad_out.dtype, device=grad_out.device + ) + if _has_no_edges(n_edge): + return grad_x + # Pre-transpose the (small, constant) weight to (F, N, K) with a contiguous + # K axis so the N-contraction kernel loads it coalesced (see kernel doc). + # The weight is a parameter, so Inductor constant-folds this in the frozen + # graph; eagerly it is a sub-megabyte copy. + weight_t = weight.transpose(1, 2).contiguous() + m_tiles = triton.cdiv(n_edge, _BM) + for in0, in1, out0, out1 in slices: + klen, nlen = in1 - in0, out1 - out0 + wrap_triton(_block_gemm_dx_kernel)[(m_tiles * (klen // _BN), n_focus)]( + grad_out, + weight_t, + grad_x, + n_edge, + in0, + out0, + grad_out.stride(0), + grad_out.stride(1), + grad_out.stride(2), + weight_t.stride(0), + weight_t.stride(1), + weight_t.stride(2), + grad_x.stride(0), + grad_x.stride(1), + grad_x.stride(2), + BM=_BM, + BN=_BN, + BK=_BK, + KLEN=klen, + NLEN=nlen, + num_warps=_NUM_WARPS, + num_stages=_NUM_STAGES, + ) + return grad_x + + +# ====================================================================== +# Dispatch helpers (triton on CUDA float, eager otherwise) +# ====================================================================== +def _use_triton(tensor: Tensor) -> bool: + return ( + SO2_BLOCK_GEMM_TRITON_AVAILABLE + and tensor.is_cuda + and tensor.dtype in (torch.float16, torch.bfloat16, torch.float32) + ) + + +def _unflatten_slices(slices_flat: list[int]) -> list[tuple[int, int, int, int]]: + """Rebuild ``(in0, in1, out0, out1)`` blocks from the flat ``list[int]``. + + ``triton_op`` schema inference accepts ``list[int]`` but not + ``list[list[int]]``, so the block table is carried as groups of four. + """ + return [ + (slices_flat[i], slices_flat[i + 1], slices_flat[i + 2], slices_flat[i + 3]) + for i in range(0, len(slices_flat), 4) + ] + + +def _forward_impl(x_flat: Tensor, weight: Tensor, slices_flat: list[int]) -> Tensor: + slices = _unflatten_slices(slices_flat) + if not _use_triton(x_flat): + return block_diag_gemm_reference(x_flat, weight, slices) + n_out = max(out1 for _, _, _, out1 in slices) + return _launch_forward(x_flat, weight, slices, n_out) + + +def _backward_impl(grad_out: Tensor, weight: Tensor, slices_flat: list[int]) -> Tensor: + slices = _unflatten_slices(slices_flat) + if not _TRITON_BACKWARD or not _use_triton(grad_out): + return _block_diag_gemm_bwd_reference(grad_out, weight, slices) + k_out = max(in1 for _, in1, _, _ in slices) + return _launch_backward_dx(grad_out, weight, slices, k_out) + + +# ====================================================================== +# Functional triton_op + fake + autograd registration +# ====================================================================== +_bd_gemm_op = torch.library.triton_op( + "sezm_triton::so2_block_diag_gemm", mutates_args=() +)(_forward_impl) + +_bd_gemm_bwd_op = torch.library.triton_op( + "sezm_triton::so2_block_diag_gemm_bwd", mutates_args=() +)(_backward_impl) + + +@_bd_gemm_op.register_fake +def _(x_flat, weight, slices_flat): + n_out = max(slices_flat[3::4]) + return x_flat.new_empty((x_flat.shape[0], x_flat.shape[1], n_out)) + + +@_bd_gemm_bwd_op.register_fake +def _(grad_out, weight, slices_flat): + k_total = max(slices_flat[1::4]) + return grad_out.new_empty((grad_out.shape[0], grad_out.shape[1], k_total)) + + +def _bd_gemm_setup_context(ctx, inputs, output): + x_flat, weight, slices_flat = inputs + ctx.save_for_backward(weight) + ctx.slices_flat = slices_flat + + +def _bd_gemm_backward(ctx, grad_out): + (weight,) = ctx.saved_tensors + grad_x = _bd_gemm_bwd_op(grad_out, weight, ctx.slices_flat) + # weight is a parameter (never a function of the coordinates); the inference + # force differentiates only w.r.t. the activation, so its gradient is not + # produced. ``slices_flat`` is a static block table. + return grad_x, None, None + + +_bd_gemm_op.register_autograd(_bd_gemm_backward, setup_context=_bd_gemm_setup_context) + + +# ====================================================================== +# Public API +# ====================================================================== +def slices_supported( + slices: list[tuple[int, int, int, int]], block_n: int = _BN +) -> bool: + """Return whether every block boundary/width aligns to ``block_n``. + + The BN-tiled kernel maps each ``BN``-wide output (input) tile to a single + diagonal block, which requires every block edge and width to be a multiple of + ``block_n`` so no tile straddles two blocks. Callers gate the Triton path on + this (e.g. an even ``lmax`` makes the ``m=0`` block width ``(lmax+1)*C`` an + odd count and may break alignment); unsupported layouts fall back to eager. + + Parameters + ---------- + slices : list of (int, int, int, int) + The ``(in0, in1, out0, out1)`` diagonal blocks in m-major order. + block_n : int + Column tile width; every block edge and width must be a multiple of it. + + Returns + ------- + bool + ``True`` when every block boundary and width is a multiple of ``block_n``. + """ + return all( + edge % block_n == 0 + for in0, in1, out0, out1 in slices + for edge in (in0, in1, out0, out1) + ) + + +def block_diag_gemm( + x_flat: Tensor, + weight: Tensor, + slices: list[tuple[int, int, int, int]], +) -> Tensor: + """Apply the ``BN=64`` block-diagonal GEMM ``(F, E, K) -> (F, E, N)``. + + Computes the same result as :func:`block_diag_gemm_reference` while avoiding + both the block concatenation and any contiguity copy of the strided weight: + one Triton launch per diagonal block streams only that block's contraction + range from the strided operands, with the output ``N`` axis tiled at 64. + + Parameters + ---------- + x_flat : Tensor + Activation with shape ``(F, E, K)``. + weight : Tensor + Block-diagonal weight presented as ``(F, K, N)`` (may be strided). + slices : list of (int, int, int, int) + The ``(in0, in1, out0, out1)`` diagonal blocks in m-major order. + + Returns + ------- + Tensor + Output with shape ``(F, E, N)``. + """ + slices_flat = [int(v) for s in slices for v in s] + return _bd_gemm_op(x_flat, weight, slices_flat) diff --git a/deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py b/deepmd/kernels/triton/sezm/so2_rotation.py similarity index 99% rename from deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py rename to deepmd/kernels/triton/sezm/so2_rotation.py index 1f524d9ca5..87b7792121 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py +++ b/deepmd/kernels/triton/sezm/so2_rotation.py @@ -78,7 +78,7 @@ wrap_triton, ) -from ..indexing import ( +from deepmd.pt.model.descriptor.sezm_nn.indexing import ( build_m_major_index, ) diff --git a/deepmd/kernels/triton/sezm/so2_stack_fp16x3.py b/deepmd/kernels/triton/sezm/so2_stack_fp16x3.py new file mode 100644 index 0000000000..bf8e733d8e --- /dev/null +++ b/deepmd/kernels/triton/sezm/so2_stack_fp16x3.py @@ -0,0 +1,904 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +# ruff: noqa: ANN001, ANN202 +"""fp16x3 split-compensated SO(2) mixing-stack operators. + +This module provides a tensor-core implementation of the SO(2) mixing stack +that is numerically interchangeable with the fp32 ``sezm_triton:: +so2_mixing_stack`` operator while running the block GEMMs on fp16 tensor +cores. It is selected at ``DP_TRITON_INFER >= 3`` for ``(focus_dim, lmax)`` +keys resolved by :func:`~.tile_configs.stack_fp16x3_configs`; all other +shapes keep the fp32 stack. + +Numerical scheme +---------------- +Each fp32 GEMM ``C = A @ B`` is evaluated as three fp16 tensor-core products +with fp32 accumulation (a two-term Ootomo split):: + + A = A_hi + A_lo, B = B_hi + B_lo (fp16 head + fp16 tail) + C ~= A_hi B_hi + A_hi B_lo + A_lo B_hi (the A_lo B_lo term, ~2^-22 + relative, is dropped) + +An fp16 multiply feeding an fp32 accumulator is exact (11 x 11 -> 22-bit +products), so the only error sources are the two split truncations. The +head product and the two tail corrections accumulate in *separate* fp32 +accumulators merged once per tile: chaining all three into one accumulator +absorbs the small tail terms against the large head partial sums each +k-step and doubles the error. Measured against fp64 on production shapes, +the per-GEMM maximum relative error is indistinguishable from the fp32 FFMA +reference (~5e-7), at roughly 1.6x the FFMA GEMM throughput on H20. + +Dynamic-range handling +---------------------- +fp16 spans ~[6e-5, 65504] against fp32's ~[1e-38, 3e38]; both ends are +protected with exact power-of-two scalings: + +- *Tail underflow.* The tail of an element below ~1.2e-4 falls out of the + fp16 subnormal range and the correction silently vanishes (local + degradation to bare fp16 accuracy). Tails are therefore stored pre-scaled + by ``2^11`` (the fp16 mantissa width) and the accumulated correction is + scaled back in the epilogue; the scaled tail never overflows where the + head itself does not (``|x_lo * 2^11| <= |x|``). +- *Head overflow.* The stack input rides the unnormalized residual stream + (the default SeZM layout applies the equivariant norm after the SO(2) + update, not before), so the activation operand is pre-scaled by ``2^-4`` + before the split and the merged accumulator is scaled back by ``2^4``. + Layer inputs measured on production checkpoints peak near 13 with roughly + 6x per-block growth, so the prescale keeps four orders of magnitude of + headroom below the fp16 maximum for realistic depths. Weights are static + and of order one after training and stay unscaled; a checkpoint whose + stack weights or activations exceed the fp16 head range surfaces loudly + as NaN on the first evaluation rather than as silent error. + +Accuracy contract +----------------- +The scheme perturbs a trained model's outputs at the level of the fp32 +rounding itself per GEMM; through a full force evaluation the accumulated +deviation against the fp32 stack is of order 1e-6 on forces (measured +~4e-6 eV/A maximum on a 4096-atom system) and ~1e-7 eV per atom on the +energy. The rounding step of the scheme is 2^-22 relative -- three orders +of magnitude finer than TF32 -- so the smoothness character of the +potential-energy surface matches fp32. The level-3 gate exists so this +trade is always an explicit opt-in. + +Launch-configuration discipline +------------------------------- +Some ``(num_warps, num_stages)`` combinations of the three-``tl.dot`` k-loop +are miscompiled by the Triton software pipeliner into silent NaN rows at +production edge counts, and the affected set shifts with any change to the +kernel body. Launch configurations therefore come exclusively from +:func:`~.tile_configs.stack_fp16x3_configs`, whose entries are regenerated +by the fp64-validated sweep (``sweep_tile_configs.py --kernels fp16x3``). +Any edit to a kernel body in this module invalidates every table entry. + +Layout and semantics are identical to ``sezm_triton::so2_mixing_stack`` +(m-major focus-major rows, raw pre-activations saved for the backward, the +competition weight folded into the final store); the gate, recompute and +pointwise-backward kernels are shared with the fp32 operator. +""" + +from __future__ import ( + annotations, +) + +import torch +from torch import ( + Tensor, +) +from torch.library import ( + wrap_triton, +) + +from .so2_value_path import ( + SO2_VALUE_PATH_TRITON_AVAILABLE, + _has_no_edges, + _mixing_stack_backward_reference, + _mixing_stack_reference, + _use_triton, +) +from .tile_configs import ( + GATE_BMM_MIN_FOCUS_DIM, + gate_config, + point_config, + recompute_config, + stack_fp16x3_configs, +) + +__all__ = [ + "STACK_FP16X3_TRITON_AVAILABLE", + "mixing_stack_fp16x3", +] + +STACK_FP16X3_TRITON_AVAILABLE = SO2_VALUE_PATH_TRITON_AVAILABLE + +if STACK_FP16X3_TRITON_AVAILABLE: + import triton + import triton.language as tl + + from .so2_value_path import ( + _stack_gate_kernel, + _stack_grad_alpha_kernel, + _stack_point_bwd_kernel, + _stack_recompute_kernel, + ) + + @triton.jit + def _split_fp16_kernel( + w_ptr, # (numel,) fp32 weights, contiguous + hi_ptr, # (numel,) fp16 head out + lo_ptr, # (numel,) fp16 tail out (pre-scaled by 2^11) + numel, + BLOCK: tl.constexpr, + ): + """Two-term fp16 split evaluated inside Triton. + + The split must not be expressed as aten operations: Inductor's + codegen keeps pointwise intermediates in fp32 and elides the + ``fp32 -> fp16 -> fp32`` rounding round-trip, which turns the tail + into exact zero and silently disables the compensation on the + compiled path. A Triton kernel is an opaque leaf for Inductor and + its ``.to`` conversions round as written. + """ + offs = tl.program_id(0).to(tl.int64) * BLOCK + tl.arange(0, BLOCK) + mask = offs < numel + w = tl.load(w_ptr + offs, mask=mask, other=0.0) + hi = w.to(tl.float16) + lo = ((w - hi.to(tl.float32)) * 2048.0).to(tl.float16) + tl.store(hi_ptr + offs, hi, mask=mask) + tl.store(lo_ptr + offs, lo, mask=mask) + + @triton.jit + def _dot_fp16x3(a, bh, bl, acc, acc2): + """fp16x3 compensated dot with prescaled head / scaled-tail terms. + + The activation tile is scaled by ``2^-4`` before the split and its + tail by a further ``2^11``; the epilogue applies the matching inverse + factors (``2^4`` on the head accumulator, ``2^4 * 2^-11`` on the tail + accumulator). + """ + a_s = a * 0.0625 + a_hi = a_s.to(tl.float16) + a_lo = ((a_s - a_hi.to(tl.float32)) * 2048.0).to(tl.float16) + acc = tl.dot(a_hi, bh, acc) + acc2 = tl.dot(a_hi, bl, acc2) + acc2 = tl.dot(a_lo, bh, acc2) + return acc, acc2 + + @triton.jit + def _stack_fp16x3_m0_kernel( + u_ptr, # (F, E, ROW) layer input + wh_ptr, # (NL, F, M0, M0) fp16 weight head + wl_ptr, # (NL, F, M0, M0) fp16 weight tail (pre-scaled by 2^11) + alpha_ptr, # (E, F) competition weight (identity epilogue only) + v_ptr, # z_all stack (EPILOGUE 0) or the final output (EPILOGUE 1) + n_edge, + layer, + L: tl.constexpr, + CF: tl.constexpr, + EPILOGUE: tl.constexpr, # 0: store raw z; 1: residual (+ alpha) output + V_EDGE_MAJOR: tl.constexpr, # v is (E, F, ROW); else focus-major + APPLY_ALPHA: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ): + """``m = 0`` block GEMM ``z = u[:, :M0] @ W0`` (fp16x3 inner product). + + Output strides are derived in-kernel from the layout flag on int64 + offsets: a host-side ``n_edge * ROW`` scalar argument would be + specialized to int32 by the first (small) compilation and overflow + on systems beyond ~2^31 / ROW edges. + """ + M0: tl.constexpr = (L + 1) * CF + ROW: tl.constexpr = (3 * L + 1) * CF + NT: tl.constexpr = (M0 + BLOCK_N - 1) // BLOCK_N + + pid = tl.program_id(0) + fid = tl.program_id(1).to(tl.int64) + n_focus = tl.num_programs(1) + pid_m = pid // NT + pid_n = pid % NT + + offs_m = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64) + m_mask = offs_m < n_edge + mm = m_mask[:, None] + u_row = u_ptr + fid * n_edge * ROW + offs_m * ROW + offs_k = tl.arange(0, BLOCK_K) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + n_mask = offs_n < M0 + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + acc2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + a_ptrs = u_row[:, None] + offs_k[None, :] + w_off = ( + (layer * n_focus + fid) * M0 * M0 + offs_k[:, None] * M0 + offs_n[None, :] + ) + for _ in range(0, M0, BLOCK_K): + a = tl.load(a_ptrs, mask=mm, other=0.0) + bh = tl.load(wh_ptr + w_off, mask=n_mask[None, :], other=0.0) + bl = tl.load(wl_ptr + w_off, mask=n_mask[None, :], other=0.0) + acc, acc2 = _dot_fp16x3(a, bh, bl, acc, acc2) + a_ptrs += BLOCK_K + w_off += BLOCK_K * M0 + acc = acc * 16.0 + acc2 * 0.0078125 # 2^4 head, 2^4 * 2^-11 tail unscale + + if EPILOGUE == 1: + u_t = tl.load( + u_row[:, None] + offs_n[None, :], mask=mm & n_mask[None, :], other=0.0 + ) + acc = acc + u_t + if APPLY_ALPHA: + alpha = tl.load( + alpha_ptr + offs_m * n_focus + fid, mask=m_mask, other=0.0 + ) + acc = acc * alpha[:, None] + if V_EDGE_MAJOR: + v_row = v_ptr + fid * ROW + offs_m * (n_focus * ROW) + else: + v_row = v_ptr + fid * n_edge * ROW + offs_m * ROW + tl.store(v_row[:, None] + offs_n[None, :], acc, mask=mm & n_mask[None, :]) + else: + z_row = v_ptr + (layer * n_focus + fid) * n_edge * ROW + offs_m * ROW + tl.store(z_row[:, None] + offs_n[None, :], acc, mask=mm & n_mask[None, :]) + + @triton.jit + def _stack_fp16x3_m1_kernel( + u_ptr, + wh_ptr, # (NL, F, M1, M1) fp16 weight head + wl_ptr, # (NL, F, M1, M1) fp16 weight tail (pre-scaled by 2^11) + sig_ptr, # (F, E, L*CF) gate sigmoids (HAS_GATE) + alpha_ptr, + v_ptr, + z_ptr, + n_edge, + layer, + L: tl.constexpr, + CF: tl.constexpr, + HAS_GATE: tl.constexpr, + V_EDGE_MAJOR: tl.constexpr, # v is (E, F, ROW); else focus-major + APPLY_ALPHA: tl.constexpr, + SAVE_Z: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ): + """``|m| = 1`` block GEMM with the gate / residual / alpha epilogue.""" + M0: tl.constexpr = (L + 1) * CF + M1: tl.constexpr = 2 * L * CF + ROW: tl.constexpr = (3 * L + 1) * CF + LG: tl.constexpr = L * CF + NT: tl.constexpr = (M1 + BLOCK_N - 1) // BLOCK_N + + pid = tl.program_id(0) + fid = tl.program_id(1).to(tl.int64) + n_focus = tl.num_programs(1) + pid_m = pid // NT + pid_n = pid % NT + + offs_m = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64) + m_mask = offs_m < n_edge + mm = m_mask[:, None] + u_row = u_ptr + fid * n_edge * ROW + offs_m * ROW + offs_k = tl.arange(0, BLOCK_K) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + n_mask = offs_n < M1 + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + acc2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + a_ptrs = u_row[:, None] + (M0 + offs_k)[None, :] + w_off = ( + (layer * n_focus + fid) * M1 * M1 + offs_k[:, None] * M1 + offs_n[None, :] + ) + for _ in range(0, M1, BLOCK_K): + a = tl.load(a_ptrs, mask=mm, other=0.0) + bh = tl.load(wh_ptr + w_off, mask=n_mask[None, :], other=0.0) + bl = tl.load(wl_ptr + w_off, mask=n_mask[None, :], other=0.0) + acc, acc2 = _dot_fp16x3(a, bh, bl, acc, acc2) + a_ptrs += BLOCK_K + w_off += BLOCK_K * M1 + acc = acc * 16.0 + acc2 * 0.0078125 # 2^4 head, 2^4 * 2^-11 tail unscale + + if SAVE_Z: + z_row = z_ptr + (layer * n_focus + fid) * n_edge * ROW + offs_m * ROW + tl.store( + z_row[:, None] + (M0 + offs_n)[None, :], acc, mask=mm & n_mask[None, :] + ) + if HAS_GATE: + # Both |m| = 1 stripes of degree group g share gate group g. + sig_cols = ((offs_n // CF) % L) * CF + (offs_n % CF) + sig = tl.load( + sig_ptr + (fid * n_edge + offs_m)[:, None] * LG + sig_cols[None, :], + mask=mm & n_mask[None, :], + other=0.0, + ) + acc = acc * sig + u_t = tl.load( + u_row[:, None] + (M0 + offs_n)[None, :], + mask=mm & n_mask[None, :], + other=0.0, + ) + acc = acc + u_t + if APPLY_ALPHA: + alpha = tl.load(alpha_ptr + offs_m * n_focus + fid, mask=m_mask, other=0.0) + acc = acc * alpha[:, None] + if V_EDGE_MAJOR: + v_row = v_ptr + fid * ROW + offs_m * (n_focus * ROW) + else: + v_row = v_ptr + fid * n_edge * ROW + offs_m * ROW + tl.store( + v_row[:, None] + (M0 + offs_n)[None, :], acc, mask=mm & n_mask[None, :] + ) + + @triton.jit + def _stack_fp16x3_bwd_kernel( + gz_ptr, # (F, E, ROW), or the raw upstream gradient when FOLD_ALPHA + res_ptr, # (F, E, ROW) residual gradient source; unread if FOLD_ALPHA + wh_ptr, # (NL, F, MB, MB) fp16 transposed weight head of this block + wl_ptr, # (NL, F, MB, MB) fp16 transposed weight tail (2^11-scaled) + alpha_ptr, + gu_ptr, # (F, E, ROW) layer-input gradient + n_edge, + layer, + L: tl.constexpr, + CF: tl.constexpr, + IS_M1: tl.constexpr, # 0: m = 0 block (offset 0), 1: |m| = 1 block + G_EDGE_MAJOR: tl.constexpr, # gz is (E, F, ROW); else focus-major + FOLD_ALPHA: tl.constexpr, # gz = g * alpha on the fly; residual == gz + RES_IS_GZ: tl.constexpr, # residual equals gz (final layer, no alpha) + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ): + """Backward GEMM ``g_u = residual + gz @ W^T`` for one ``|m|`` block. + + The two blocks are separate launches (``IS_M1`` 0 / 1) so each + pipelines with its own swept schedule instead of sharing one + compromise configuration. + """ + M0: tl.constexpr = (L + 1) * CF + MB: tl.constexpr = (2 * L * CF) if IS_M1 else ((L + 1) * CF) + OFF: tl.constexpr = M0 if IS_M1 else 0 + ROW: tl.constexpr = (3 * L + 1) * CF + NT: tl.constexpr = (MB + BLOCK_N - 1) // BLOCK_N + + pid = tl.program_id(0) + fid = tl.program_id(1).to(tl.int64) + n_focus = tl.num_programs(1) + pid_m = pid // NT + pid_n = pid % NT + + offs_m = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64) + m_mask = offs_m < n_edge + mm = m_mask[:, None] + offs_k = tl.arange(0, BLOCK_K) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + n_mask = offs_n < MB + + if G_EDGE_MAJOR: + gz_row = gz_ptr + fid * ROW + offs_m * (n_focus * ROW) + else: + gz_row = gz_ptr + fid * n_edge * ROW + offs_m * ROW + gu_row = gu_ptr + fid * n_edge * ROW + offs_m * ROW + if FOLD_ALPHA: + alpha = tl.load(alpha_ptr + offs_m * n_focus + fid, mask=m_mask, other=0.0) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + acc2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + a_ptrs = gz_row[:, None] + (OFF + offs_k)[None, :] + w_off = ( + (layer * n_focus + fid) * MB * MB + offs_k[:, None] * MB + offs_n[None, :] + ) + for _ in range(0, MB, BLOCK_K): + a = tl.load(a_ptrs, mask=mm, other=0.0) + if FOLD_ALPHA: + a = a * alpha[:, None] + bh = tl.load(wh_ptr + w_off, mask=n_mask[None, :], other=0.0) + bl = tl.load(wl_ptr + w_off, mask=n_mask[None, :], other=0.0) + acc, acc2 = _dot_fp16x3(a, bh, bl, acc, acc2) + a_ptrs += BLOCK_K + w_off += BLOCK_K * MB + acc = acc * 16.0 + acc2 * 0.0078125 # 2^4 head, 2^4 * 2^-11 tail unscale + + col0 = OFF + offs_n + if FOLD_ALPHA: + res = tl.load( + gz_row[:, None] + col0[None, :], mask=mm & n_mask[None, :], other=0.0 + ) + res = res * alpha[:, None] + elif RES_IS_GZ: + res = tl.load( + gz_row[:, None] + col0[None, :], mask=mm & n_mask[None, :], other=0.0 + ) + else: + res_row = res_ptr + fid * n_edge * ROW + offs_m * ROW + res = tl.load( + res_row[:, None] + col0[None, :], mask=mm & n_mask[None, :], other=0.0 + ) + tl.store(gu_row[:, None] + col0[None, :], acc + res, mask=mm & n_mask[None, :]) + + +def _split_fp16(w: Tensor) -> tuple[Tensor, Tensor]: + """Two-term fp16 split ``w ~= hi + lo * 2^-11`` (contiguous halves). + + The tail is stored pre-scaled by ``2^11`` so it stays representable in + fp16 over the whole fp32-relevant magnitude range; the kernels apply the + matching inverse factor in their epilogues (exact powers of two). + + The split runs as a Triton kernel rather than as aten operations: the + tail is *defined* by an ``fp32 -> fp16 -> fp32`` rounding round-trip, and + Inductor's pointwise fusion keeps such intermediates in fp32 registers, + which folds the round-trip away and zeroes the tail on the compiled path. + """ + w = w.contiguous() + hi = torch.empty(w.shape, device=w.device, dtype=torch.float16) + lo = torch.empty(w.shape, device=w.device, dtype=torch.float16) + numel = w.numel() + block = 1024 + wrap_triton(_split_fp16_kernel)[(triton.cdiv(numel, block),)]( + w.view(-1), hi.view(-1), lo.view(-1), numel, BLOCK=block + ) + return hi, lo + + +def _mixing_stack_fp16x3_impl( + u0: Tensor, + alpha: Tensor, + w0_all: Tensor, + w1_all: Tensor, + gw_all: Tensor, + lmax: int, + focus_dim: int, + apply_alpha: bool, +) -> tuple[Tensor, Tensor]: + if not _use_triton(u0): + return _mixing_stack_reference( + u0, alpha, w0_all, w1_all, gw_all, lmax, focus_dim, apply_alpha + ) + n_focus, n_edge, row = u0.shape + lmax = int(lmax) + focus_dim = int(focus_dim) + configs = stack_fp16x3_configs(focus_dim, lmax) + if configs is None: + raise RuntimeError( + f"no validated fp16x3 configuration for (focus_dim={focus_dim}, " + f"lmax={lmax}); the caller must route unswept shapes to the fp32 " + "mixing stack" + ) + n_gated = gw_all.shape[0] + z_all = torch.empty( + (n_gated, n_focus, n_edge, row), device=u0.device, dtype=u0.dtype + ) + x_local = torch.empty((n_edge, n_focus, row), device=u0.device, dtype=u0.dtype) + if _has_no_edges(n_edge): + return x_local, z_all + + # Weight splits are parameter-only and negligible next to the GEMMs. + w0h, w0l = _split_fp16(w0_all) + w1h, w1l = _split_fp16(w1_all) + + (bm0, bn0, bk0, w0_warps, w0_stages), (bm1, bn1, bk1, w1_warps, w1_stages) = ( + configs[0], + configs[1], + ) + m0 = (lmax + 1) * focus_dim + m1 = 2 * lmax * focus_dim + gate_bm, gate_w, gate_s = gate_config(focus_dim, lmax) + sig_by_bmm = focus_dim >= GATE_BMM_MIN_FOCUS_DIM + sig = torch.empty( + (n_focus, n_edge, lmax * focus_dim), device=u0.device, dtype=torch.float32 + ) + grid_m0 = (triton.cdiv(n_edge, bm0) * triton.cdiv(m0, bn0), n_focus) + grid_m1 = (triton.cdiv(n_edge, bm1) * triton.cdiv(m1, bn1), n_focus) + + u = u0 + for layer in range(n_gated): + out = torch.empty_like(u) + wrap_triton(_stack_fp16x3_m0_kernel)[grid_m0]( + u, + w0h, + w0l, + u, + z_all, + n_edge, + layer, + L=lmax, + CF=focus_dim, + EPILOGUE=0, + V_EDGE_MAJOR=False, + APPLY_ALPHA=False, + BLOCK_M=bm0, + BLOCK_N=bn0, + BLOCK_K=bk0, + num_warps=w0_warps, + num_stages=w0_stages, + ) + if sig_by_bmm: + # Wide-channel regime: sigmoid projection as a cuBLAS bmm on the + # freshly written l = 0 scalar rows of the pre-activation. + torch.sigmoid( + torch.bmm(z_all[layer, :, :, :focus_dim], gw_all[layer]), out=sig + ) + wrap_triton(_stack_gate_kernel)[(triton.cdiv(n_edge, gate_bm), n_focus)]( + u, + z_all, + gw_all, + out, + sig, + n_edge, + layer, + L=lmax, + CF=focus_dim, + SIG_IN=sig_by_bmm, + BLOCK_M=gate_bm, + num_warps=gate_w, + num_stages=gate_s, + ) + wrap_triton(_stack_fp16x3_m1_kernel)[grid_m1]( + u, + w1h, + w1l, + sig, + u, + out, + z_all, + n_edge, + layer, + L=lmax, + CF=focus_dim, + HAS_GATE=True, + V_EDGE_MAJOR=False, + APPLY_ALPHA=False, + SAVE_Z=True, + BLOCK_M=bm1, + BLOCK_N=bn1, + BLOCK_K=bk1, + num_warps=w1_warps, + num_stages=w1_stages, + ) + u = out + + # Final identity layer streams straight into the edge-major output layout. + wrap_triton(_stack_fp16x3_m0_kernel)[grid_m0]( + u, + w0h, + w0l, + alpha, + x_local, + n_edge, + n_gated, + L=lmax, + CF=focus_dim, + EPILOGUE=1, + V_EDGE_MAJOR=True, + APPLY_ALPHA=apply_alpha, + BLOCK_M=bm0, + BLOCK_N=bn0, + BLOCK_K=bk0, + num_warps=w0_warps, + num_stages=w0_stages, + ) + wrap_triton(_stack_fp16x3_m1_kernel)[grid_m1]( + u, + w1h, + w1l, + sig, + alpha, + x_local, + u, + n_edge, + n_gated, + L=lmax, + CF=focus_dim, + HAS_GATE=False, + V_EDGE_MAJOR=True, + APPLY_ALPHA=apply_alpha, + SAVE_Z=False, + BLOCK_M=bm1, + BLOCK_N=bn1, + BLOCK_K=bk1, + num_warps=w1_warps, + num_stages=w1_stages, + ) + return x_local, z_all + + +def _mixing_stack_fp16x3_bwd_impl( + grad_out: Tensor, + x_local: Tensor, + z_all: Tensor, + alpha: Tensor, + w0t_all: Tensor, + w1t_all: Tensor, + gw_all: Tensor, + gwt_all: Tensor, + lmax: int, + focus_dim: int, + apply_alpha: bool, +) -> tuple[Tensor, Tensor]: + if not _use_triton(grad_out): + return _mixing_stack_backward_reference( + grad_out, + x_local, + z_all, + alpha, + w0t_all, + w1t_all, + gw_all, + gwt_all, + lmax, + focus_dim, + apply_alpha, + ) + n_gated, n_focus, n_edge, row = z_all.shape + lmax = int(lmax) + focus_dim = int(focus_dim) + configs = stack_fp16x3_configs(focus_dim, lmax) + if configs is None: + raise RuntimeError( + f"no validated fp16x3 configuration for (focus_dim={focus_dim}, " + f"lmax={lmax}); the caller must route unswept shapes to the fp32 " + "mixing stack" + ) + device, dtype = grad_out.device, grad_out.dtype + grad_alpha = torch.empty((n_edge, n_focus), device=device, dtype=dtype) + grad_u0 = torch.empty((n_focus, n_edge, row), device=device, dtype=dtype) + if _has_no_edges(n_edge): + return grad_u0, grad_alpha + + w0h, w0l = _split_fp16(w0t_all) + w1h, w1l = _split_fp16(w1t_all) + + m0 = (lmax + 1) * focus_dim + m1 = 2 * lmax * focus_dim + (bm0, bn0, bk0, w0_warps, w0_stages), (bm1, bn1, bk1, w1_warps, w1_stages) = ( + configs[2], + configs[3], + ) + grid_bwd0 = (triton.cdiv(n_edge, bm0) * triton.cdiv(m0, bn0), n_focus) + grid_bwd1 = (triton.cdiv(n_edge, bm1) * triton.cdiv(m1, bn1), n_focus) + point_bm, point_w, point_s = point_config(focus_dim, lmax) + + def launch_bwd_gemms(gz, res, gu, layer, g_edge_major, fold, res_is_gz): + wrap_triton(_stack_fp16x3_bwd_kernel)[grid_bwd0]( + gz, + res, + w0h, + w0l, + alpha, + gu, + n_edge, + layer, + L=lmax, + CF=focus_dim, + IS_M1=False, + G_EDGE_MAJOR=g_edge_major, + FOLD_ALPHA=fold, + RES_IS_GZ=res_is_gz, + BLOCK_M=bm0, + BLOCK_N=bn0, + BLOCK_K=bk0, + num_warps=w0_warps, + num_stages=w0_stages, + ) + wrap_triton(_stack_fp16x3_bwd_kernel)[grid_bwd1]( + gz, + res, + w1h, + w1l, + alpha, + gu, + n_edge, + layer, + L=lmax, + CF=focus_dim, + IS_M1=True, + G_EDGE_MAJOR=g_edge_major, + FOLD_ALPHA=fold, + RES_IS_GZ=res_is_gz, + BLOCK_M=bm1, + BLOCK_N=bn1, + BLOCK_K=bk1, + num_warps=w1_warps, + num_stages=w1_stages, + ) + + # === Final layer: g = gz + gz @ W^T with gz = grad [* alpha] on the fly === + g_cur = torch.empty((n_focus, n_edge, row), device=device, dtype=dtype) + launch_bwd_gemms(grad_out, grad_out, g_cur, n_gated, True, apply_alpha, True) + if apply_alpha: + a_bm, a_w, a_s = gate_config(focus_dim, lmax) + wrap_triton(_stack_grad_alpha_kernel)[(triton.cdiv(n_edge, a_bm), n_focus)]( + grad_out, + x_local, + alpha, + grad_alpha, + n_edge, + L=lmax, + CF=focus_dim, + BLOCK_M=a_bm, + num_warps=a_w, + num_stages=a_s, + ) + + # === Gated layers in reverse; sig / gz buffers are reused across layers === + gate_width = lmax * focus_dim + sig = torch.empty((n_focus, n_edge, gate_width), device=device, dtype=torch.float32) + gz = torch.empty((n_focus, n_edge, row), device=device, dtype=dtype) + use_bmm = focus_dim >= GATE_BMM_MIN_FOCUS_DIM + glogit = ( + torch.empty((n_focus, n_edge, gate_width), device=device, dtype=torch.float32) + if use_bmm + else sig + ) + r_bm, r_w, r_s = recompute_config(focus_dim, lmax) + for layer in range(n_gated - 1, -1, -1): + if use_bmm: + torch.sigmoid( + torch.bmm(z_all[layer, :, :, :focus_dim], gw_all[layer]), out=sig + ) + else: + wrap_triton(_stack_recompute_kernel)[(triton.cdiv(n_edge, r_bm), n_focus)]( + z_all, + gw_all, + sig, + n_edge, + layer, + L=lmax, + CF=focus_dim, + BLOCK_M=r_bm, + num_warps=r_w, + num_stages=r_s, + ) + wrap_triton(_stack_point_bwd_kernel)[(triton.cdiv(n_edge, point_bm), n_focus)]( + g_cur, + z_all, + sig, + gwt_all, + gz, + glogit, + n_edge, + layer, + L=lmax, + CF=focus_dim, + GLOGIT_OUT=use_bmm, + BLOCK_M=point_bm, + num_warps=point_w, + num_stages=point_s, + ) + if use_bmm: + # Gate-logit contraction back to the scalar rows via cuBLAS. + gz[:, :, :focus_dim] += torch.bmm(glogit, gwt_all[layer]) + g_next = torch.empty((n_focus, n_edge, row), device=device, dtype=dtype) + launch_bwd_gemms(gz, g_cur, g_next, layer, False, False, False) + g_cur = g_next + return g_cur, grad_alpha + + +# ====================================================================== +# Functional triton_op + fake + autograd registration +# ====================================================================== +_mixing_stack_fp16x3_op = torch.library.triton_op( + "sezm_triton::so2_mixing_stack_fp16x3", mutates_args=() +)(_mixing_stack_fp16x3_impl) +_mixing_stack_fp16x3_bwd_op = torch.library.triton_op( + "sezm_triton::so2_mixing_stack_fp16x3_bwd", mutates_args=() +)(_mixing_stack_fp16x3_bwd_impl) + + +@_mixing_stack_fp16x3_op.register_fake +def _(u0, alpha, w0_all, w1_all, gw_all, lmax, focus_dim, apply_alpha): + n_focus, n_edge, row = u0.shape + return ( + u0.new_empty((n_edge, n_focus, row)), + u0.new_empty((gw_all.shape[0], n_focus, n_edge, row)), + ) + + +@_mixing_stack_fp16x3_bwd_op.register_fake +def _( + grad_out, + x_local, + z_all, + alpha, + w0t_all, + w1t_all, + gw_all, + gwt_all, + lmax, + focus_dim, + apply_alpha, +): + n_gated, n_focus, n_edge, row = z_all.shape + return ( + z_all.new_empty((n_focus, n_edge, row)), + z_all.new_empty((n_edge, n_focus)), + ) + + +def _setup_context(ctx, inputs, output): + u0, alpha, w0_all, w1_all, gw_all, lmax, focus_dim, apply_alpha = inputs + x_local, z_all = output + ctx.save_for_backward(alpha, x_local, z_all, w0_all, w1_all, gw_all) + ctx.lmax = lmax + ctx.focus_dim = focus_dim + ctx.apply_alpha = apply_alpha + + +def _backward(ctx, grad_out, grad_z_unused): + alpha, x_local, z_all, w0_all, w1_all, gw_all = ctx.saved_tensors + grad_u0, grad_alpha = _mixing_stack_fp16x3_bwd_op( + grad_out.contiguous(), + x_local, + z_all, + alpha, + w0_all.transpose(2, 3).contiguous(), + w1_all.transpose(2, 3).contiguous(), + gw_all, + gw_all.transpose(2, 3).contiguous(), + ctx.lmax, + ctx.focus_dim, + ctx.apply_alpha, + ) + return ( + grad_u0, + grad_alpha if ctx.apply_alpha else None, + None, + None, + None, + None, + None, + None, + ) + + +_mixing_stack_fp16x3_op.register_autograd(_backward, setup_context=_setup_context) + + +def mixing_stack_fp16x3( + u0: Tensor, + alpha: Tensor, + w0_all: Tensor, + w1_all: Tensor, + gw_all: Tensor, + lmax: int, + focus_dim: int, + apply_alpha: bool, +) -> tuple[Tensor, Tensor]: + """Run the SO(2) mixing stack through the fp16x3 tensor-core operator. + + Drop-in replacement for ``sezm_triton::so2_mixing_stack`` on shapes whose + launch configuration passed the fp64 validation sweep (see the module + docstring for the numerical scheme and its accuracy contract). + + Parameters + ---------- + u0 : Tensor + Focus-major stack input with shape (n_focus, n_edge, row), where + ``row = (3 * lmax + 1) * focus_dim``. + alpha : Tensor + Cross-focus competition weight with shape (n_edge, n_focus). + w0_all : Tensor + Stacked ``m = 0`` block weights with shape + (n_layers, n_focus, M0, M0), (in, out) convention. + w1_all : Tensor + Stacked ``|m| = 1`` block weights with shape + (n_layers, n_focus, M1, M1). + gw_all : Tensor + Stacked gate projections with shape + (n_layers - 1, n_focus, focus_dim, lmax * focus_dim). + lmax : int + Maximum spherical harmonic degree. + focus_dim : int + Per-focus channel width ``Cf``. + apply_alpha : bool + Whether the competition weight is folded into the final store. + + Returns + ------- + tuple[Tensor, Tensor] + The edge-major local features with shape (n_edge, n_focus, row) and + the stacked gated-layer pre-activations with shape + (n_layers - 1, n_focus, n_edge, row). + """ + return _mixing_stack_fp16x3_op( + u0, alpha, w0_all, w1_all, gw_all, lmax, focus_dim, apply_alpha + ) diff --git a/deepmd/kernels/triton/sezm/so2_value_path.py b/deepmd/kernels/triton/sezm/so2_value_path.py new file mode 100644 index 0000000000..23c102b6c6 --- /dev/null +++ b/deepmd/kernels/triton/sezm/so2_value_path.py @@ -0,0 +1,2309 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +# pyright: reportMissingImports=false +# ruff: noqa: ANN001, ANN202, RUF005 +"""Fused Triton SO(2) value path for the SeZM/DPA4 descriptor. + +The SO(2) value path of :class:`SO2Convolution` -- rotate-to-local, radial +degree mixing, the multi-layer gated SO(2) mixing stack, and the cross-focus +competition -- dominates both the time and the activation memory of a SeZM +inference step. This module fuses it into two functional operators: + +``sezm_triton::so2_rotate_mix`` + One kernel per edge: gathers the source node features, applies the + block-diagonal Wigner rotation over the structural non-zeros only (kept in + registers), applies the edge-conditioned radial degree mixing, and stores + the result directly in the focus-major flat layout ``(F, E, ROW)`` with + ``ROW = (3 * lmax + 1) * Cf`` that the mixing stack consumes. The rotated + pre-mix intermediate is never materialized. The backward recomputes the + rotation in registers (nothing is saved besides the operator inputs) and + reduces the per-edge node gradient with a contention-free CSR segment sum + (``sezm_triton::segment_sum``) instead of ``index_add_``: at typical + neighbor counts (~10^2 colliding edges per atom) row-atomic scatters + serialize and are several times slower. On narrow hidden widths the + backward dispatches to an edge-block kernel that replaces the per-edge + cross-lane ``tl.sum`` chains with batched axis-1 reductions; the win-list + table :func:`~.tile_configs.rotate_mix_bwd_block_config` decides per + ``(C_wide, lmax)`` key. + +``sezm_triton::so2_mixing_stack`` + The whole mixing stack -- ``n_layers - 1`` gated layers followed by one + identity layer, with the optional cross-focus competition weight folded + into the final store -- as a single operator. Keeping the inter-layer + activations inside the op (ordinary caching-allocator tensors) instead of + graph-level intermediates minimizes the compiled graph's activation + footprint; only the tensors the backward needs surface as outputs (the + stacked gated-layer pre-activations ``z_all`` and the result itself). + Gate sigmoids are recomputed in the backward from the saved ``z``. + +Per gated layer the stack runs three launches: a pure block GEMM for the +``m = 0`` block, a pointwise kernel evaluating the sigmoid gates from the +``l = 0`` scalar slice and finishing the ``m = 0`` rows, and a ``|m| = 1`` +block GEMM with the gate/residual epilogue fused in. The final identity +layer is two GEMM launches whose epilogue adds the residual, applies the +competition weight, and stores straight into the edge-major ``(E, F, ROW)`` +layout the fused attention aggregation consumes -- no reassembly copy. + +Layout contract +--------------- +The focus-major activation ``(F, E, ROW)`` orders each row m-major: +subtiles ``r = 0..lmax`` hold ``m = 0`` degrees ``l = r``; subtiles +``r = lmax+1..2*lmax`` and ``r = 2*lmax+1..3*lmax`` hold the ``m = -1`` and +``m = +1`` degrees ``l = 1..lmax``. The sigmoid gate group of subtile +``r > 0`` is ``(r - 1) % lmax`` for the ``m = 0`` rows and +``(r - lmax - 1) % lmax`` for the ``|m| = 1`` rows, matching +:class:`GatedActivation` with one gate group per degree ``l >= 1``. + +Weight passing discipline +------------------------- +Per-layer weights are stacked along dim 0 -- ``(n_layers, F, M, M)`` -- and +kernels select a layer through an integer ``layer`` argument. Slicing the +stack in Python and handing ``select`` views to the Triton higher-order op +must be avoided: Inductor's ``decompose_triton_kernel_wrapper_functional`` +re-traces the op body with ``replace_by_example`` and asserts node-for-node +graph equality, which view-typed kernel arguments break (clone insertion +differs between the two traces on PyTorch 2.11). + +Numerics +-------- +Every ``tl.dot`` runs with ``input_precision="ieee"`` (no TF32), keeping the +potential-energy surface smooth. fp32 is the supported precision; the +factory refuses non-fp32 weights rather than silently down-casting. Launch +tile choices never affect results (they change the schedule, not any +reduction order); the swept tables live in :mod:`.tile_configs`. At +``DP_TRITON_INFER >= 3`` the mixing stack is replaced by the fp16x3 +tensor-core operator of :mod:`.so2_stack_fp16x3` on validated shapes, the +one deliberate exception to the exact-fp32 contract. + +Wide-channel regime +------------------- +For ``Cf >= GATE_BMM_MIN_FOCUS_DIM`` the per-group ``CP x CP`` register dot +of the gate forward/backward spills (``CP`` is ``Cf`` padded to a power of +two), so the sigmoid projection and the gate-logit contraction run as cuBLAS +batched matmuls inside the op while the Triton kernels keep the pointwise +work. Non-power-of-two focus widths (e.g. ``Cf = 96``) are supported by the +same padding plus column masks; block GEMM kernels handle any ``Cf`` through +their edge masks, and their K loops stay exact because ``(lmax + 1) * Cf`` +and ``2 * lmax * Cf`` remain multiples of the K tile. +""" + +from __future__ import ( + annotations, +) + +from typing import ( + TYPE_CHECKING, +) + +import torch +from torch import ( + Tensor, +) +from torch.library import ( + wrap_triton, +) + +from deepmd.pt.model.descriptor.sezm_nn.indexing import ( + build_m_major_index, +) + +from .tile_configs import ( + GATE_BMM_MIN_FOCUS_DIM, + gate_config, + point_config, + recompute_config, + rotate_mix_bwd_block_config, + rotate_mix_fwd_config, + stack_fp16x3_configs, +) + +if TYPE_CHECKING: + from deepmd.pt.model.descriptor.sezm_nn.edge_cache import ( + EdgeFeatureCache, + ) + from deepmd.pt.model.descriptor.sezm_nn.so2 import ( + SO2Convolution, + ) + +__all__ = [ + "SO2_VALUE_PATH_TRITON_AVAILABLE", + "make_triton_value_path", +] + +try: + import triton + import triton.language as tl + + SO2_VALUE_PATH_TRITON_AVAILABLE = True +except ImportError: # pragma: no cover - exercised only without triton + SO2_VALUE_PATH_TRITON_AVAILABLE = False + +_SUPPORTED_FOCUS_DIMS = (32, 64, 96, 128) +_MAX_LMAX = 6 +_MAX_MIXER_RANK = 4 + +# Block GEMM tiling: 25 TFLOPS (~58% of the H20 FFMA peak) on the deployed +# block widths, at the measured efficiency ceiling of IEEE-fp32 tl.dot tiling. +# The configuration was confirmed optimal (or within 1%) across the whole +# swept (focus_dim, lmax) family, so it is a constant rather than a table. +_GEMM_CONFIG = (64, 64, 32, 4, 2) # (BLOCK_M, BLOCK_N, BLOCK_K, warps, stages) +_ROTATE_MIX_BWD_CONFIG = (1, 2) # per-edge backward (warps, stages) + + +# ====================================================================== +# Eager reference / fallback implementations +# ====================================================================== +def _rotate_mix_reference( + x: Tensor, + src: Tensor, + wigner: Tensor, + kc: Tensor, + cb: Tensor, + lmax: int, + n_focus: int, + rank: int, +) -> Tensor: + """Eager ground truth for ``so2_rotate_mix``. + + Rotates the gathered source features into the m-major ``mmax == 1`` + reduced layout, applies the radial degree mixing (rank-``R`` factorized + kernel, or the degree-wise multiply when ``rank == 0``), and returns the + focus-major ``(F, E, ROW)`` activation. + """ + n_edge = src.shape[0] + c_wide = x.shape[2] + focus_dim = c_wide // n_focus + dim = (lmax + 1) ** 2 + n_deg = lmax + 1 + reduced = 3 * lmax + 1 + coeff = build_m_major_index(lmax, 1, device=x.device) + d_to_m = wigner[:, :dim, :dim].index_select(1, coeff) + x_local = torch.bmm(d_to_m, x.index_select(0, src)) # (E, reduced, C_wide) + if rank == 0: + # kc holds per-degree radial features (E, lmax+1, C_wide); each reduced + # row is multiplied by the feature of its degree. + rad = kc.view(n_edge, n_deg, c_wide) + degree = torch.tensor( + list(range(n_deg)) + 2 * list(range(1, n_deg)), + device=x.device, + dtype=torch.long, + ) + y = x_local * rad.index_select(1, degree) + else: + kc_v = kc.view(n_edge, -1, rank) + k0 = kc_v[:, : n_deg * n_deg].view(n_edge, n_deg, n_deg, rank) + k1 = kc_v[:, n_deg * n_deg :].view(n_edge, lmax, lmax, rank) + cb_v = cb.view(rank, c_wide) + y = torch.empty_like(x_local) + y[:, :n_deg] = torch.einsum("eior,eic,rc->eoc", k0, x_local[:, :n_deg], cb_v) + y[:, n_deg : n_deg + lmax] = torch.einsum( + "eior,eic,rc->eoc", k1, x_local[:, n_deg : n_deg + lmax], cb_v + ) + y[:, n_deg + lmax :] = torch.einsum( + "eior,eic,rc->eoc", k1, x_local[:, n_deg + lmax :], cb_v + ) + return ( + y.view(n_edge, reduced, n_focus, focus_dim) + .permute(2, 0, 1, 3) + .reshape(n_focus, n_edge, reduced * focus_dim) + .contiguous() + ) + + +def _rotate_mix_backward_reference( + grad_u: Tensor, + x: Tensor, + src: Tensor, + wigner: Tensor, + kc: Tensor, + cb: Tensor, + lmax: int, + n_focus: int, + rank: int, +) -> tuple[Tensor, Tensor, Tensor]: + """Closed-form eager backward of ``so2_rotate_mix``. + + Returns ``(grad_x_edge, grad_wigner, grad_kc)`` where ``grad_x_edge`` is + the per-edge source gradient (the caller segment-sums it over ``src``). + A closed form (not a nested ``autograd.grad``) is required because the + backward operator is dispatched under ``_AutoDispatchBelowAutograd`` when + the frozen force graph replays under ``torch.no_grad``. + """ + n_edge = src.shape[0] + c_wide = x.shape[2] + focus_dim = c_wide // n_focus + dim = (lmax + 1) ** 2 + n_deg = lmax + 1 + reduced = 3 * lmax + 1 + coeff = build_m_major_index(lmax, 1, device=x.device) + d_to_m = wigner[:, :dim, :dim].index_select(1, coeff) + x_src = x.index_select(0, src) + x_local = torch.bmm(d_to_m, x_src) # (E, reduced, C_wide) + + g_y = ( + grad_u.view(n_focus, n_edge, reduced, focus_dim) + .permute(1, 2, 0, 3) + .reshape(n_edge, reduced, c_wide) + ) + if rank == 0: + rad = kc.view(n_edge, n_deg, c_wide) + degree = torch.tensor( + list(range(n_deg)) + 2 * list(range(1, n_deg)), + device=x.device, + dtype=torch.long, + ) + g_local = g_y * rad.index_select(1, degree) + prod = g_y * x_local + grad_kc = prod[:, :n_deg].clone() + grad_kc[:, 1:] += prod[:, n_deg : n_deg + lmax] + grad_kc[:, 1:] += prod[:, n_deg + lmax :] + grad_kc = grad_kc.reshape(kc.shape) + else: + kc_v = kc.view(n_edge, -1, rank) + k0 = kc_v[:, : n_deg * n_deg].view(n_edge, n_deg, n_deg, rank) + k1 = kc_v[:, n_deg * n_deg :].view(n_edge, lmax, lmax, rank) + cb_v = cb.view(rank, c_wide) + g_local = torch.empty_like(g_y) + g_local[:, :n_deg] = torch.einsum("eior,eoc,rc->eic", k0, g_y[:, :n_deg], cb_v) + g_local[:, n_deg : n_deg + lmax] = torch.einsum( + "eior,eoc,rc->eic", k1, g_y[:, n_deg : n_deg + lmax], cb_v + ) + g_local[:, n_deg + lmax :] = torch.einsum( + "eior,eoc,rc->eic", k1, g_y[:, n_deg + lmax :], cb_v + ) + gk0 = torch.einsum("eoc,eic,rc->eior", g_y[:, :n_deg], x_local[:, :n_deg], cb_v) + gk1 = torch.einsum( + "eoc,eic,rc->eior", + g_y[:, n_deg : n_deg + lmax], + x_local[:, n_deg : n_deg + lmax], + cb_v, + ) + torch.einsum( + "eoc,eic,rc->eior", + g_y[:, n_deg + lmax :], + x_local[:, n_deg + lmax :], + cb_v, + ) + grad_kc = torch.cat( + [gk0.reshape(n_edge, -1), gk1.reshape(n_edge, -1)], dim=1 + ).reshape(kc.shape) + + grad_x_edge = torch.bmm(d_to_m.transpose(1, 2), g_local) # (E, D, C_wide) + grad_rows = torch.bmm(g_local, x_src.transpose(1, 2)) # (E, reduced, D) + grad_block = wigner.new_zeros(n_edge, dim, dim) + grad_block.index_copy_(1, coeff, grad_rows) + grad_wigner = torch.zeros_like(wigner) + grad_wigner[:, :dim, :dim] = grad_block + return grad_x_edge, grad_wigner, grad_kc + + +def _mixing_stack_reference( + u0: Tensor, + alpha: Tensor, + w0_all: Tensor, + w1_all: Tensor, + gw_all: Tensor, + lmax: int, + focus_dim: int, + apply_alpha: bool, +) -> tuple[Tensor, Tensor]: + """Eager ground truth for ``so2_mixing_stack``. + + Returns the edge-major output ``(E, F, ROW)`` and the stacked gated-layer + pre-activations ``(n_gated, F, E, ROW)``. + """ + n_focus, n_edge, row = u0.shape + m0 = (lmax + 1) * focus_dim + n_gated = gw_all.shape[0] + u = u0 + z_saved = [] + for layer in range(n_gated): + z0 = torch.bmm(u[:, :, :m0], w0_all[layer]) + z1 = torch.bmm(u[:, :, m0:], w1_all[layer]) + z_saved.append(torch.cat([z0, z1], dim=-1)) + z_scalar = z0[:, :, :focus_dim] + sig = torch.sigmoid(torch.bmm(z_scalar, gw_all[layer])) # (F, E, lmax*Cf) + act = torch.cat( + [ + z_scalar * torch.sigmoid(z_scalar), + z0[:, :, focus_dim:] * sig, + z1 * sig.repeat(1, 1, 2), + ], + dim=-1, + ) + u = u + act + out = u.clone() + out[:, :, :m0] += torch.bmm(u[:, :, :m0], w0_all[n_gated]) + out[:, :, m0:] += torch.bmm(u[:, :, m0:], w1_all[n_gated]) + if apply_alpha: + out = out * alpha.transpose(0, 1).unsqueeze(-1).to(out.dtype) + x_local = out.permute(1, 0, 2).contiguous() + z_all = ( + torch.stack(z_saved) if n_gated > 0 else u0.new_empty(0, n_focus, n_edge, row) + ) + return x_local, z_all + + +def _mixing_stack_backward_reference( + grad_out: Tensor, + x_local: Tensor, + z_all: Tensor, + alpha: Tensor, + w0t_all: Tensor, + w1t_all: Tensor, + gw_all: Tensor, + gwt_all: Tensor, + lmax: int, + focus_dim: int, + apply_alpha: bool, +) -> tuple[Tensor, Tensor]: + """Closed-form eager backward of ``so2_mixing_stack``. + + Returns ``(grad_u0, grad_alpha)``; ``grad_alpha`` is meaningful only when + ``apply_alpha`` is set (the identity ``grad_alpha = sum(grad * out) / + alpha`` is exact because the final store is a plain scale). + """ + n_gated = gw_all.shape[0] + m0 = (lmax + 1) * focus_dim + g_edge = grad_out # (E, F, ROW) + if apply_alpha: + grad_alpha = (g_edge * x_local).sum(dim=-1) / alpha.clamp_min(1e-12) + g_edge = g_edge * alpha.unsqueeze(-1).to(g_edge.dtype) + else: + grad_alpha = torch.zeros_like(alpha) + g = g_edge.permute(1, 0, 2) # (F, E, ROW) + g_cur = g.clone() + g_cur[:, :, :m0] += torch.bmm(g[:, :, :m0], w0t_all[n_gated]) + g_cur[:, :, m0:] += torch.bmm(g[:, :, m0:], w1t_all[n_gated]) + for layer in range(n_gated - 1, -1, -1): + z = z_all[layer] + z0, z1 = z[:, :, :m0], z[:, :, m0:] + z_scalar = z0[:, :, :focus_dim] + sig = torch.sigmoid(torch.bmm(z_scalar, gw_all[layer])) + sig2 = sig.repeat(1, 1, 2) + s0 = torch.sigmoid(z_scalar) + gz0 = torch.cat( + [ + g_cur[:, :, :focus_dim] * s0 * (1.0 + z_scalar * (1.0 - s0)), + g_cur[:, :, focus_dim:m0] * sig, + ], + dim=-1, + ) + gz1 = g_cur[:, :, m0:] * sig2 + g_sig = (g_cur[:, :, focus_dim:m0] * z0[:, :, focus_dim:]).view(*sig.shape) + ( + g_cur[:, :, m0:] * z1 + ).view(sig.shape[0], sig.shape[1], 2, -1).sum(2) + g_logit = g_sig * sig * (1.0 - sig) + gz0 = torch.cat( + [ + gz0[:, :, :focus_dim] + torch.bmm(g_logit, gwt_all[layer]), + gz0[:, :, focus_dim:], + ], + dim=-1, + ) + g_next = g_cur.clone() + g_next[:, :, :m0] += torch.bmm(gz0, w0t_all[layer]) + g_next[:, :, m0:] += torch.bmm(gz1, w1t_all[layer]) + g_cur = g_next + return g_cur, grad_alpha + + +# ====================================================================== +# Triton kernels +# ====================================================================== +if SO2_VALUE_PATH_TRITON_AVAILABLE: + + @triton.jit + def _rotate_mix_fwd_kernel( + x_ptr, # (N, D, C_wide), strides (x_sn, x_sd, 1) + src_ptr, # (E,) + w_ptr, # (E, D, D) block-diagonal Wigner-D, contiguous + kc_ptr, # (E, KSZ * RANK) compact kernel, or (E, L+1, CW) when RANK == 0 + cb_ptr, # (RANK, CW) channel basis (unread when RANK == 0) + u_ptr, # (F, E, ROW) focus-major output + n_edge, + x_sn, + x_sd, + L: tl.constexpr, + CF: tl.constexpr, + CW: tl.constexpr, # true C_wide; BC = next_power_of_2(CW) lanes with mask + BC: tl.constexpr, + RANK: tl.constexpr, + ): + """One program per edge, channels vectorized. + + Phase 1 rotates the gathered source features over the structural + block-diagonal non-zeros only, holding the ``3 * L + 1`` reduced rows + in registers. Phase 2 applies the low-rank degree mixing + ``K_eff[i, o, c] = sum_r kc[i, o, r] * cb[r, c]`` (for ``RANK == 1`` + the channel basis factors out of the degree contraction and is applied + once per output row) and stores focus-major with channel decode + ``c = f * CF + cf``. ``RANK == 0`` is the mixer-free variant: each + reduced row is multiplied by the radial feature of its degree. + """ + NS0: tl.constexpr = L + 1 + RED: tl.constexpr = 3 * L + 1 + DIM: tl.constexpr = (L + 1) * (L + 1) + ROW: tl.constexpr = RED * CF + + edge = tl.program_id(0).to(tl.int64) + chan = tl.arange(0, BC) + cmask = chan < CW + src = tl.load(src_ptr + edge).to(tl.int64) + x_base = x_ptr + src * x_sn + d_base = w_ptr + edge * DIM * DIM + + # === Phase 1. Rotate to the local frame (registers) === + xrows = () + for r in tl.static_range(DIM): + xrows = xrows + ( + tl.load(x_base + r * x_sd + chan, mask=cmask, other=0.0).to(tl.float32), + ) + rows0 = () + rows_m = () + rows_p = () + for l in tl.static_range(L + 1): + base = l * l + r0 = base + l + acc0 = tl.zeros((BC,), dtype=tl.float32) + accm = tl.zeros((BC,), dtype=tl.float32) + accp = tl.zeros((BC,), dtype=tl.float32) + for j in tl.static_range(2 * l + 1): + xv = xrows[l * l + j] + acc0 += tl.load(d_base + r0 * DIM + base + j) * xv + if l >= 1: + accm += tl.load(d_base + (r0 - 1) * DIM + base + j) * xv + accp += tl.load(d_base + (r0 + 1) * DIM + base + j) * xv + rows0 = rows0 + (acc0,) + if l >= 1: + rows_m = rows_m + (accm,) + rows_p = rows_p + (accp,) + xl = rows0 + rows_m + rows_p + + # === Phase 2. Degree mix (or degree-wise multiply), store focus-major === + f_off = (chan // CF).to(tl.int64) * n_edge * ROW + edge * ROW + (chan % CF) + if RANK == 0: + rad_base = kc_ptr + edge * NS0 * CW + for o in tl.static_range(NS0): + rad = tl.load(rad_base + o * CW + chan, mask=cmask, other=0.0).to( + tl.float32 + ) + tl.store(u_ptr + f_off + o * CF, xl[o] * rad, mask=cmask) + for o in tl.static_range(L): + rad = tl.load(rad_base + (o + 1) * CW + chan, mask=cmask, other=0.0).to( + tl.float32 + ) + tl.store(u_ptr + f_off + (NS0 + o) * CF, xl[NS0 + o] * rad, mask=cmask) + tl.store( + u_ptr + f_off + (NS0 + L + o) * CF, + xl[NS0 + L + o] * rad, + mask=cmask, + ) + return + cb = () + for r in tl.static_range(RANK): + cb = cb + ( + tl.load(cb_ptr + r * CW + chan, mask=cmask, other=0.0).to(tl.float32), + ) + kc_base = kc_ptr + edge * (NS0 * NS0 + L * L) * RANK + for o in tl.static_range(NS0): + acc = tl.zeros((BC,), dtype=tl.float32) + for i in tl.static_range(NS0): + if RANK == 1: + acc += tl.load(kc_base + i * NS0 + o) * xl[i] + else: + keff = tl.zeros((BC,), dtype=tl.float32) + for r in tl.static_range(RANK): + keff += tl.load(kc_base + (i * NS0 + o) * RANK + r) * cb[r] + acc += keff * xl[i] + if RANK == 1: + acc = acc * cb[0] + tl.store(u_ptr + f_off + o * CF, acc, mask=cmask) + for o in tl.static_range(L): + accn = tl.zeros((BC,), dtype=tl.float32) + accq = tl.zeros((BC,), dtype=tl.float32) + for i in tl.static_range(L): + if RANK == 1: + k_val = tl.load(kc_base + NS0 * NS0 + i * L + o) + accn += k_val * xl[NS0 + i] + accq += k_val * xl[NS0 + L + i] + else: + keff = tl.zeros((BC,), dtype=tl.float32) + for r in tl.static_range(RANK): + keff += ( + tl.load(kc_base + (NS0 * NS0 + i * L + o) * RANK + r) + * cb[r] + ) + accn += keff * xl[NS0 + i] + accq += keff * xl[NS0 + L + i] + if RANK == 1: + accn = accn * cb[0] + accq = accq * cb[0] + tl.store(u_ptr + f_off + (NS0 + o) * CF, accn, mask=cmask) + tl.store(u_ptr + f_off + (NS0 + L + o) * CF, accq, mask=cmask) + + @triton.jit + def _rotate_mix_bwd_kernel( + gu_ptr, # (F, E, ROW) upstream gradient (focus-major) + x_ptr, + src_ptr, + w_ptr, + kc_ptr, + cb_ptr, + gxe_ptr, # (E, D, CW) per-edge node gradient (segment-summed by the caller) + gw_ptr, # (E, D, D) Wigner gradient (structural non-zeros; pre-zeroed) + gkc_ptr, # gradient of kc, same layout as kc + n_edge, + x_sn, + x_sd, + L: tl.constexpr, + CF: tl.constexpr, + CW: tl.constexpr, + BC: tl.constexpr, + RANK: tl.constexpr, + ): + """Backward of the fused front end (one program per edge). + + The rotated pre-mix rows are recomputed from ``x`` / ``W`` in + registers (the program reads both anyway), so the forward saves no + per-edge intermediate. The node gradient is written densely per edge + and reduced by a segment sum outside: a direct row-atomic scatter + serializes on the colliding edges of each atom. ``RANK == 0``: the + degree-kernel phase becomes the degree-wise product rule on the radial + features. + """ + NS0: tl.constexpr = L + 1 + RED: tl.constexpr = 3 * L + 1 + DIM: tl.constexpr = (L + 1) * (L + 1) + ROW: tl.constexpr = RED * CF + + edge = tl.program_id(0).to(tl.int64) + chan = tl.arange(0, BC) + cmask = chan < CW + src = tl.load(src_ptr + edge).to(tl.int64) + cb = () + for r in tl.static_range(RANK): + cb = cb + ( + tl.load(cb_ptr + r * CW + chan, mask=cmask, other=0.0).to(tl.float32), + ) + x_base = x_ptr + src * x_sn + d_base = w_ptr + edge * DIM * DIM + if RANK == 0: + kc_base = kc_ptr + edge * NS0 * CW + gkc_base = gkc_ptr + edge * NS0 * CW + else: + kc_base = kc_ptr + edge * (NS0 * NS0 + L * L) * RANK + gkc_base = gkc_ptr + edge * (NS0 * NS0 + L * L) * RANK + f_off = (chan // CF).to(tl.int64) * n_edge * ROW + edge * ROW + (chan % CF) + + # === Phase 0. Recompute the rotated rows; load the upstream rows === + xrows = () + for r in tl.static_range(DIM): + xrows = xrows + ( + tl.load(x_base + r * x_sd + chan, mask=cmask, other=0.0).to(tl.float32), + ) + rows0 = () + rows_m = () + rows_p = () + for l in tl.static_range(L + 1): + base = l * l + r0 = base + l + acc0 = tl.zeros((BC,), dtype=tl.float32) + accm = tl.zeros((BC,), dtype=tl.float32) + accp = tl.zeros((BC,), dtype=tl.float32) + for j in tl.static_range(2 * l + 1): + xv = xrows[l * l + j] + acc0 += tl.load(d_base + r0 * DIM + base + j) * xv + if l >= 1: + accm += tl.load(d_base + (r0 - 1) * DIM + base + j) * xv + accp += tl.load(d_base + (r0 + 1) * DIM + base + j) * xv + rows0 = rows0 + (acc0,) + if l >= 1: + rows_m = rows_m + (accm,) + rows_p = rows_p + (accp,) + xl = rows0 + rows_m + rows_p + # For RANK == 1 the channel basis is folded into the upstream rows + # once; the generic path applies cb inside the contractions. + gy = () + for r in tl.static_range(RED): + gval = tl.load(gu_ptr + f_off + r * CF, mask=cmask, other=0.0).to( + tl.float32 + ) + if RANK == 1: + gval = gval * cb[0] + gy = gy + (gval,) + + # === Phase 1. Degree-kernel (or radial-feature) gradient === + if RANK == 0: + tl.store(gkc_base + 0 * CW + chan, gy[0] * xl[0], mask=cmask) + for d in tl.static_range(1, NS0): + t = ( + gy[d] * xl[d] + + gy[NS0 + d - 1] * xl[NS0 + d - 1] + + gy[NS0 + L + d - 1] * xl[NS0 + L + d - 1] + ) + tl.store(gkc_base + d * CW + chan, t, mask=cmask) + for i in tl.static_range(NS0 if RANK > 0 else 0): + for o in tl.static_range(NS0): + if RANK == 1: + tl.store(gkc_base + i * NS0 + o, tl.sum(gy[o] * xl[i])) + else: + t = gy[o] * xl[i] + for r in tl.static_range(RANK): + tl.store(gkc_base + (i * NS0 + o) * RANK + r, tl.sum(t * cb[r])) + for i in tl.static_range(L if RANK > 0 else 0): + for o in tl.static_range(L): + if RANK == 1: + tl.store( + gkc_base + NS0 * NS0 + i * L + o, + tl.sum(gy[NS0 + o] * xl[NS0 + i]) + + tl.sum(gy[NS0 + L + o] * xl[NS0 + L + i]), + ) + else: + t = gy[NS0 + o] * xl[NS0 + i] + gy[NS0 + L + o] * xl[NS0 + L + i] + for r in tl.static_range(RANK): + tl.store( + gkc_base + (NS0 * NS0 + i * L + o) * RANK + r, + tl.sum(t * cb[r]), + ) + + # === Phase 2. Rotation backward with g_local formed on the fly === + gd_base = gw_ptr + edge * DIM * DIM + for l in tl.static_range(L + 1): + base = l * l + r0 = base + l + g0 = tl.zeros((BC,), dtype=tl.float32) + if RANK == 0: + rad_l = tl.load(kc_base + l * CW + chan, mask=cmask, other=0.0).to( + tl.float32 + ) + g0 = gy[l] * rad_l + for o in tl.static_range(NS0 if RANK > 0 else 0): + if RANK == 1: + g0 += tl.load(kc_base + l * NS0 + o) * gy[o] + else: + keff = tl.zeros((BC,), dtype=tl.float32) + for r in tl.static_range(RANK): + keff += tl.load(kc_base + (l * NS0 + o) * RANK + r) * cb[r] + g0 += keff * gy[o] + gm = tl.zeros((BC,), dtype=tl.float32) + gp = tl.zeros((BC,), dtype=tl.float32) + if l >= 1: + if RANK == 0: + gm = gy[NS0 + l - 1] * rad_l + gp = gy[NS0 + L + l - 1] * rad_l + for o in tl.static_range(L if RANK > 0 else 0): + if RANK == 1: + k_val = tl.load(kc_base + NS0 * NS0 + (l - 1) * L + o) + gm += k_val * gy[NS0 + o] + gp += k_val * gy[NS0 + L + o] + else: + keff = tl.zeros((BC,), dtype=tl.float32) + for r in tl.static_range(RANK): + keff += ( + tl.load( + kc_base + (NS0 * NS0 + (l - 1) * L + o) * RANK + r + ) + * cb[r] + ) + gm += keff * gy[NS0 + o] + gp += keff * gy[NS0 + L + o] + for j in tl.static_range(2 * l + 1): + col = base + j + xv = xrows[l * l + j] + w0 = tl.load(d_base + r0 * DIM + col) + gx_row = w0 * g0 + tl.store(gd_base + r0 * DIM + col, tl.sum(g0 * xv)) + if l >= 1: + wmv = tl.load(d_base + (r0 - 1) * DIM + col) + wpv = tl.load(d_base + (r0 + 1) * DIM + col) + gx_row += wmv * gm + wpv * gp + tl.store(gd_base + (r0 - 1) * DIM + col, tl.sum(gm * xv)) + tl.store(gd_base + (r0 + 1) * DIM + col, tl.sum(gp * xv)) + tl.store( + gxe_ptr + edge * DIM * CW + col * CW + chan, gx_row, mask=cmask + ) + + @triton.jit + def _rotate_mix_bwd_block_kernel( + gu_ptr, # (F, E, ROW) upstream gradient (focus-major) + x_ptr, # (N, D, CW) node features + src_ptr, # (E,) + w_ptr, # (E, D, D) block-diagonal Wigner-D + kc_ptr, # (E, KSZ) rank-1 compact kernel, or (E, L+1, CW) when RANK == 0 + cb_ptr, # (1, CW) channel basis (RANK == 1) + gxe_ptr, # (E, D, CW) per-edge node gradient out + gw_ptr, # (E, D, D) Wigner gradient out (structural non-zeros; pre-zeroed) + gkc_ptr, # gradient of kc out, same layout as kc + n_edge, + x_sn, + x_sd, + L: tl.constexpr, + CF: tl.constexpr, + CW: tl.constexpr, + CP: tl.constexpr, # next power of two >= CW (vector lane count) + RANK: tl.constexpr, + BLOCK_E: tl.constexpr, + ): + """Edge-block variant of the rotate+mix backward. + + The per-edge kernel closes one cross-lane ``tl.sum`` per ``grad_kc`` + entry and per structural Wigner non-zero -- serialized warp + shuffle-reduction chains that dominate its runtime on narrow hidden + widths. This variant processes ``BLOCK_E`` edges per program with + channels as the vector axis: every reduction becomes one batched + axis-1 reduction of a ``(BLOCK_E, CP)`` tile, and the per-edge Wigner + and kernel scalars are loaded as coalesced ``(BLOCK_E,)`` vectors. + The rotated rows are recomputed in registers, matching the per-edge + kernel's saved-nothing contract. Channels are padded to the + power-of-two lane count ``CP`` with masked lanes (masked lanes issue + no memory traffic; they only raise register pressure, which the + launch table absorbs with a smaller ``BLOCK_E``). + + The schedule wins only where the reduction overhead of the per-edge + kernel dominates; :func:`tile_configs.rotate_mix_bwd_block_config` + acts as the win list, and ``RANK`` must be at most 1 (the per-focus + upstream fold applies a single channel basis). + """ + NS0: tl.constexpr = L + 1 + RED: tl.constexpr = 3 * L + 1 + DIM: tl.constexpr = (L + 1) * (L + 1) + ROW: tl.constexpr = RED * CF + KSZ: tl.constexpr = NS0 * NS0 + L * L + PADDED: tl.constexpr = CP != CW + + pid = tl.program_id(0) + offs_e = (pid * BLOCK_E + tl.arange(0, BLOCK_E)).to(tl.int64) + e_mask = offs_e < n_edge + eq = tl.where(e_mask, offs_e, 0) + chan = tl.arange(0, CP) + if PADDED: + c_mask = chan < CW + em = e_mask[:, None] & c_mask[None, :] + chan_c = tl.where(c_mask, chan, 0) + else: + em = e_mask[:, None] + chan_c = chan + + src = tl.load(src_ptr + eq, mask=e_mask, other=0).to(tl.int64) + x_base = x_ptr + (src * x_sn)[:, None] + d_base = w_ptr + eq * DIM * DIM + gd_base = gw_ptr + eq * DIM * DIM + gxe_base = gxe_ptr + (eq * DIM * CW)[:, None] + if RANK == 0: + kc_base = kc_ptr + (eq * NS0 * CW)[:, None] + gkc_base = gkc_ptr + (eq * NS0 * CW)[:, None] + else: + kc_base = kc_ptr + eq * KSZ + gkc_base = gkc_ptr + eq * KSZ + + # Focus-major upstream offset of channel c = f * CF + cf. + f_off = ( + gu_ptr + + ((chan_c // CF).to(tl.int64) * n_edge * ROW + (chan_c % CF))[None, :] + + (eq * ROW)[:, None] + ) + + # === Phase 0. Upstream rows (channel basis folded once, RANK == 1) === + gy = () + if RANK == 1: + cbv = tl.load(cb_ptr + chan, mask=(chan < CW), other=0.0)[None, :] + for r in tl.static_range(RED): + gval = tl.load(f_off + r * CF, mask=em, other=0.0) + if RANK == 1: + gval = gval * cbv + gy = gy + (gval,) + + # === Phase 1. Per degree: recompute rotation, kernel grads, gx, gD === + for l in tl.static_range(L + 1): + base = l * l + r0 = base + l + + xrows = () + for j in tl.static_range(2 * l + 1): + xrows = xrows + ( + tl.load( + x_base + (base + j) * x_sd + chan_c[None, :], + mask=em, + other=0.0, + ), + ) + xl0 = tl.zeros((BLOCK_E, CP), dtype=tl.float32) + xlm = tl.zeros((BLOCK_E, CP), dtype=tl.float32) + xlp = tl.zeros((BLOCK_E, CP), dtype=tl.float32) + for j in tl.static_range(2 * l + 1): + xv = xrows[j] + w0 = tl.load(d_base + r0 * DIM + base + j, mask=e_mask, other=0.0) + xl0 += w0[:, None] * xv + if l >= 1: + wm = tl.load( + d_base + (r0 - 1) * DIM + base + j, mask=e_mask, other=0.0 + ) + wp = tl.load( + d_base + (r0 + 1) * DIM + base + j, mask=e_mask, other=0.0 + ) + xlm += wm[:, None] * xv + xlp += wp[:, None] * xv + + # Kernel gradient rows of input degree l. + if RANK == 0: + if l == 0: + t = gy[0] * xl0 + else: + t = gy[l] * xl0 + gy[NS0 + l - 1] * xlm + gy[NS0 + L + l - 1] * xlp + tl.store(gkc_base + l * CW + chan[None, :], t, mask=em) + else: + for o in tl.static_range(NS0): + tl.store( + gkc_base + l * NS0 + o, + tl.sum(gy[o] * xl0, axis=1), + mask=e_mask, + ) + if l >= 1: + for o in tl.static_range(L): + tl.store( + gkc_base + NS0 * NS0 + (l - 1) * L + o, + tl.sum(gy[NS0 + o] * xlm + gy[NS0 + L + o] * xlp, axis=1), + mask=e_mask, + ) + + # Local-frame gradients of the reduced rows of degree l. + g0 = tl.zeros((BLOCK_E, CP), dtype=tl.float32) + gm = tl.zeros((BLOCK_E, CP), dtype=tl.float32) + gp = tl.zeros((BLOCK_E, CP), dtype=tl.float32) + if RANK == 0: + rad_l = tl.load(kc_base + l * CW + chan[None, :], mask=em, other=0.0) + g0 = gy[l] * rad_l + if l >= 1: + gm = gy[NS0 + l - 1] * rad_l + gp = gy[NS0 + L + l - 1] * rad_l + else: + for o in tl.static_range(NS0): + k_val = tl.load(kc_base + l * NS0 + o, mask=e_mask, other=0.0) + g0 += k_val[:, None] * gy[o] + if l >= 1: + for o in tl.static_range(L): + k_val = tl.load( + kc_base + NS0 * NS0 + (l - 1) * L + o, + mask=e_mask, + other=0.0, + ) + gm += k_val[:, None] * gy[NS0 + o] + gp += k_val[:, None] * gy[NS0 + L + o] + + # Rotation backward: node gradient rows and Wigner gradients. + for j in tl.static_range(2 * l + 1): + col = base + j + xv = xrows[j] + w0 = tl.load(d_base + r0 * DIM + col, mask=e_mask, other=0.0) + gx_row = w0[:, None] * g0 + tl.store(gd_base + r0 * DIM + col, tl.sum(g0 * xv, axis=1), mask=e_mask) + if l >= 1: + wm = tl.load(d_base + (r0 - 1) * DIM + col, mask=e_mask, other=0.0) + wp = tl.load(d_base + (r0 + 1) * DIM + col, mask=e_mask, other=0.0) + gx_row += wm[:, None] * gm + wp[:, None] * gp + tl.store( + gd_base + (r0 - 1) * DIM + col, + tl.sum(gm * xv, axis=1), + mask=e_mask, + ) + tl.store( + gd_base + (r0 + 1) * DIM + col, + tl.sum(gp * xv, axis=1), + mask=e_mask, + ) + tl.store(gxe_base + col * CW + chan[None, :], gx_row, mask=em) + + @triton.jit + def _segment_sum_kernel( + rows_ptr, # (E, P) per-edge rows + order_ptr, # (E,) edge ids sorted by segment key + row_ptr_ptr, # (N + 1,) CSR offsets into ``order`` + out_ptr, # (N, P) + P: tl.constexpr, + BC: tl.constexpr, + ): + """Indirect CSR segment sum: ``out[n] = sum_{i in seg(n)} rows[order[i]]``. + + Replaces the row-atomic scatter / ``index_add_`` of the edge-to-node + reduction; the contention-free segmented read is several times faster + than atomics at typical per-atom edge counts. + """ + node = tl.program_id(0).to(tl.int64) + chunk = tl.program_id(1) + cols = chunk * BC + tl.arange(0, BC) + col_mask = cols < P + beg = tl.load(row_ptr_ptr + node).to(tl.int64) + end = tl.load(row_ptr_ptr + node + 1).to(tl.int64) + acc = tl.zeros((BC,), dtype=tl.float32) + for i in range(beg, end): + e = tl.load(order_ptr + i).to(tl.int64) + acc += tl.load(rows_ptr + e * P + cols, mask=col_mask, other=0.0) + tl.store(out_ptr + node * P + cols, acc, mask=col_mask) + + @triton.jit + def _stack_gemm_m0_kernel( + u_ptr, # (F, E, ROW) layer input + w0_ptr, # (NL, F, M0, M0) stacked weights, layer selected by ``layer`` + alpha_ptr, # (E, F) competition weight (identity epilogue only) + v_ptr, # z_all stack (gated) or the final output (identity epilogue) + n_edge, + layer, + L: tl.constexpr, + CF: tl.constexpr, + EPILOGUE: tl.constexpr, # 0: store raw z; 1: residual (+ alpha) output + V_EDGE_MAJOR: tl.constexpr, # v is (E, F, ROW); else focus-major (F, E, ROW) + APPLY_ALPHA: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ): + """``m = 0`` block GEMM ``z = u[:, :M0] @ W0`` with an optional epilogue. + + Output strides are derived in-kernel from the layout flag on int64 + offsets: a host-side ``n_edge * ROW`` scalar argument would be + specialized to int32 by the first (small) compilation and overflow + on systems beyond ~2^31 / ROW edges. + """ + M0: tl.constexpr = (L + 1) * CF + ROW: tl.constexpr = (3 * L + 1) * CF + NT: tl.constexpr = (M0 + BLOCK_N - 1) // BLOCK_N + + pid = tl.program_id(0) + fid = tl.program_id(1).to(tl.int64) + n_focus = tl.num_programs(1) + pid_m = pid // NT + pid_n = pid % NT + + offs_m = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64) + m_mask = offs_m < n_edge + mm = m_mask[:, None] + u_row = u_ptr + fid * n_edge * ROW + offs_m * ROW + offs_k = tl.arange(0, BLOCK_K) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + n_mask = offs_n < M0 + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + a_ptrs = u_row[:, None] + offs_k[None, :] + w_ptrs = ( + w0_ptr + + (layer * n_focus + fid) * M0 * M0 + + offs_k[:, None] * M0 + + offs_n[None, :] + ) + for _ in range(0, M0, BLOCK_K): + a = tl.load(a_ptrs, mask=mm, other=0.0) + w = tl.load(w_ptrs, mask=n_mask[None, :], other=0.0) + acc = tl.dot(a, w, acc, input_precision="ieee") + a_ptrs += BLOCK_K + w_ptrs += BLOCK_K * M0 + + if EPILOGUE == 1: + u_t = tl.load( + u_row[:, None] + offs_n[None, :], mask=mm & n_mask[None, :], other=0.0 + ) + acc = acc + u_t + if APPLY_ALPHA: + alpha = tl.load( + alpha_ptr + offs_m * n_focus + fid, mask=m_mask, other=0.0 + ) + acc = acc * alpha[:, None] + if V_EDGE_MAJOR: + v_row = v_ptr + fid * ROW + offs_m * (n_focus * ROW) + else: + v_row = v_ptr + fid * n_edge * ROW + offs_m * ROW + tl.store(v_row[:, None] + offs_n[None, :], acc, mask=mm & n_mask[None, :]) + else: + z_row = v_ptr + (layer * n_focus + fid) * n_edge * ROW + offs_m * ROW + tl.store(z_row[:, None] + offs_n[None, :], acc, mask=mm & n_mask[None, :]) + + @triton.jit + def _stack_gate_kernel( + u_ptr, + z_ptr, # z_all stack, layer selected by ``layer`` + gw_ptr, # (NL, F, CF, L*CF) stacked gate projections + v_ptr, # (F, E, ROW) layer output, focus-major + sig_ptr, # (F, E, L*CF); output when SIG_IN == 0, input when SIG_IN == 1 + n_edge, + layer, + L: tl.constexpr, + CF: tl.constexpr, + SIG_IN: tl.constexpr, + BLOCK_M: tl.constexpr, + ): + """Gate evaluation and ``m = 0`` finish: ``v = u + act(z)`` on the m0 rows. + + Register tiles are ``CP`` wide (``CF`` padded to a power of two) with a + column mask, so non-power-of-two focus widths are supported; padded dot + lanes carry zeros and are never stored. With ``SIG_IN`` the sigmoid + projection has already been produced by a cuBLAS bmm (wide-channel + regime) and this kernel only reads it. + """ + ROW: tl.constexpr = (3 * L + 1) * CF + LG: tl.constexpr = L * CF + CP: tl.constexpr = triton.next_power_of_2(CF) + + pid_m = tl.program_id(0) + fid = tl.program_id(1).to(tl.int64) + n_focus = tl.num_programs(1) + + offs_m = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64) + m_mask = offs_m < n_edge + mm = m_mask[:, None] + nc = tl.arange(0, CP) + cm = mm & (nc < CF)[None, :] + wm = ((nc < CF)[:, None]) & ((nc < CF)[None, :]) + + u_row = u_ptr + fid * n_edge * ROW + offs_m * ROW + z_row = z_ptr + (layer * n_focus + fid) * n_edge * ROW + offs_m * ROW + v_row = v_ptr + fid * n_edge * ROW + offs_m * ROW + sig_row = sig_ptr + (fid * n_edge + offs_m) * LG + + # l = 0 scalar rows pass through silu. + z_s = tl.load(z_row[:, None] + nc[None, :], mask=cm, other=0.0) + u_s = tl.load(u_row[:, None] + nc[None, :], mask=cm, other=0.0) + tl.store(v_row[:, None] + nc[None, :], u_s + z_s * tl.sigmoid(z_s), mask=cm) + + # Per-group sigmoid gates and the gated m = 0 rows. + for g in tl.static_range(L): + if SIG_IN: + sig_g = tl.load( + sig_row[:, None] + (g * CF + nc)[None, :], mask=cm, other=0.0 + ) + else: + gw_g = tl.load( + gw_ptr + + (layer * n_focus + fid) * CF * LG + + nc[:, None] * LG + + (g * CF + nc)[None, :], + mask=wm, + other=0.0, + ) + sig_g = tl.sigmoid(tl.dot(z_s, gw_g, input_precision="ieee")) + tl.store(sig_row[:, None] + (g * CF + nc)[None, :], sig_g, mask=cm) + z_g = tl.load( + z_row[:, None] + ((1 + g) * CF + nc)[None, :], mask=cm, other=0.0 + ) + u_g = tl.load( + u_row[:, None] + ((1 + g) * CF + nc)[None, :], mask=cm, other=0.0 + ) + tl.store( + v_row[:, None] + ((1 + g) * CF + nc)[None, :], + u_g + z_g * sig_g, + mask=cm, + ) + + @triton.jit + def _stack_gemm_m1_kernel( + u_ptr, + w1_ptr, # (NL, F, M1, M1) stacked weights, layer selected by ``layer`` + sig_ptr, + alpha_ptr, + v_ptr, + z_ptr, # z_all stack, layer selected by ``layer`` + n_edge, + layer, + L: tl.constexpr, + CF: tl.constexpr, + HAS_GATE: tl.constexpr, + V_EDGE_MAJOR: tl.constexpr, # v is (E, F, ROW); else focus-major (F, E, ROW) + APPLY_ALPHA: tl.constexpr, + SAVE_Z: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ): + """``|m| = 1`` block GEMM with the gate / residual / alpha epilogue fused.""" + M0: tl.constexpr = (L + 1) * CF + M1: tl.constexpr = 2 * L * CF + ROW: tl.constexpr = (3 * L + 1) * CF + LG: tl.constexpr = L * CF + NT: tl.constexpr = (M1 + BLOCK_N - 1) // BLOCK_N + + pid = tl.program_id(0) + fid = tl.program_id(1).to(tl.int64) + n_focus = tl.num_programs(1) + pid_m = pid // NT + pid_n = pid % NT + + offs_m = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64) + m_mask = offs_m < n_edge + mm = m_mask[:, None] + u_row = u_ptr + fid * n_edge * ROW + offs_m * ROW + offs_k = tl.arange(0, BLOCK_K) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + n_mask = offs_n < M1 + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + a_ptrs = u_row[:, None] + (M0 + offs_k)[None, :] + w_ptrs = ( + w1_ptr + + (layer * n_focus + fid) * M1 * M1 + + offs_k[:, None] * M1 + + offs_n[None, :] + ) + for _ in range(0, M1, BLOCK_K): + a = tl.load(a_ptrs, mask=mm, other=0.0) + w = tl.load(w_ptrs, mask=n_mask[None, :], other=0.0) + acc = tl.dot(a, w, acc, input_precision="ieee") + a_ptrs += BLOCK_K + w_ptrs += BLOCK_K * M1 + + if SAVE_Z: + z_row = z_ptr + (layer * n_focus + fid) * n_edge * ROW + offs_m * ROW + tl.store( + z_row[:, None] + (M0 + offs_n)[None, :], acc, mask=mm & n_mask[None, :] + ) + if HAS_GATE: + # Both |m| = 1 stripes of degree group g share gate group g. + sig_cols = ((offs_n // CF) % L) * CF + (offs_n % CF) + sig = tl.load( + sig_ptr + (fid * n_edge + offs_m)[:, None] * LG + sig_cols[None, :], + mask=mm & n_mask[None, :], + other=0.0, + ) + acc = acc * sig + u_t = tl.load( + u_row[:, None] + (M0 + offs_n)[None, :], + mask=mm & n_mask[None, :], + other=0.0, + ) + acc = acc + u_t + if APPLY_ALPHA: + alpha = tl.load(alpha_ptr + offs_m * n_focus + fid, mask=m_mask, other=0.0) + acc = acc * alpha[:, None] + if V_EDGE_MAJOR: + v_row = v_ptr + fid * ROW + offs_m * (n_focus * ROW) + else: + v_row = v_ptr + fid * n_edge * ROW + offs_m * ROW + tl.store( + v_row[:, None] + (M0 + offs_n)[None, :], acc, mask=mm & n_mask[None, :] + ) + + @triton.jit + def _stack_recompute_kernel( + z_ptr, # z_all stack (NL, F, E, ROW), layer selected by ``layer`` + gw_ptr, # (NL, F, CF, L*CF) stacked gate projections + sig_ptr, # (F, E, L*CF) output + n_edge, + layer, + L: tl.constexpr, + CF: tl.constexpr, + BLOCK_M: tl.constexpr, + ): + """Recompute the gate sigmoids from the saved pre-activation (backward).""" + ROW: tl.constexpr = (3 * L + 1) * CF + LG: tl.constexpr = L * CF + CP: tl.constexpr = triton.next_power_of_2(CF) + + pid_m = tl.program_id(0) + fid = tl.program_id(1).to(tl.int64) + n_focus = tl.num_programs(1) + + offs_m = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64) + m_mask = offs_m < n_edge + mm = m_mask[:, None] + nc = tl.arange(0, CP) + cm = mm & (nc < CF)[None, :] + wm = ((nc < CF)[:, None]) & ((nc < CF)[None, :]) + + z_row = z_ptr + (layer * n_focus + fid) * n_edge * ROW + offs_m * ROW + z_s = tl.load(z_row[:, None] + nc[None, :], mask=cm, other=0.0) + sig_row = sig_ptr + (fid * n_edge + offs_m) * LG + for g in tl.static_range(L): + gw_g = tl.load( + gw_ptr + + (layer * n_focus + fid) * CF * LG + + nc[:, None] * LG + + (g * CF + nc)[None, :], + mask=wm, + other=0.0, + ) + sig_g = tl.sigmoid(tl.dot(z_s, gw_g, input_precision="ieee")) + tl.store(sig_row[:, None] + (g * CF + nc)[None, :], sig_g, mask=cm) + + @triton.jit + def _stack_point_bwd_kernel( + g_ptr, # (F, E, ROW) upstream gradient of the layer output + z_ptr, # z_all stack, layer selected by ``layer`` + sig_ptr, # (F, E, L*CF) gate sigmoids + gwt_ptr, # (NL, F, L*CF, CF) transposed gate projections + gz_ptr, # (F, E, ROW) pre-activation gradient output + gl_ptr, # (F, E, L*CF) gate-logit gradient output (GLOGIT_OUT only) + n_edge, + layer, + L: tl.constexpr, + CF: tl.constexpr, + GLOGIT_OUT: tl.constexpr, + BLOCK_M: tl.constexpr, + ): + """Pointwise part of the gated-layer backward. + + Produces the pre-activation gradient ``gz`` for the value rows and the + gate-path contribution to the ``l = 0`` scalar rows. The gate-logit + contraction back to the scalars is either folded in as a ``CP x CP`` + register dot (small ``CF``) or emitted to ``gl`` for an external + batched GEMM (wide-channel regime, where the register dot spills). + """ + ROW: tl.constexpr = (3 * L + 1) * CF + LG: tl.constexpr = L * CF + CP: tl.constexpr = triton.next_power_of_2(CF) + + pid_m = tl.program_id(0) + fid = tl.program_id(1).to(tl.int64) + n_focus = tl.num_programs(1) + + offs_m = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64) + m_mask = offs_m < n_edge + mm = m_mask[:, None] + nc = tl.arange(0, CP) + cm = mm & (nc < CF)[None, :] + wm = ((nc < CF)[:, None]) & ((nc < CF)[None, :]) + + g_row = g_ptr + fid * n_edge * ROW + offs_m * ROW + z_row = z_ptr + (layer * n_focus + fid) * n_edge * ROW + offs_m * ROW + gz_row = gz_ptr + fid * n_edge * ROW + offs_m * ROW + sig_row = sig_ptr + (fid * n_edge + offs_m) * LG + gl_row = gl_ptr + (fid * n_edge + offs_m) * LG + + # l = 0 value path: silu backward. + z_s = tl.load(z_row[:, None] + nc[None, :], mask=cm, other=0.0) + g_s = tl.load(g_row[:, None] + nc[None, :], mask=cm, other=0.0) + s0 = tl.sigmoid(z_s) + gz_s = g_s * s0 * (1.0 + z_s * (1.0 - s0)) + + for g in tl.static_range(L): + sig_g = tl.load( + sig_row[:, None] + (g * CF + nc)[None, :], mask=cm, other=0.0 + ) + gr0 = tl.load( + g_row[:, None] + ((1 + g) * CF + nc)[None, :], mask=cm, other=0.0 + ) + zr0 = tl.load( + z_row[:, None] + ((1 + g) * CF + nc)[None, :], mask=cm, other=0.0 + ) + tl.store( + gz_row[:, None] + ((1 + g) * CF + nc)[None, :], gr0 * sig_g, mask=cm + ) + rn = (L + 1) + g + grn = tl.load(g_row[:, None] + (rn * CF + nc)[None, :], mask=cm, other=0.0) + zrn = tl.load(z_row[:, None] + (rn * CF + nc)[None, :], mask=cm, other=0.0) + tl.store(gz_row[:, None] + (rn * CF + nc)[None, :], grn * sig_g, mask=cm) + rp = (2 * L + 1) + g + grp = tl.load(g_row[:, None] + (rp * CF + nc)[None, :], mask=cm, other=0.0) + zrp = tl.load(z_row[:, None] + (rp * CF + nc)[None, :], mask=cm, other=0.0) + tl.store(gz_row[:, None] + (rp * CF + nc)[None, :], grp * sig_g, mask=cm) + # Gate path: three value rows share gate group g. + g_sig = gr0 * zr0 + grn * zrn + grp * zrp + g_logit = g_sig * sig_g * (1.0 - sig_g) + if GLOGIT_OUT: + tl.store(gl_row[:, None] + (g * CF + nc)[None, :], g_logit, mask=cm) + else: + gwt_g = tl.load( + gwt_ptr + + (layer * n_focus + fid) * LG * CF + + (g * CF + nc)[:, None] * CF + + nc[None, :], + mask=wm, + other=0.0, + ) + gz_s = tl.dot(g_logit, gwt_g, gz_s, input_precision="ieee") + + tl.store(gz_row[:, None] + nc[None, :], gz_s, mask=cm) + + @triton.jit + def _stack_gemm_bwd_kernel( + gz_ptr, # (F, E, ROW), or the raw upstream gradient when FOLD_ALPHA + res_ptr, # (F, E, ROW) residual gradient source; unread if FOLD_ALPHA + w0t_ptr, # (NL, F, M0, M0) stacked transposed weights + w1t_ptr, # (NL, F, M1, M1) stacked transposed weights + alpha_ptr, + gu_ptr, # (F, E, ROW) layer-input gradient + n_edge, + layer, + L: tl.constexpr, + CF: tl.constexpr, + G_EDGE_MAJOR: tl.constexpr, # gz is (E, F, ROW); else focus-major + FOLD_ALPHA: tl.constexpr, # gz = g * alpha on the fly; residual == gz + RES_IS_GZ: tl.constexpr, # residual equals gz (final layer, no alpha) + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ): + """Backward block GEMM ``g_u = residual + gz @ W^T`` over both blocks. + + The upstream gradient is edge-major only on the final layer, where + the residual aliases it (``RES_IS_GZ`` or ``FOLD_ALPHA``); an + explicit residual pointer is always focus-major. Strides are + derived in-kernel on int64 offsets (see ``_stack_gemm_m0_kernel``). + """ + M0: tl.constexpr = (L + 1) * CF + M1: tl.constexpr = 2 * L * CF + ROW: tl.constexpr = (3 * L + 1) * CF + NT0: tl.constexpr = (M0 + BLOCK_N - 1) // BLOCK_N + NT1: tl.constexpr = (M1 + BLOCK_N - 1) // BLOCK_N + NT: tl.constexpr = NT0 + NT1 + + pid = tl.program_id(0) + fid = tl.program_id(1).to(tl.int64) + n_focus = tl.num_programs(1) + pid_m = pid // NT + pid_n = pid % NT + + offs_m = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64) + m_mask = offs_m < n_edge + mm = m_mask[:, None] + offs_k = tl.arange(0, BLOCK_K) + + if G_EDGE_MAJOR: + gz_row = gz_ptr + fid * ROW + offs_m * (n_focus * ROW) + else: + gz_row = gz_ptr + fid * n_edge * ROW + offs_m * ROW + gu_row = gu_ptr + fid * n_edge * ROW + offs_m * ROW + if FOLD_ALPHA: + alpha = tl.load(alpha_ptr + offs_m * n_focus + fid, mask=m_mask, other=0.0) + + if pid_n < NT0: + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + n_mask = offs_n < M0 + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + a_ptrs = gz_row[:, None] + offs_k[None, :] + w_ptrs = ( + w0t_ptr + + (layer * n_focus + fid) * M0 * M0 + + offs_k[:, None] * M0 + + offs_n[None, :] + ) + for _ in range(0, M0, BLOCK_K): + a = tl.load(a_ptrs, mask=mm, other=0.0) + if FOLD_ALPHA: + a = a * alpha[:, None] + w = tl.load(w_ptrs, mask=n_mask[None, :], other=0.0) + acc = tl.dot(a, w, acc, input_precision="ieee") + a_ptrs += BLOCK_K + w_ptrs += BLOCK_K * M0 + col0 = offs_n + col_mask = n_mask + else: + offs_n = (pid_n - NT0) * BLOCK_N + tl.arange(0, BLOCK_N) + n_mask = offs_n < M1 + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + a_ptrs = gz_row[:, None] + (M0 + offs_k)[None, :] + w_ptrs = ( + w1t_ptr + + (layer * n_focus + fid) * M1 * M1 + + offs_k[:, None] * M1 + + offs_n[None, :] + ) + for _ in range(0, M1, BLOCK_K): + a = tl.load(a_ptrs, mask=mm, other=0.0) + if FOLD_ALPHA: + a = a * alpha[:, None] + w = tl.load(w_ptrs, mask=n_mask[None, :], other=0.0) + acc = tl.dot(a, w, acc, input_precision="ieee") + a_ptrs += BLOCK_K + w_ptrs += BLOCK_K * M1 + col0 = M0 + offs_n + col_mask = n_mask + + if FOLD_ALPHA: + res = tl.load( + gz_row[:, None] + col0[None, :], mask=mm & col_mask[None, :], other=0.0 + ) + res = res * alpha[:, None] + elif RES_IS_GZ: + res = tl.load( + gz_row[:, None] + col0[None, :], mask=mm & col_mask[None, :], other=0.0 + ) + else: + res_row = res_ptr + fid * n_edge * ROW + offs_m * ROW + res = tl.load( + res_row[:, None] + col0[None, :], mask=mm & col_mask[None, :], other=0.0 + ) + tl.store( + gu_row[:, None] + col0[None, :], acc + res, mask=mm & col_mask[None, :] + ) + + @triton.jit + def _stack_grad_alpha_kernel( + g_ptr, # (E, F, ROW) edge-major upstream gradient + out_ptr, # (E, F, ROW) forward output + alpha_ptr, # (E, F) + ga_ptr, # (E, F) + n_edge, + L: tl.constexpr, + CF: tl.constexpr, + BLOCK_M: tl.constexpr, + ): + """Competition-weight gradient from the identity ``grad_alpha = + sum(grad * out) / alpha`` -- exact because the final store is a plain + scale, saving the two pre-scale activation copies. + """ + ROW: tl.constexpr = (3 * L + 1) * CF + CP: tl.constexpr = triton.next_power_of_2(CF) + + pid_m = tl.program_id(0) + fid = tl.program_id(1).to(tl.int64) + n_focus = tl.num_programs(1) + + offs_m = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64) + m_mask = offs_m < n_edge + mm = m_mask[:, None] + nc = tl.arange(0, CP) + cm = mm & (nc < CF)[None, :] + + g_row = g_ptr + (offs_m * n_focus + fid) * ROW + o_row = out_ptr + (offs_m * n_focus + fid) * ROW + ga = tl.zeros((BLOCK_M,), dtype=tl.float32) + for r in tl.static_range(3 * L + 1): + g_r = tl.load(g_row[:, None] + (r * CF + nc)[None, :], mask=cm, other=0.0) + o_r = tl.load(o_row[:, None] + (r * CF + nc)[None, :], mask=cm, other=0.0) + ga += tl.sum(g_r * o_r, axis=1) + alpha = tl.load(alpha_ptr + offs_m * n_focus + fid, mask=m_mask, other=1.0) + tl.store( + ga_ptr + offs_m * n_focus + fid, + ga / tl.maximum(alpha, 1e-12), + mask=m_mask, + ) + + +# ====================================================================== +# Zero-edge guard and dispatch predicate +# ====================================================================== +def _has_no_edges(n_edge) -> bool: + """True only for eager zero-edge calls; never guards symbolic edge counts.""" + return type(n_edge) is int and n_edge == 0 + + +def _use_triton(tensor: Tensor) -> bool: + return ( + SO2_VALUE_PATH_TRITON_AVAILABLE + and tensor.is_cuda + and tensor.dtype is torch.float32 + ) + + +# ====================================================================== +# Operator implementations (Triton on CUDA fp32, eager reference otherwise) +# ====================================================================== +def _rotate_mix_impl( + x: Tensor, + src: Tensor, + wigner: Tensor, + kc: Tensor, + cb: Tensor, + lmax: int, + n_focus: int, + rank: int, +) -> Tensor: + if not _use_triton(x): + return _rotate_mix_reference(x, src, wigner, kc, cb, lmax, n_focus, rank) + n_edge = src.shape[0] + c_wide = int(x.shape[2]) + focus_dim = c_wide // int(n_focus) + row = (3 * int(lmax) + 1) * focus_dim + u = torch.empty(n_focus, n_edge, row, device=x.device, dtype=x.dtype) + if _has_no_edges(n_edge): + return u + warps, stages = rotate_mix_fwd_config(c_wide, int(lmax)) + wrap_triton(_rotate_mix_fwd_kernel)[(n_edge,)]( + x, + src, + wigner, + kc, + cb, + u, + n_edge, + x.stride(0), + x.stride(1), + L=int(lmax), + CF=focus_dim, + CW=c_wide, + BC=triton.next_power_of_2(c_wide), + RANK=int(rank), + num_warps=warps, + num_stages=stages, + ) + return u + + +def _rotate_mix_bwd_impl( + grad_u: Tensor, + x: Tensor, + src: Tensor, + wigner: Tensor, + kc: Tensor, + cb: Tensor, + lmax: int, + n_focus: int, + rank: int, +) -> tuple[Tensor, Tensor, Tensor]: + if not _use_triton(x): + return _rotate_mix_backward_reference( + grad_u, x, src, wigner, kc, cb, lmax, n_focus, rank + ) + n_edge = src.shape[0] + c_wide = int(x.shape[2]) + dim = (int(lmax) + 1) ** 2 + grad_x_edge = torch.empty(n_edge, dim, c_wide, device=x.device, dtype=x.dtype) + grad_wigner = torch.zeros_like(wigner) + grad_kc = torch.empty_like(kc) + if _has_no_edges(n_edge): + return grad_x_edge, grad_wigner, grad_kc + # The edge-block schedule engages on swept-and-winning (C_wide, lmax) + # keys (RANK <= 1 -- the block kernel folds a single channel basis); + # every other shape keeps the per-edge kernel. The branch resolves at + # trace time, so exactly one kernel reaches the compiled graph. + block_cfg = ( + rotate_mix_bwd_block_config(c_wide, int(lmax)) if int(rank) <= 1 else None + ) + if block_cfg is not None: + block_e, warps, stages = block_cfg + wrap_triton(_rotate_mix_bwd_block_kernel)[(triton.cdiv(n_edge, block_e),)]( + grad_u, + x, + src, + wigner, + kc, + cb, + grad_x_edge, + grad_wigner, + grad_kc, + n_edge, + x.stride(0), + x.stride(1), + L=int(lmax), + CF=c_wide // int(n_focus), + CW=c_wide, + CP=triton.next_power_of_2(c_wide), + RANK=int(rank), + BLOCK_E=block_e, + num_warps=warps, + num_stages=stages, + ) + return grad_x_edge, grad_wigner, grad_kc + warps, stages = _ROTATE_MIX_BWD_CONFIG + wrap_triton(_rotate_mix_bwd_kernel)[(n_edge,)]( + grad_u, + x, + src, + wigner, + kc, + cb, + grad_x_edge, + grad_wigner, + grad_kc, + n_edge, + x.stride(0), + x.stride(1), + L=int(lmax), + CF=c_wide // int(n_focus), + CW=c_wide, + BC=triton.next_power_of_2(c_wide), + RANK=int(rank), + num_warps=warps, + num_stages=stages, + ) + return grad_x_edge, grad_wigner, grad_kc + + +def _segment_sum_impl(rows: Tensor, order: Tensor, row_ptr: Tensor) -> Tensor: + n_rows = rows.shape[0] + n_seg = row_ptr.shape[0] - 1 + if not _use_triton(rows): + counts = row_ptr[1:] - row_ptr[:-1] + seg_of_sorted = torch.repeat_interleave( + torch.arange(n_seg, device=rows.device, dtype=order.dtype), counts + ) + out = rows.new_zeros((n_seg, rows.shape[1], rows.shape[2])) + out.index_add_(0, seg_of_sorted, rows.index_select(0, order)) + return out + out = torch.empty( + (n_seg, rows.shape[1], rows.shape[2]), device=rows.device, dtype=rows.dtype + ) + if _has_no_edges(n_rows): + return out.zero_() + per_row = int(rows.shape[1]) * int(rows.shape[2]) + block = 256 + wrap_triton(_segment_sum_kernel)[(n_seg, triton.cdiv(per_row, block))]( + rows, + order, + row_ptr, + out, + P=per_row, + BC=block, + num_warps=4, + num_stages=2, + ) + return out + + +def _mixing_stack_impl( + u0: Tensor, + alpha: Tensor, + w0_all: Tensor, + w1_all: Tensor, + gw_all: Tensor, + lmax: int, + focus_dim: int, + apply_alpha: bool, +) -> tuple[Tensor, Tensor]: + if not _use_triton(u0): + return _mixing_stack_reference( + u0, alpha, w0_all, w1_all, gw_all, lmax, focus_dim, apply_alpha + ) + n_focus, n_edge, row = u0.shape + lmax = int(lmax) + focus_dim = int(focus_dim) + n_gated = gw_all.shape[0] + z_all = torch.empty( + (n_gated, n_focus, n_edge, row), device=u0.device, dtype=u0.dtype + ) + x_local = torch.empty((n_edge, n_focus, row), device=u0.device, dtype=u0.dtype) + if _has_no_edges(n_edge): + return x_local, z_all + + block_m, block_n, block_k, warps, stages = _GEMM_CONFIG + m0 = (lmax + 1) * focus_dim + m1 = 2 * lmax * focus_dim + gate_bm, gate_w, gate_s = gate_config(focus_dim, lmax) + sig_by_bmm = focus_dim >= GATE_BMM_MIN_FOCUS_DIM + sig = torch.empty( + (n_focus, n_edge, lmax * focus_dim), device=u0.device, dtype=torch.float32 + ) + + u = u0 + for layer in range(n_gated): + out = torch.empty_like(u) + wrap_triton(_stack_gemm_m0_kernel)[ + (triton.cdiv(n_edge, block_m) * triton.cdiv(m0, block_n), n_focus) + ]( + u, + w0_all, + u, + z_all, + n_edge, + layer, + L=lmax, + CF=focus_dim, + EPILOGUE=0, + V_EDGE_MAJOR=False, + APPLY_ALPHA=False, + BLOCK_M=block_m, + BLOCK_N=block_n, + BLOCK_K=block_k, + num_warps=warps, + num_stages=stages, + ) + if sig_by_bmm: + # Wide-channel regime: sigmoid projection as a cuBLAS bmm on the + # freshly written l = 0 scalar rows of the pre-activation. + torch.sigmoid( + torch.bmm(z_all[layer, :, :, :focus_dim], gw_all[layer]), out=sig + ) + wrap_triton(_stack_gate_kernel)[(triton.cdiv(n_edge, gate_bm), n_focus)]( + u, + z_all, + gw_all, + out, + sig, + n_edge, + layer, + L=lmax, + CF=focus_dim, + SIG_IN=sig_by_bmm, + BLOCK_M=gate_bm, + num_warps=gate_w, + num_stages=gate_s, + ) + wrap_triton(_stack_gemm_m1_kernel)[ + (triton.cdiv(n_edge, block_m) * triton.cdiv(m1, block_n), n_focus) + ]( + u, + w1_all, + sig, + u, + out, + z_all, + n_edge, + layer, + L=lmax, + CF=focus_dim, + HAS_GATE=True, + V_EDGE_MAJOR=False, + APPLY_ALPHA=False, + SAVE_Z=True, + BLOCK_M=block_m, + BLOCK_N=block_n, + BLOCK_K=block_k, + num_warps=warps, + num_stages=stages, + ) + u = out + + # Final identity layer streams straight into the edge-major output layout. + wrap_triton(_stack_gemm_m0_kernel)[ + (triton.cdiv(n_edge, block_m) * triton.cdiv(m0, block_n), n_focus) + ]( + u, + w0_all, + alpha, + x_local, + n_edge, + n_gated, + L=lmax, + CF=focus_dim, + EPILOGUE=1, + V_EDGE_MAJOR=True, + APPLY_ALPHA=apply_alpha, + BLOCK_M=block_m, + BLOCK_N=block_n, + BLOCK_K=block_k, + num_warps=warps, + num_stages=stages, + ) + wrap_triton(_stack_gemm_m1_kernel)[ + (triton.cdiv(n_edge, block_m) * triton.cdiv(m1, block_n), n_focus) + ]( + u, + w1_all, + u, + alpha, + x_local, + u, + n_edge, + n_gated, + L=lmax, + CF=focus_dim, + HAS_GATE=False, + V_EDGE_MAJOR=True, + APPLY_ALPHA=apply_alpha, + SAVE_Z=False, + BLOCK_M=block_m, + BLOCK_N=block_n, + BLOCK_K=block_k, + num_warps=warps, + num_stages=stages, + ) + return x_local, z_all + + +def _mixing_stack_bwd_impl( + grad_out: Tensor, + x_local: Tensor, + z_all: Tensor, + alpha: Tensor, + w0t_all: Tensor, + w1t_all: Tensor, + gw_all: Tensor, + gwt_all: Tensor, + lmax: int, + focus_dim: int, + apply_alpha: bool, +) -> tuple[Tensor, Tensor]: + if not _use_triton(grad_out): + return _mixing_stack_backward_reference( + grad_out, + x_local, + z_all, + alpha, + w0t_all, + w1t_all, + gw_all, + gwt_all, + lmax, + focus_dim, + apply_alpha, + ) + n_gated, n_focus, n_edge, row = z_all.shape + lmax = int(lmax) + focus_dim = int(focus_dim) + device, dtype = grad_out.device, grad_out.dtype + grad_alpha = torch.empty((n_edge, n_focus), device=device, dtype=dtype) + grad_u0 = torch.empty((n_focus, n_edge, row), device=device, dtype=dtype) + if _has_no_edges(n_edge): + return grad_u0, grad_alpha + + block_m, block_n, block_k, warps, stages = _GEMM_CONFIG + m0 = (lmax + 1) * focus_dim + m1 = 2 * lmax * focus_dim + n_tiles = triton.cdiv(m0, block_n) + triton.cdiv(m1, block_n) + point_bm, point_w, point_s = point_config(focus_dim, lmax) + + # === Final layer: g = gz + gz @ W^T with gz = grad [* alpha] on the fly === + g_cur = torch.empty((n_focus, n_edge, row), device=device, dtype=dtype) + wrap_triton(_stack_gemm_bwd_kernel)[ + (triton.cdiv(n_edge, block_m) * n_tiles, n_focus) + ]( + grad_out, + grad_out, + w0t_all, + w1t_all, + alpha, + g_cur, + n_edge, + n_gated, + L=lmax, + CF=focus_dim, + G_EDGE_MAJOR=True, + FOLD_ALPHA=apply_alpha, + RES_IS_GZ=True, + BLOCK_M=block_m, + BLOCK_N=block_n, + BLOCK_K=block_k, + num_warps=warps, + num_stages=stages, + ) + if apply_alpha: + a_bm, a_w, a_s = gate_config(focus_dim, lmax) + wrap_triton(_stack_grad_alpha_kernel)[(triton.cdiv(n_edge, a_bm), n_focus)]( + grad_out, + x_local, + alpha, + grad_alpha, + n_edge, + L=lmax, + CF=focus_dim, + BLOCK_M=a_bm, + num_warps=a_w, + num_stages=a_s, + ) + + # === Gated layers in reverse; sig / gz buffers are reused across layers === + gate_width = lmax * focus_dim + sig = torch.empty((n_focus, n_edge, gate_width), device=device, dtype=torch.float32) + gz = torch.empty((n_focus, n_edge, row), device=device, dtype=dtype) + use_bmm = focus_dim >= GATE_BMM_MIN_FOCUS_DIM + glogit = ( + torch.empty((n_focus, n_edge, gate_width), device=device, dtype=torch.float32) + if use_bmm + else sig + ) + r_bm, r_w, r_s = recompute_config(focus_dim, lmax) + for layer in range(n_gated - 1, -1, -1): + if use_bmm: + torch.sigmoid( + torch.bmm(z_all[layer, :, :, :focus_dim], gw_all[layer]), out=sig + ) + else: + wrap_triton(_stack_recompute_kernel)[(triton.cdiv(n_edge, r_bm), n_focus)]( + z_all, + gw_all, + sig, + n_edge, + layer, + L=lmax, + CF=focus_dim, + BLOCK_M=r_bm, + num_warps=r_w, + num_stages=r_s, + ) + wrap_triton(_stack_point_bwd_kernel)[(triton.cdiv(n_edge, point_bm), n_focus)]( + g_cur, + z_all, + sig, + gwt_all, + gz, + glogit, + n_edge, + layer, + L=lmax, + CF=focus_dim, + GLOGIT_OUT=use_bmm, + BLOCK_M=point_bm, + num_warps=point_w, + num_stages=point_s, + ) + if use_bmm: + # Gate-logit contraction back to the scalar rows via cuBLAS. + gz[:, :, :focus_dim] += torch.bmm(glogit, gwt_all[layer]) + g_next = torch.empty((n_focus, n_edge, row), device=device, dtype=dtype) + wrap_triton(_stack_gemm_bwd_kernel)[ + (triton.cdiv(n_edge, block_m) * n_tiles, n_focus) + ]( + gz, + g_cur, + w0t_all, + w1t_all, + gz, + g_next, + n_edge, + layer, + L=lmax, + CF=focus_dim, + G_EDGE_MAJOR=False, + FOLD_ALPHA=False, + RES_IS_GZ=False, + BLOCK_M=block_m, + BLOCK_N=block_n, + BLOCK_K=block_k, + num_warps=warps, + num_stages=stages, + ) + g_cur = g_next + return g_cur, grad_alpha + + +# ====================================================================== +# Functional triton_op + fake + autograd registration +# ====================================================================== +_rotate_mix_op = torch.library.triton_op( + "sezm_triton::so2_rotate_mix", mutates_args=() +)(_rotate_mix_impl) +_rotate_mix_bwd_op = torch.library.triton_op( + "sezm_triton::so2_rotate_mix_bwd", mutates_args=() +)(_rotate_mix_bwd_impl) +_segment_sum_op = torch.library.triton_op("sezm_triton::segment_sum", mutates_args=())( + _segment_sum_impl +) +_mixing_stack_op = torch.library.triton_op( + "sezm_triton::so2_mixing_stack", mutates_args=() +)(_mixing_stack_impl) +_mixing_stack_bwd_op = torch.library.triton_op( + "sezm_triton::so2_mixing_stack_bwd", mutates_args=() +)(_mixing_stack_bwd_impl) + + +@_rotate_mix_op.register_fake +def _(x, src, wigner, kc, cb, lmax, n_focus, rank): + focus_dim = x.shape[2] // n_focus + return x.new_empty((n_focus, src.shape[0], (3 * lmax + 1) * focus_dim)) + + +@_rotate_mix_bwd_op.register_fake +def _(grad_u, x, src, wigner, kc, cb, lmax, n_focus, rank): + return ( + x.new_empty((src.shape[0], (lmax + 1) ** 2, x.shape[2])), + torch.empty_like(wigner), + torch.empty_like(kc), + ) + + +@_segment_sum_op.register_fake +def _(rows, order, row_ptr): + return rows.new_empty((row_ptr.shape[0] - 1, rows.shape[1], rows.shape[2])) + + +@_mixing_stack_op.register_fake +def _(u0, alpha, w0_all, w1_all, gw_all, lmax, focus_dim, apply_alpha): + n_focus, n_edge, row = u0.shape + return ( + u0.new_empty((n_edge, n_focus, row)), + u0.new_empty((gw_all.shape[0], n_focus, n_edge, row)), + ) + + +@_mixing_stack_bwd_op.register_fake +def _( + grad_out, + x_local, + z_all, + alpha, + w0t_all, + w1t_all, + gw_all, + gwt_all, + lmax, + focus_dim, + apply_alpha, +): + n_gated, n_focus, n_edge, row = z_all.shape + return ( + z_all.new_empty((n_focus, n_edge, row)), + z_all.new_empty((n_edge, n_focus)), + ) + + +def _rotate_mix_setup_context(ctx, inputs, output): + x, src, wigner, kc, cb, lmax, n_focus, rank = inputs + ctx.save_for_backward(x, src, wigner, kc, cb) + ctx.lmax = lmax + ctx.n_focus = n_focus + ctx.rank = rank + + +def _rotate_mix_backward(ctx, grad_u): + x, src, wigner, kc, cb = ctx.saved_tensors + grad_x_edge, grad_wigner, grad_kc = _rotate_mix_bwd_op( + grad_u.contiguous(), x, src, wigner, kc, cb, ctx.lmax, ctx.n_focus, ctx.rank + ) + # Contention-free segmented reduction of the per-edge node gradient; the + # integer topology (argsort + CSR offsets) traces as ordinary aten ops. + order = torch.argsort(src) + boundaries = torch.arange(x.shape[0] + 1, device=src.device, dtype=src.dtype) + row_ptr = torch.searchsorted(src.index_select(0, order), boundaries) + grad_x = _segment_sum_op(grad_x_edge, order, row_ptr) + return grad_x, None, grad_wigner, grad_kc, None, None, None, None + + +_rotate_mix_op.register_autograd( + _rotate_mix_backward, setup_context=_rotate_mix_setup_context +) + + +def _mixing_stack_setup_context(ctx, inputs, output): + u0, alpha, w0_all, w1_all, gw_all, lmax, focus_dim, apply_alpha = inputs + x_local, z_all = output + ctx.save_for_backward(alpha, x_local, z_all, w0_all, w1_all, gw_all) + ctx.lmax = lmax + ctx.focus_dim = focus_dim + ctx.apply_alpha = apply_alpha + + +def _mixing_stack_backward(ctx, grad_out, grad_z_unused): + alpha, x_local, z_all, w0_all, w1_all, gw_all = ctx.saved_tensors + grad_u0, grad_alpha = _mixing_stack_bwd_op( + grad_out.contiguous(), + x_local, + z_all, + alpha, + w0_all.transpose(2, 3).contiguous(), + w1_all.transpose(2, 3).contiguous(), + gw_all, + gw_all.transpose(2, 3).contiguous(), + ctx.lmax, + ctx.focus_dim, + ctx.apply_alpha, + ) + return ( + grad_u0, + grad_alpha if ctx.apply_alpha else None, + None, + None, + None, + None, + None, + None, + ) + + +_mixing_stack_op.register_autograd( + _mixing_stack_backward, setup_context=_mixing_stack_setup_context +) + + +# ====================================================================== +# Per-convolution entry point +# ====================================================================== +class _TritonSO2ValuePath: + """Per-convolution entry running the SO(2) value path through the fused ops. + + The call contract mirrors the reference ``so2_message(..., + return_local=True)``: it returns the post-focus-compete local features + ``(E, F, D_m, Cf)`` and the projected radial features whose ``l = 0`` + slice feeds the attention aggregation. + + The stacked weights are assembled from the live parameters on every call + and must not be cached across calls: the first call may run inside a + ``make_fx`` fake-tensor trace, where a cache would capture fake weights, + and eager weights may change when a checkpoint is loaded after + construction. The assembly is a short chain of parameter-only aten ops + that the compile pipeline constant-folds out of the hot path. + + At ``DP_TRITON_INFER >= 3`` the mixing stack runs through the fp16x3 + tensor-core operator when the ``(focus_dim, lmax)`` key carries a + validated configuration (see :mod:`.so2_stack_fp16x3`); the selection is + fixed at construction, so exactly one stack operator reaches the traced + graph. + """ + + def __init__(self, conv: SO2Convolution) -> None: + self._conv = conv + self._stack_op = _mixing_stack_op + if ( + conv.triton_infer_level >= 3 + and stack_fp16x3_configs(conv.so2_focus_dim, conv.lmax) is not None + ): + from .so2_stack_fp16x3 import ( + mixing_stack_fp16x3, + ) + + self._stack_op = mixing_stack_fp16x3 + + def _pack_weights(self) -> tuple[Tensor, Tensor, Tensor]: + """Stack the SO(2) block weights and gate projections per layer. + + Returns ``(w0_all, w1_all, gw_all)`` with shapes + ``(n_layers, F, M0, M0)``, ``(n_layers, F, M1, M1)`` and + ``(n_gated, F, Cf, lmax * Cf)``, all in the ``(in, out)`` convention. + """ + conv = self._conv + m0 = (conv.lmax + 1) * conv.so2_focus_dim + w0_list, w1_list, gw_list = [], [], [] + for layer, linear in enumerate(conv.so2_linears): + weight = ( + linear._build_so2_weight().detach().permute(1, 0, 2).contiguous() + ) # (F, D_m*Cf, D_m*Cf) + w0_list.append(weight[:, :m0, :m0]) + w1_list.append(weight[:, m0:, m0:]) + non_linear = conv.non_linearities[layer] + if type(non_linear).__name__ == "GatedActivation": + gw_list.append( + non_linear.gate_linear.weight.detach() + .view( + conv.so2_focus_dim, + conv.n_focus, + conv.lmax * conv.so2_focus_dim, + ) + .permute(1, 0, 2) + ) + return ( + torch.stack(w0_list).contiguous(), + torch.stack(w1_list).contiguous(), + torch.stack(gw_list).contiguous(), + ) + + def __call__( + self, + x: Tensor, + edge_cache: EdgeFeatureCache, + radial_feat: Tensor, + ) -> tuple[Tensor, Tensor]: + """Compute the SO(2) local features and radial features via the fused ops. + + Parameters + ---------- + x : Tensor + Node features with shape (N, D, C_wide). + edge_cache : EdgeFeatureCache + Precomputed edge cache (provides ``src`` and the Wigner ``D_full``). + radial_feat : Tensor + Per-edge radial features with shape (E, lmax+1, C). + + Returns + ------- + x_local : Tensor + Post-focus-compete local features with shape (E, F, D_m, Cf). + rad_feat : Tensor + Projected radial features with shape (E, lmax+1, C_wide); its + ``l = 0`` slice is consumed by the attention aggregation. The + degree-expanded ``(E, D_m, C_wide)`` layout of the reference path + is never materialized: the mixer projection and the mixer-free + multiply read only the ``lmax + 1`` per-degree rows. + """ + conv = self._conv + src = edge_cache.src + w0_all, w1_all, gw_all = self._pack_weights() + + # === Step 1. Radial features and the compact degree kernel === + if conv.radial_hidden_proj is not None: + rad_feat = conv.radial_hidden_proj(radial_feat) # (E, lmax+1, C_wide) + else: + rad_feat = radial_feat + mixer = conv.radial_degree_mixer + if mixer is None: + kc = rad_feat + cb = rad_feat.new_zeros(1) + rank = 0 + else: + kc = torch.matmul( + rad_feat.reshape(rad_feat.shape[0], -1), mixer.weight + ) # (E, degree_kernel_size * rank) + cb = mixer.channel_basis.reshape(-1) + rank = mixer.rank + + # === Step 2. Fused rotate-to-local + degree mixing (focus-major) === + u0 = _rotate_mix_op( + x.contiguous(), + src, + edge_cache.D_full, + kc.contiguous(), + cb.contiguous(), + conv.lmax, + conv.n_focus, + rank, + ) + + # === Step 3. Cross-focus competition weight from the l = 0 scalars === + apply_alpha = bool(conv.focus_compete and conv.n_focus > 1) + if apply_alpha: + # The small (E, F, Cf) copy keeps the softmax backward from + # retaining a view of the whole focus-major activation. + gate_src = u0[:, :, : conv.so2_focus_dim].permute(1, 0, 2).contiguous() + alpha = conv._focus_alpha(gate_src).to(u0.dtype).contiguous() + else: + alpha = torch.ones( + src.shape[0], conv.n_focus, device=u0.device, dtype=u0.dtype + ) + + # === Step 4. Fused mixing stack (identity layer stores edge-major) === + x_local, _ = self._stack_op( + u0, + alpha, + w0_all, + w1_all, + gw_all, + conv.lmax, + conv.so2_focus_dim, + apply_alpha, + ) + n_edge = src.shape[0] + reduced_dim = 3 * conv.lmax + 1 + return ( + x_local.view(n_edge, conv.n_focus, reduced_dim, conv.so2_focus_dim), + rad_feat, + ) + + +def _is_supported(conv: SO2Convolution) -> bool: + """Return whether ``conv`` matches the fused value-path configuration.""" + if ( + conv.mmax != 1 + or not 1 <= conv.lmax <= _MAX_LMAX + or conv.mixing_layers < 2 + or conv.so2_focus_dim not in _SUPPORTED_FOCUS_DIMS + or conv.node_wise_grid_product is not None + or conv.use_so2_attn_res + or conv.layer_scale + # Kernels accumulate in fp32; refuse other precisions rather than + # silently down-casting a double-precision model. + or conv.so2_linears[0].weight_m0.dtype is not torch.float32 + ): + return False + mixer = conv.radial_degree_mixer + if mixer is not None and ( + mixer.mode != "degree_channel" or not 1 <= mixer.rank <= _MAX_MIXER_RANK + ): + return False + if any(type(norm).__name__ != "Identity" for norm in conv.so2_inter_norms): + return False + if any(linear.bias0 is not None for linear in conv.so2_linears): + return False + if any( + linear.in_channels != conv.so2_focus_dim + or linear.out_channels != conv.so2_focus_dim + for linear in conv.so2_linears + ): + return False + non_linears = conv.non_linearities + if any( + type(non_linears[layer]).__name__ != "GatedActivation" + or ( + getattr(non_linears[layer].scalar_act, "activation", None) + or getattr(non_linears[layer], "activation_function", None) + ) + != "silu" + for layer in range(conv.mixing_layers - 1) + ): + return False + return type(non_linears[conv.mixing_layers - 1]).__name__ == "Identity" + + +def make_triton_value_path(conv: SO2Convolution) -> _TritonSO2ValuePath | None: + """Build the fused Triton value-path entry for a convolution block. + + Parameters + ---------- + conv : SO2Convolution + The convolution block to accelerate. + + Returns + ------- + _TritonSO2ValuePath or None + The entry callable when Triton is available and ``conv`` matches the + supported configuration (``mmax == 1``, ``lmax`` 1..6, focus width in + {32, 64, 96, 128}, gated stack with an identity final layer, radial + mixer absent or ``degree_channel`` with rank 1..4, fp32 weights); + otherwise ``None`` and the caller falls back to the reference path. + """ + if not SO2_VALUE_PATH_TRITON_AVAILABLE or not _is_supported(conv): + return None + return _TritonSO2ValuePath(conv) diff --git a/deepmd/kernels/triton/sezm/sweep_tile_configs.py b/deepmd/kernels/triton/sezm/sweep_tile_configs.py new file mode 100644 index 0000000000..02a080588f --- /dev/null +++ b/deepmd/kernels/triton/sezm/sweep_tile_configs.py @@ -0,0 +1,1283 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +# ruff: noqa: T201 +r"""Sweep the launch-configuration tables of the shape-tuned SeZM kernels. + +Every table served by :mod:`.tile_configs` is generated by this module. A +sweep measures the candidate launch configurations of one kernel family for +one shape key on synthetic tensors at production edge counts, applies the +family's win rule, and returns the resulting entries (a configuration +tuple, or ``None`` where the family default is the measured optimum). + +The module is used in three ways: + +1. *Freeze auto-tuning* -- :func:`tune_missing_configs` sweeps exactly the + shape keys a model needs that are absent from both the built-in tables + and the current process registrations, and registers the winners via + :func:`~.tile_configs.register_tile_configs`. The SeZM ``.pt2`` freeze + path calls this on the local GPU before tracing, so the frozen artifact + bakes launches tuned for the deployment hardware even when the GPU has + no built-in coverage. Keys whose sweep concluded "keep the default" + are registered as ``None`` and are not re-swept on later freezes. +2. *Manual sweeps* -- the CLI sweeps one explicit shape key + (``--cf/--lmax``) or every key a checkpoint needs (``--model``), and + prints the resulting entries as a Python fragment in the layout of + :mod:`.tile_config_data`, ready to be merged under a new GPU name when + extending the built-in tables. +3. *Table regeneration* -- rerunning a sweep after any kernel-body change; + see the mandatory-validation note below. + +Usage +----- +:: + + python -m deepmd.kernels.triton.sezm.sweep_tile_configs \\ + --cf CF --lmax LMAX [--kernels FAMILY[,FAMILY...]] [--focus F] + [--heads H] [--edges E] [--device cuda:0] + + python -m deepmd.kernels.triton.sezm.sweep_tile_configs \\ + --model model.pt [--level 3] [--edges E] [--device cuda:0] + +``--kernels`` selects the sweep groups (default: all): + +``pointwise`` + The gate / recompute / backward-pointwise kernels, keyed + ``(focus_dim, lmax)`` -> the ``gate`` / ``recompute`` / ``point`` + families. For ``cf >= GATE_BMM_MIN_FOCUS_DIM`` the gate projection + runs as a cuBLAS bmm (reported once for reference) and the recompute + kernel is not swept because it is never launched in that regime. +``rotate_fwd`` + The rotate+mix forward kernel, keyed ``(C_wide, lmax)`` -> + ``rotate_mix_fwd``; ``None`` records that the upstream default won. +``rotate_bwd`` + The edge-block rotate+mix backward against the per-edge kernel, keyed + ``(C_wide, lmax)`` -> ``rotate_mix_bwd_block``. The family is a win + list: the entry is a configuration only when the speedup exceeds the + 3% noise margin, otherwise ``None`` keeps the per-edge kernel. +``flash_bwd`` + The edge-block flash-attention backward against the per-edge kernel, + keyed ``(C_wide, lmax)`` -> ``flash_bwd_block``; the same win-list + rule applies. +``fp16x3`` + The four fp16x3 mixing-stack GEMMs, keyed ``(focus_dim, lmax)`` -> + ``stack_fp16x3``. Every candidate is ranked by standalone kernel time + and then validated end to end against an fp64 reference of the whole + stack operator before it may win; ``None`` records that no candidate + validated and the fp32 stack stays in charge. + +Interpretation guidance +----------------------- +The winning ``BLOCK_M`` of the pointwise kernels and the winning ``BLOCK_E`` +of the edge-block kernels shrink monotonically as the register-pressure +products ``lmax * next_power_of_2(cf)`` and ``lmax * next_power_of_2(C_wide)`` +grow; a candidate on the wrong side of the spill point can be an order of +magnitude slower, which is why the tables are exact-keyed rather than +heuristic. For the fp32 kernels, tile choices never affect numerical +results, so a sweep only ever changes speed. Winning configurations are +insensitive to the edge count once the device is saturated (the defaults +sweep at 6.5e5 edges); small systems are launch-bound and insensitive to +the choice altogether. + +fp16x3 validation is mandatory +------------------------------ +Some ``(num_warps, num_stages)`` combinations of the fp16x3 three-``tl.dot`` +k-loop are miscompiled by the Triton software pipeliner into silent NaN rows. +The affected set shifts both with any change to the kernel body *and with the +edge count*: a config finite at one count can produce NaN at another, because +the miscompilation follows the launch grid rather than the data. The fp16x3 +sweep therefore validates every candidate through the full operator (forward +and force backward) in two ways before timing may crown it: an fp64 accuracy +check at a fixed count, and a finiteness check across an intermediate spread +of edge counts (``_FP16X3_FINITE_EDGES``) plus the main sweep count. A +candidate that is inaccurate or non-finite at any of these is skipped. The +edge-count spread reduces but cannot eliminate the risk (no finite sample +certifies every count), which is why ``num_stages == 1`` -- the +pipeliner-free family that is structurally NaN-free at any count -- stays in +the candidate grid. Writing a hand-picked configuration into the tables +without this validation risks silently corrupted inference. Sweeps at +``lmax >= 5`` compile deeply unrolled kernels and can take several minutes per +candidate; the reduced candidate grid keeps this tractable. +""" + +from __future__ import ( + annotations, +) + +import argparse +import itertools +import logging +from typing import ( + TYPE_CHECKING, + Any, +) + +import torch +import triton +from torch.library import ( + wrap_triton, +) + +if TYPE_CHECKING: + from collections.abc import ( + Callable, + ) + +from deepmd.kernels.triton.sezm.flash_atten import ( + _flash_bwd_block_kernel, + _flash_bwd_op, +) +from deepmd.kernels.triton.sezm.so2_stack_fp16x3 import ( + _mixing_stack_fp16x3_op, + _split_fp16, + _stack_fp16x3_bwd_kernel, + _stack_fp16x3_m0_kernel, + _stack_fp16x3_m1_kernel, +) +from deepmd.kernels.triton.sezm.so2_value_path import ( + _mixing_stack_op, + _mixing_stack_reference, + _rotate_mix_bwd_block_kernel, + _rotate_mix_bwd_op, + _rotate_mix_fwd_kernel, + _stack_gate_kernel, + _stack_point_bwd_kernel, + _stack_recompute_kernel, +) +from deepmd.kernels.triton.sezm.tile_configs import ( + GATE_BMM_MIN_FOCUS_DIM, + _runtime_tile_configs, + has_tile_config, + register_tile_configs, +) + +__all__ = [ + "collect_model_shape_keys", + "tune_missing_configs", +] + +log = logging.getLogger(__name__) + +_BLOCK_M_CANDIDATES = (8, 16, 32, 64) +_WARP_CANDIDATES = (4, 8, 16) +_STAGE_CANDIDATES = (1, 2) + +# Edge-block winners never exceed BLOCK_E = 8 on any swept width (the +# per-edge register tile scales with BLOCK_E * C_wide, so larger blocks +# spill first), and every excluded candidate still costs a full Triton +# compilation -- minutes apiece on wide channels at high lmax. +_EDGE_BLOCK_CANDIDATES = tuple(itertools.product((2, 4, 8), (2, 4, 8), (1, 2))) +_ROTATE_FWD_DEFAULT = (2, 2) +_ROTATE_FWD_CANDIDATES = ( + (1, 1), + (1, 2), + (2, 1), + _ROTATE_FWD_DEFAULT, + (4, 1), + (4, 2), + (8, 2), +) +_FP16X3_GEMM_CANDIDATES = tuple( + itertools.product((64, 128), (64, 128), (32, 64), (4, 8), (1, 2, 3)) +) +# Whole-op pin candidates used to isolate one fp16x3 kernel during +# validation; the first combination that validates serves as the pin. +_FP16X3_PIN_CANDIDATES = ( + (64, 64, 32, 4, 1), + (32, 64, 32, 4, 1), + (64, 64, 64, 4, 1), + (128, 64, 32, 8, 1), + (64, 64, 32, 8, 2), +) +# Validated end-to-end error budgets, relative to the fp32 reference error +# on identical data (the fp32 rounding itself grows with the reduction +# width, so absolute thresholds mis-fire at wide shapes). +_FP16X3_TOL_FACTOR = 3.0 +_FP16X3_CHECK_EDGES = 8192 +# Edge counts at which every candidate's finiteness is re-checked (forward +# and backward), in addition to the accuracy check and the finiteness at the +# main sweep count. The Triton software pipeliner can miscompile the +# three-``tl.dot`` k-loop into silent NaN at *some* edge counts while +# remaining finite at others -- the affected set shifts with the launch grid, +# so a single-edge-count check does not certify a config for the arbitrary +# edge counts an MD run presents. This intermediate spread (below the main +# count, which is checked separately) samples the band where the +# miscompilation was observed; it reduces rather than eliminates the risk, so +# ``num_stages == 1`` configs (pipeliner off, structurally NaN-free at any +# edge count) remain in the candidate grid as the safe fallback. +_FP16X3_FINITE_EDGES = (4096, 20000, 65537, 131072) + +# Win margin of the edge-block families against the per-edge kernels; a +# speedup below this noise floor keeps the per-edge kernel. +_WIN_MARGIN = 1.03 + +_DEFAULT_EDGES = 650000 + + +def _saturating_edges(width: int) -> int: + """Return a device-saturating edge count for the given channel width. + + Winning configurations are insensitive to the edge count once the device + is saturated (measured winners are stable down to ~2e5 edges and start + to drift below ~5e4), and the per-edge tensor footprint grows linearly + with the channel width, so the sweep scales the synthetic edge count + inversely with the width to keep peak memory roughly constant. The + reference point is 6.5e5 edges at width 64 -- the production scale the + built-in tables were swept at. + + The count is additionally capped so the sweep's peak allocation stays + within roughly a third of the device's total memory, keeping the freeze + auto-tuner viable on small-memory GPUs and beside a resident model. + The per-edge cost model (~160 KB per edge at width 64, scaling with the + width) is the measured peak of the heaviest group -- the fp16x3 sweep, + whose ranking buffers, fp64 validation reference and autograd graphs + dwarf the synthetic tensors themselves. + + The memory budget always wins over tuning quality: measured winners are + stable down to ~2e5 edges and drift only into neighboring near-optimal + candidates below that (a few percent of kernel time, never correctness + -- fp16x3 validity comes from the fp64 check, which is edge-count + independent), so a reduced count merely costs a sliver of speed. A + sweep below 1e5 edges logs a notice to that effect; the hard floor of + 2e4 edges only keeps the timing signal above launch overhead. + """ + edges = _DEFAULT_EDGES * 64 // max(width, 64) + if torch.cuda.is_available(): + total = torch.cuda.get_device_properties( + torch.cuda.current_device() + ).total_memory + bytes_per_edge = 160000 * max(width, 64) // 64 + cap = max(int(total / 3) // bytes_per_edge, 20000) + if cap < min(edges, 100000): + log.info( + "Sweep edge count capped at %d by the device memory budget " + "(device total %.1f GB); winning configurations may be a few " + "percent off the saturated optimum.", + cap, + total / 2**30, + ) + edges = min(edges, cap) + return edges + + +# Entries returned by one sweep group: family name -> {shape key: entry}. +SweepResult = dict[str, dict[tuple[int, int], "tuple | None"]] + + +def _bench(fn: Callable[[], object], iters: int = 20) -> float: + """Return the mean kernel time of ``fn`` in milliseconds.""" + for _ in range(3): + fn() + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + stop = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(iters): + fn() + stop.record() + torch.cuda.synchronize() + return start.elapsed_time(stop) / iters + + +def _relerr(a: torch.Tensor, b: torch.Tensor) -> float: + """Return the max relative error of ``a`` against reference ``b``.""" + return float((a - b).abs().max() / b.abs().max().clamp_min(1e-30)) + + +def _block_diag_wigner(n_edge: int, lmax: int, device: torch.device) -> torch.Tensor: + """Return a random block-diagonal Wigner-D stand-in with shape (E, D, D).""" + dim = (lmax + 1) ** 2 + wigner = torch.zeros(n_edge, dim, dim, device=device, dtype=torch.float32) + for degree in range(lmax + 1): + base, size = degree * degree, 2 * degree + 1 + wigner[:, base : base + size, base : base + size] = torch.randn( + n_edge, size, size, device=device + ) + return wigner + + +# ====================================================================== +# Pointwise group (gate / recompute / backward pointwise) +# ====================================================================== +def sweep_pointwise( + cf: int, + lmax: int, + *, + n_focus: int = 2, + n_edge: int | None = None, + device: torch.device | str = "cuda", +) -> SweepResult: + """Sweep the gate / recompute / backward-pointwise launch triples. + + Parameters + ---------- + cf : int + Per-focus channel width ``Cf``. + lmax : int + Maximum spherical harmonic degree. + n_focus : int + Focus count of the synthetic tensors; the winners are valid for any + focus count (the focus stream rides the grid batch axis). + n_edge : int + Edge count of the synthetic tensors. + device : torch.device or str + CUDA device to sweep on. + + Returns + ------- + SweepResult + Entries for the ``gate``, ``point`` and (below the bmm regime) + ``recompute`` families under the key ``(cf, lmax)``. + """ + device = torch.device(device) + if n_edge is None: + n_edge = _saturating_edges(cf) + row = (3 * lmax + 1) * cf + gate_width = lmax * cf + use_bmm = cf >= GATE_BMM_MIN_FOCUS_DIM + + u = torch.randn(n_focus, n_edge, row, device=device) + z_all = torch.randn(1, n_focus, n_edge, row, device=device) + gw_all = torch.randn(1, n_focus, cf, gate_width, device=device) * 0.05 + gwt_all = gw_all.transpose(2, 3).contiguous() + v = torch.empty_like(u) + sig = torch.rand(n_focus, n_edge, gate_width, device=device, dtype=torch.float32) + grad = torch.randn(n_focus, n_edge, row, device=device) + gz = torch.empty_like(grad) + glogit = torch.empty_like(sig) if use_bmm else sig + + kernels = { + "gate": lambda bm, w, s: _stack_gate_kernel[(triton.cdiv(n_edge, bm), n_focus)]( + u, + z_all, + gw_all, + v, + sig, + n_edge, + 0, + L=lmax, + CF=cf, + SIG_IN=use_bmm, + BLOCK_M=bm, + num_warps=w, + num_stages=s, + ), + "recompute": lambda bm, w, s: _stack_recompute_kernel[ + (triton.cdiv(n_edge, bm), n_focus) + ]( + z_all, + gw_all, + sig, + n_edge, + 0, + L=lmax, + CF=cf, + BLOCK_M=bm, + num_warps=w, + num_stages=s, + ), + "point": lambda bm, w, s: _stack_point_bwd_kernel[ + (triton.cdiv(n_edge, bm), n_focus) + ]( + grad, + z_all, + sig, + gwt_all, + gz, + glogit, + n_edge, + 0, + L=lmax, + CF=cf, + GLOGIT_OUT=use_bmm, + BLOCK_M=bm, + num_warps=w, + num_stages=s, + ), + } + if use_bmm: + projection_ms = _bench( + lambda: torch.sigmoid(torch.bmm(z_all[0, :, :, :cf], gw_all[0]), out=sig) + ) + print(f"[bmm sigmoid projection: {projection_ms:.3f} ms]") + kernels.pop("recompute") + + result: SweepResult = {} + for family, launch in kernels.items(): + best_ms, best_cfg = float("inf"), None + for bm, w, s in itertools.product( + _BLOCK_M_CANDIDATES, _WARP_CANDIDATES, _STAGE_CANDIDATES + ): + try: + ms = _bench(lambda: launch(bm, w, s)) + except triton.runtime.errors.OutOfResources: + print(f" BM={bm:3d} warps={w:2d} stages={s}: out of resources") + continue + marker = "" + if ms < best_ms: + best_ms, best_cfg = ms, (bm, w, s) + marker = " <-" + print(f" BM={bm:3d} warps={w:2d} stages={s}: {ms:8.3f} ms{marker}") + print(f"BEST {family}[({cf}, {lmax})] = {best_cfg} # {best_ms:.3f} ms") + result[family] = {(cf, lmax): best_cfg} + return result + + +# ====================================================================== +# Rotate+mix forward +# ====================================================================== +def sweep_rotate_fwd( + cf: int, + lmax: int, + *, + n_focus: int = 2, + n_edge: int | None = None, + device: torch.device | str = "cuda", +) -> SweepResult: + """Sweep the rotate+mix forward ``(num_warps, num_stages)`` pair. + + Returns the ``rotate_mix_fwd`` entry under the key + ``(n_focus * cf, lmax)``; ``None`` records that the upstream default + ``(2, 2)`` is the measured optimum. + """ + device = torch.device(device) + c_wide = n_focus * cf + if n_edge is None: + n_edge = _saturating_edges(c_wide) + row = (3 * lmax + 1) * cf + kernel_size = (lmax + 1) ** 2 + lmax**2 + n_nodes = max(n_edge // 128, 1) + + x = torch.randn(n_nodes, (lmax + 1) ** 2, c_wide, device=device) + src = torch.randint(0, n_nodes, (n_edge,), device=device) + wigner = _block_diag_wigner(n_edge, lmax, device) + kc = torch.randn(n_edge, kernel_size, device=device) + cb = torch.randn(1, c_wide, device=device) + u = torch.empty(n_focus, n_edge, row, device=device, dtype=torch.float32) + + def launch(warps: int, stages: int) -> None: + _rotate_mix_fwd_kernel[(n_edge,)]( + x, + src, + wigner, + kc, + cb, + u, + n_edge, + x.stride(0), + x.stride(1), + L=lmax, + CF=cf, + CW=c_wide, + BC=triton.next_power_of_2(c_wide), + RANK=1, + num_warps=warps, + num_stages=stages, + ) + + best_ms, best_cfg = float("inf"), None + for warps, stages in _ROTATE_FWD_CANDIDATES: + try: + ms = _bench(lambda: launch(warps, stages)) + except triton.runtime.errors.OutOfResources: + continue + marker = "" + if ms < best_ms: + best_ms, best_cfg = ms, (warps, stages) + marker = " <-" + print(f" warps={warps:2d} stages={stages}: {ms:8.3f} ms{marker}") + entry = None if best_cfg == _ROTATE_FWD_DEFAULT else best_cfg + print( + f"BEST rotate_mix_fwd[({c_wide}, {lmax})] = {entry} # {best_ms:.3f} ms" + + (" (upstream default)" if entry is None else "") + ) + return {"rotate_mix_fwd": {(c_wide, lmax): entry}} + + +# ====================================================================== +# Edge-block backward groups (win lists against the per-edge kernels) +# ====================================================================== +def _win_list_entry( + family: str, key: tuple[int, int], best: tuple[float, tuple] | None, base_ms: float +) -> tuple | None: + """Apply the win rule: record the entry only above the 3% margin.""" + if best is None: + print(f"BEST {family}[{key}]: no valid candidate; keep the per-edge kernel") + return None + speedup = base_ms / best[0] + verdict = "RECORD" if speedup >= _WIN_MARGIN else "keep the per-edge kernel" + print( + f"BEST {family}[{key}] = {best[1]} # {best[0]:.3f} ms vs per-edge " + f"{base_ms:.3f} ms ({speedup:.2f}x) -> {verdict}" + ) + return best[1] if speedup >= _WIN_MARGIN else None + + +def sweep_rotate_bwd( + cf: int, + lmax: int, + *, + n_focus: int = 2, + n_edge: int | None = None, + device: torch.device | str = "cuda", +) -> SweepResult: + """Sweep the edge-block rotate+mix backward against the per-edge kernel. + + Returns the ``rotate_mix_bwd_block`` win-list entry under the key + ``(n_focus * cf, lmax)``. + """ + device = torch.device(device) + c_wide = n_focus * cf + if n_edge is None: + n_edge = _saturating_edges(c_wide) + dim = (lmax + 1) ** 2 + row = (3 * lmax + 1) * cf + kernel_size = (lmax + 1) ** 2 + lmax**2 + n_nodes = max(n_edge // 128, 1) + + grad_u = torch.randn(n_focus, n_edge, row, device=device) + x = torch.randn(n_nodes, dim, c_wide, device=device) + src = torch.randint(0, n_nodes, (n_edge,), device=device) + wigner = _block_diag_wigner(n_edge, lmax, device) + kc = torch.randn(n_edge, kernel_size, device=device) + cb = torch.randn(1, c_wide, device=device) + + reference = _rotate_mix_bwd_op(grad_u, x, src, wigner, kc, cb, lmax, n_focus, 1) + base_ms = _bench( + lambda: _rotate_mix_bwd_op(grad_u, x, src, wigner, kc, cb, lmax, n_focus, 1), + iters=8, + ) + + def launch(block_e: int, warps: int, stages: int) -> tuple[torch.Tensor, ...]: + gxe = torch.empty(n_edge, dim, c_wide, device=device, dtype=torch.float32) + gw = torch.zeros_like(wigner) + gkc = torch.empty_like(kc) + wrap_triton(_rotate_mix_bwd_block_kernel)[(triton.cdiv(n_edge, block_e),)]( + grad_u, + x, + src, + wigner, + kc, + cb, + gxe, + gw, + gkc, + n_edge, + x.stride(0), + x.stride(1), + L=lmax, + CF=cf, + CW=c_wide, + CP=triton.next_power_of_2(c_wide), + RANK=1, + BLOCK_E=block_e, + num_warps=warps, + num_stages=stages, + ) + return gxe, gw, gkc + + best = None + for cfg in _EDGE_BLOCK_CANDIDATES: + try: + outputs = launch(*cfg) + torch.cuda.synchronize() + err = max(_relerr(o, r) for o, r in zip(outputs, reference)) + if err > 5e-6: + continue + ms = _bench(lambda: launch(*cfg), iters=8) + except triton.runtime.errors.OutOfResources: + continue + if best is None or ms < best[0]: + best = (ms, cfg) + print(f" BE={cfg[0]:3d} warps={cfg[1]} stages={cfg[2]}: {ms:8.3f} ms <-") + key = (c_wide, lmax) + return { + "rotate_mix_bwd_block": { + key: _win_list_entry("rotate_mix_bwd_block", key, best, base_ms) + } + } + + +def sweep_flash_bwd( + cf: int, + lmax: int, + *, + n_focus: int = 2, + n_head: int = 1, + n_edge: int | None = None, + device: torch.device | str = "cuda", +) -> SweepResult: + """Sweep the edge-block flash-attention backward against the per-edge kernel. + + Returns the ``flash_bwd_block`` win-list entry under the key + ``(n_focus * cf, lmax)``. ``n_head`` only specializes the kernel binary + through a ``constexpr``; the winning schedule is shared across head + counts of the same width. + """ + device = torch.device(device) + c_wide = n_focus * cf + if n_edge is None: + n_edge = _saturating_edges(c_wide) + dim = (lmax + 1) ** 2 + reduced_dim = 3 * lmax + 1 + n_nodes = max(n_edge // 128, 1) + + grad_pre_gate = torch.randn(n_nodes, dim, c_wide, device=device) + x_local = torch.randn(n_edge, n_focus, reduced_dim, cf, device=device) + wigner_dt = _block_diag_wigner(n_edge, lmax, device) + rescale = torch.rand(dim, device=device, dtype=torch.float32) + 0.5 + alpha = torch.rand(n_edge, n_focus, n_head, device=device, dtype=torch.float32) + dst = torch.randint(0, n_nodes, (n_edge,), device=device) + + reference = _flash_bwd_op( + grad_pre_gate, x_local, wigner_dt, rescale, alpha, dst, lmax, n_head + ) + base_ms = _bench( + lambda: _flash_bwd_op( + grad_pre_gate, x_local, wigner_dt, rescale, alpha, dst, lmax, n_head + ), + iters=8, + ) + + def launch(block_e: int, warps: int, stages: int) -> tuple[torch.Tensor, ...]: + gxl = torch.empty_like(x_local) + gdt = torch.zeros_like(wigner_dt) + gw = torch.empty_like(alpha) + wrap_triton(_flash_bwd_block_kernel)[(triton.cdiv(n_edge, block_e),)]( + grad_pre_gate, + x_local, + wigner_dt, + rescale, + alpha, + dst, + gxl, + gdt, + gw, + n_edge, + grad_pre_gate.stride(0), + grad_pre_gate.stride(1), + x_local.stride(0), + x_local.stride(1), + x_local.stride(2), + x_local.stride(3), + gxl.stride(0), + gxl.stride(1), + gxl.stride(2), + gxl.stride(3), + L=lmax, + CF=cf, + CW=c_wide, + CP=triton.next_power_of_2(c_wide), + HEAD_DIM=cf // n_head, + NHEAD=n_head, + BLOCK_E=block_e, + num_warps=warps, + num_stages=stages, + ) + return gxl, gdt, gw + + best = None + for cfg in _EDGE_BLOCK_CANDIDATES: + try: + outputs = launch(*cfg) + torch.cuda.synchronize() + err = max(_relerr(o, r) for o, r in zip(outputs, reference)) + if err > 5e-6: + continue + ms = _bench(lambda: launch(*cfg), iters=8) + except triton.runtime.errors.OutOfResources: + continue + if best is None or ms < best[0]: + best = (ms, cfg) + print(f" BE={cfg[0]:3d} warps={cfg[1]} stages={cfg[2]}: {ms:8.3f} ms <-") + key = (c_wide, lmax) + return { + "flash_bwd_block": {key: _win_list_entry("flash_bwd_block", key, best, base_ms)} + } + + +# ====================================================================== +# fp16x3 mixing-stack GEMMs (fp64-validated) +# ====================================================================== +def sweep_fp16x3( + cf: int, + lmax: int, + *, + n_focus: int = 2, + n_edge: int | None = None, + device: torch.device | str = "cuda", +) -> SweepResult: + """Sweep and validate the four fp16x3 stack GEMM configurations. + + Returns the ``stack_fp16x3`` entry under the key ``(cf, lmax)``: + the fastest fp64-validated combination, or ``None`` when no candidate + validates (the fp32 stack stays in charge). Validation runs through + the production operator, so trial configurations are registered into + the runtime table during the sweep; the final conclusion is registered + unconditionally on exit. + """ + device = torch.device(device) + if n_edge is None: + n_edge = _saturating_edges(cf) + n_layers = 3 + m0 = (lmax + 1) * cf + m1 = 2 * lmax * cf + half = lmax * cf + row = (3 * lmax + 1) * cf + + u0 = torch.randn(n_focus, n_edge, row, device=device) + alpha = torch.rand(n_edge, n_focus, device=device, dtype=torch.float32) + 0.1 + w0_all = torch.randn(n_layers, n_focus, m0, m0, device=device) * 0.2 + block_u = torch.randn(n_layers, n_focus, half, half, device=device) * 0.2 + block_v = torch.randn(n_layers, n_focus, half, half, device=device) * 0.2 + # The |m| = 1 weight carries the [[U, V], [-V, U]] complex structure of + # SO2Linear so the synthetic stack matches the production operator. + w1_all = torch.cat( + [ + torch.cat([block_u, block_v], dim=3), + torch.cat([-block_v, block_u], dim=3), + ], + dim=2, + ).contiguous() + gw_all = torch.randn(n_layers - 1, n_focus, cf, half, device=device) * 0.3 + w0h, w0l = _split_fp16(w0_all) + w1h, w1l = _split_fp16(w1_all) + w0t = w0_all.transpose(2, 3).contiguous() + w1t = w1_all.transpose(2, 3).contiguous() + w0th, w0tl = _split_fp16(w0t) + w1th, w1tl = _split_fp16(w1t) + z_all = torch.empty( + n_layers - 1, n_focus, n_edge, row, device=device, dtype=torch.float32 + ) + out = torch.empty_like(u0) + sig = torch.full((n_focus, n_edge, half), 0.5, device=device, dtype=torch.float32) + gz = torch.randn(n_focus, n_edge, row, device=device) + g_res = torch.randn(n_focus, n_edge, row, device=device) + gu = torch.empty_like(u0) + + def launch_m0(bm: int, bn: int, bk: int, warps: int, stages: int) -> None: + wrap_triton(_stack_fp16x3_m0_kernel)[ + (triton.cdiv(n_edge, bm) * triton.cdiv(m0, bn), n_focus) + ]( + u0, + w0h, + w0l, + alpha, + z_all, + n_edge, + 0, + L=lmax, + CF=cf, + EPILOGUE=0, + V_EDGE_MAJOR=False, + APPLY_ALPHA=False, + BLOCK_M=bm, + BLOCK_N=bn, + BLOCK_K=bk, + num_warps=warps, + num_stages=stages, + ) + + def launch_m1(bm: int, bn: int, bk: int, warps: int, stages: int) -> None: + wrap_triton(_stack_fp16x3_m1_kernel)[ + (triton.cdiv(n_edge, bm) * triton.cdiv(m1, bn), n_focus) + ]( + u0, + w1h, + w1l, + sig, + alpha, + out, + z_all, + n_edge, + 0, + L=lmax, + CF=cf, + HAS_GATE=True, + V_EDGE_MAJOR=False, + APPLY_ALPHA=False, + SAVE_Z=True, + BLOCK_M=bm, + BLOCK_N=bn, + BLOCK_K=bk, + num_warps=warps, + num_stages=stages, + ) + + def launch_bwd( + is_m1: bool, bm: int, bn: int, bk: int, warps: int, stages: int + ) -> None: + width = m1 if is_m1 else m0 + wh, wl = (w1th, w1tl) if is_m1 else (w0th, w0tl) + wrap_triton(_stack_fp16x3_bwd_kernel)[ + (triton.cdiv(n_edge, bm) * triton.cdiv(width, bn), n_focus) + ]( + gz, + g_res, + wh, + wl, + alpha, + gu, + n_edge, + 0, + L=lmax, + CF=cf, + IS_M1=is_m1, + G_EDGE_MAJOR=False, + FOLD_ALPHA=False, + RES_IS_GZ=False, + BLOCK_M=bm, + BLOCK_N=bn, + BLOCK_K=bk, + num_warps=warps, + num_stages=stages, + ) + + launchers = { + 0: launch_m0, + 1: launch_m1, + 2: lambda *cfg: launch_bwd(False, *cfg), + 3: lambda *cfg: launch_bwd(True, *cfg), + } + names = ("forward m0", "forward |m|=1", "backward m0", "backward |m|=1") + + # === Step 1. Rank every candidate per kernel by standalone time === + ranked: dict[int, list[tuple[float, tuple]]] = {} + for slot, launch in launchers.items(): + results = [] + for cfg in _FP16X3_GEMM_CANDIDATES: + try: + ms = _bench(lambda: launch(*cfg), iters=10) + except triton.runtime.errors.OutOfResources: + continue + results.append((ms, cfg)) + results.sort() + ranked[slot] = results + + # === Step 2. fp64 whole-op reference and validation harness === + n_check = min(_FP16X3_CHECK_EDGES, n_edge) + truth_x, _ = _mixing_stack_reference( + u0[:, :n_check].double(), + alpha[:n_check].double(), + w0_all.double(), + w1_all.double(), + gw_all.double(), + lmax, + cf, + True, + ) + truth_x = truth_x.float() + u0_ref = u0[:, :n_check].double().requires_grad_(True) + alpha_ref = alpha[:n_check].double().requires_grad_(True) + x_ref, _ = _mixing_stack_reference( + u0_ref, + alpha_ref, + w0_all.double(), + w1_all.double(), + gw_all.double(), + lmax, + cf, + True, + ) + grad_seed = torch.randn(n_check, n_focus, row, device=device) + gu_ref, _ = torch.autograd.grad(x_ref, [u0_ref, alpha_ref], grad_seed.double()) + gu_ref = gu_ref.float() + grad_pad = torch.zeros(n_edge, n_focus, row, device=device, dtype=torch.float32) + grad_pad[:n_check] = grad_seed + + # Pre-built inputs for the cross-edge-count finiteness check (see + # ``_FP16X3_FINITE_EDGES``). Weights are edge-count independent, so only + # the activation and the backward seed are rebuilt per count; counts at or + # above the main sweep count are covered by ``whole_op_errors`` and skipped. + finite_edges = [m for m in _FP16X3_FINITE_EDGES if m < n_edge] + finite_inputs = [ + ( + torch.randn(n_focus, m, row, device=device), + torch.rand(m, n_focus, device=device, dtype=torch.float32) + 0.1, + torch.randn(m, n_focus, row, device=device), + ) + for m in finite_edges + ] + + def whole_op_errors() -> tuple[float, float, bool]: + u0_run = u0.clone().requires_grad_(True) + alpha_run = alpha.clone().requires_grad_(True) + x_run, z_run = _mixing_stack_fp16x3_op( + u0_run, alpha_run, w0_all, w1_all, gw_all, lmax, cf, True + ) + finite = bool(torch.isfinite(x_run).all()) and bool(torch.isfinite(z_run).all()) + gu_run, _ = torch.autograd.grad(x_run, [u0_run, alpha_run], grad_pad) + finite = finite and bool(torch.isfinite(gu_run).all()) + return ( + _relerr(x_run[:n_check], truth_x), + _relerr(gu_run[:, :n_check], gu_ref), + finite, + ) + + def finite_across_edges() -> bool: + """Forward + backward finiteness of the installed config at the spread.""" + for u0_m, alpha_m, grad_m in finite_inputs: + u0_r = u0_m.clone().requires_grad_(True) + alpha_r = alpha_m.clone().requires_grad_(True) + x_m, z_m = _mixing_stack_fp16x3_op( + u0_r, alpha_r, w0_all, w1_all, gw_all, lmax, cf, True + ) + if not ( + bool(torch.isfinite(x_m).all()) and bool(torch.isfinite(z_m).all()) + ): + return False + gu_m, _ = torch.autograd.grad(x_m, [u0_r, alpha_r], grad_m) + if not bool(torch.isfinite(gu_m).all()): + return False + return True + + u0_fp32 = u0.clone().requires_grad_(True) + alpha_fp32 = alpha.clone().requires_grad_(True) + x_fp32, _ = _mixing_stack_op( + u0_fp32, alpha_fp32, w0_all, w1_all, gw_all, lmax, cf, True + ) + fp32_fwd_err = _relerr(x_fp32[:n_check], truth_x) + gu_fp32, _ = torch.autograd.grad(x_fp32, [u0_fp32, alpha_fp32], grad_pad) + fp32_bwd_err = _relerr(gu_fp32[:, :n_check], gu_ref) + tol_fwd = max(_FP16X3_TOL_FACTOR * fp32_fwd_err, 2e-6) + tol_bwd = max(_FP16X3_TOL_FACTOR * fp32_bwd_err, 8e-6) + print( + f"[fp32 reference errors: fwd {fp32_fwd_err:.2e}, bwd {fp32_bwd_err:.2e}; " + f"tolerances fwd {tol_fwd:.2e}, bwd {tol_bwd:.2e}]" + ) + + def validate() -> bool: + err_fwd, err_bwd, finite = whole_op_errors() + if not (finite and err_fwd < tol_fwd and err_bwd < tol_bwd): + return False + # Accuracy and main-count finiteness hold; require finiteness across + # the intermediate edge-count spread before accepting the config. + return finite_across_edges() + + key = (cf, lmax) + + def install(entry: tuple | None) -> None: + register_tile_configs("stack_fp16x3", {key: entry}) + + def conclude() -> tuple | None: + """Validate candidates through the live table; return the winner.""" + # === Step 3. Find a pin combination that validates as a whole === + pin = None + for candidate in _FP16X3_PIN_CANDIDATES: + install((candidate,) * 4) + if validate(): + pin = candidate + break + if pin is None: + print( + f"BEST stack_fp16x3[{key}]: no pin combination validates; " + "the fp32 stack stays in charge" + ) + return None + print(f"[pin combination: {pin}]") + + # === Step 4. Per kernel, walk the speed ranking to the first + # candidate that validates with the other three pinned === + chosen: list[tuple] = [] + for slot in range(4): + picked = None + for rank, (ms, cfg) in enumerate(ranked[slot]): + entry = [pin] * 4 + entry[slot] = cfg + install(tuple(entry)) + if validate(): + picked = (ms, cfg, rank) + break + if picked is None: + print( + f"BEST stack_fp16x3[{key}]: {names[slot]} has no " + "validating candidate; the fp32 stack stays in charge" + ) + return None + chosen.append(picked) + print( + f" {names[slot]}: rank #{picked[2]} {picked[1]} @ {picked[0]:.3f} ms" + ) + + # === Step 5. Joint validation; fall back to the pin on failure === + install(tuple(cfg for _ms, cfg, _rank in chosen)) + if not validate(): + print("[joint validation failed; falling back to the pin combination]") + install((pin,) * 4) + if not validate(): + print( + f"BEST stack_fp16x3[{key}]: pin combination regressed; " + "the fp32 stack stays in charge" + ) + return None + chosen = [(float("nan"), pin, -1)] * 4 + conclusion = tuple(cfg for _ms, cfg, _rank in chosen) + print(f"BEST stack_fp16x3[{key}] = {conclusion}") + return conclusion + + # Trial installs must never outlive the sweep. A completed run replaces + # them with its conclusion (a validated combination, or None recording + # that the fp32 stack won); an aborted run (OOM, compilation failure) + # restores the pre-sweep entry so the key does not read as swept. + runtime = _runtime_tile_configs("stack_fp16x3") + had_prior = key in runtime + prior = runtime.get(key) + try: + conclusion = conclude() + except BaseException: + if had_prior: + install(prior) + else: + runtime.pop(key, None) + raise + install(conclusion) + return {"stack_fp16x3": {key: conclusion}} + + +_SWEEPS: dict[str, Callable[..., SweepResult]] = { + "pointwise": sweep_pointwise, + "rotate_fwd": sweep_rotate_fwd, + "rotate_bwd": sweep_rotate_bwd, + "flash_bwd": sweep_flash_bwd, + "fp16x3": sweep_fp16x3, +} + +# Sentinel family per sweep group: the group has run for a key exactly when +# its sentinel family carries the key (``sweep_pointwise`` skips the +# recompute kernel in the bmm regime, so ``gate`` is the group sentinel). +_GROUP_SENTINELS = { + "pointwise": "gate", + "rotate_fwd": "rotate_mix_fwd", + "rotate_bwd": "rotate_mix_bwd_block", + "flash_bwd": "flash_bwd_block", + "fp16x3": "stack_fp16x3", +} + + +# ====================================================================== +# Model-driven tuning +# ====================================================================== +def collect_model_shape_keys(model: torch.nn.Module) -> list[tuple[int, int, int, int]]: + """Collect the shape keys of every fused-value-path convolution in ``model``. + + Parameters + ---------- + model : torch.nn.Module + A constructed SeZM model (weights need not be loaded; the shapes are + fixed by the hyperparameters). + + Returns + ------- + list[tuple[int, int, int, int]] + Deduplicated ``(focus_dim, lmax, n_focus, n_head)`` tuples, one per + distinct convolution shape whose layout is supported by the fused + value path. Convolutions outside the supported layout never query + the tables and contribute no keys. + """ + from deepmd.pt.model.descriptor.sezm_nn.so2 import ( + SO2Convolution, + ) + + from .so2_value_path import ( + _is_supported, + ) + + keys: list[tuple[int, int, int, int]] = [] + for module in model.modules(): + if not isinstance(module, SO2Convolution) or not _is_supported(module): + continue + n_head = module.n_atten_head if module.n_atten_head > 0 else 1 + key = (module.so2_focus_dim, module.lmax, module.n_focus, n_head) + if key not in keys: + keys.append(key) + return keys + + +def tune_missing_configs( + shape_keys: list[tuple[int, int, int, int]], + *, + level: int, + device: torch.device | str = "cuda", + n_edge: int | None = None, +) -> SweepResult: + """Sweep and register every table key the given shapes need but lack. + + For each ``(focus_dim, lmax, n_focus, n_head)`` shape, every sweep group + relevant at the given ``DP_TRITON_INFER`` level (the table-configured + kernels at level 2, plus the fp16x3 stack at level 3) is checked against + :func:`~.tile_configs.has_tile_config` and run only when its sentinel + key is absent from both the built-in tables and the current process + registrations. Results -- including "the default won" conclusions -- + are registered for the current process, so a subsequent trace bakes + them and a subsequent call skips them. + + Parameters + ---------- + shape_keys : list[tuple[int, int, int, int]] + Shapes as returned by :func:`collect_model_shape_keys`. + level : int + The active Triton inference level; levels below 2 tune nothing. + device : torch.device or str + CUDA device to sweep on. + n_edge : int or None + Edge count of the synthetic sweep tensors; ``None`` selects a + width-scaled saturating count per kernel family. + + Returns + ------- + SweepResult + All newly registered entries, empty when every key was covered. + """ + registered: SweepResult = {} + if level < 2: + return registered + for cf, lmax, n_focus, n_head in shape_keys: + c_wide = n_focus * cf + pending: list[tuple[str, dict[str, Any]]] = [] + for group, sentinel in _GROUP_SENTINELS.items(): + if group == "fp16x3" and level < 3: + continue + key = (cf, lmax) if group in ("pointwise", "fp16x3") else (c_wide, lmax) + if has_tile_config(sentinel, key): + continue + kwargs: dict[str, Any] = { + "n_focus": n_focus, + "n_edge": n_edge, + "device": device, + } + if group == "flash_bwd": + kwargs["n_head"] = n_head + pending.append((group, kwargs)) + if not pending: + continue + log.info( + "Tuning SeZM Triton launch configurations for shape " + "(focus_dim=%d, lmax=%d, n_focus=%d) on %s: groups %s. " + "This runs once per uncovered shape; expect minutes on narrow " + "widths and up to tens of minutes on wide channels at high lmax " + "(kernel compilation dominates).", + cf, + lmax, + n_focus, + torch.cuda.get_device_name(torch.device(device)), + [group for group, _ in pending], + ) + for group, kwargs in pending: + result = _SWEEPS[group](cf, lmax, **kwargs) + for family, entries in result.items(): + register_tile_configs(family, entries) + registered.setdefault(family, {}).update(entries) + # Synthetic sweep tensors approach a third of the device memory; + # returning them to the driver between groups keeps consecutive + # sweeps (and the model being frozen) from fragmenting into OOM. + torch.cuda.empty_cache() + return registered + + +def _format_data_fragment(entries: SweepResult) -> str: + """Render registered entries as a :mod:`.tile_config_data` fragment.""" + if not torch.cuda.is_available(): + return "" + device_name = torch.cuda.get_device_name(torch.cuda.current_device()) + lines = [f' "{device_name}": {{'] + for family in sorted(entries): + lines.append(f' "{family}": {{') + for key in sorted(entries[family]): + lines.append(f" {key}: {entries[family][key]},") + lines.append(" },") + lines.append(" },") + return "\n".join(lines) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Sweep launch configurations of the shape-tuned SeZM " + "Triton kernels, either for one explicit (cf, lmax) key or for " + "every key a checkpoint needs." + ) + parser.add_argument("--cf", type=int, help="per-focus width Cf") + parser.add_argument("--lmax", type=int, help="maximum degree") + parser.add_argument("--model", help="checkpoint path; sweeps its missing keys") + parser.add_argument( + "--level", + type=int, + default=3, + help="tune for this DP_TRITON_INFER level (with --model)", + ) + parser.add_argument("--focus", type=int, default=2, help="focus count F") + parser.add_argument("--heads", type=int, default=1, help="attention head count") + parser.add_argument( + "--edges", + type=int, + default=None, + help="edge count E (default: width-scaled saturating count)", + ) + parser.add_argument("--device", default="cuda", help="CUDA device string") + parser.add_argument( + "--kernels", + default=",".join(_SWEEPS), + help=f"comma list of sweep groups (with --cf/--lmax), from {sorted(_SWEEPS)}", + ) + args = parser.parse_args() + + torch.manual_seed(0) + torch.backends.cuda.matmul.allow_tf32 = False + device = torch.device(args.device) + if device.type == "cuda" and device.index is not None: + torch.cuda.set_device(device) + + if args.model is not None: + from deepmd.pt.entrypoints.freeze_pt2 import ( + _extract_state_and_params, + ) + from deepmd.pt.model.model import ( + get_model, + ) + + raw = torch.load(args.model, map_location="cpu", weights_only=False) + _, params = _extract_state_and_params(raw) + branch_params = ( + list(params["model_dict"].values()) if "model_dict" in params else [params] + ) + shape_keys = [] + for branch in branch_params: + for key in collect_model_shape_keys(get_model(branch)): + if key not in shape_keys: + shape_keys.append(key) + print(f"model shape keys (cf, lmax, F, H): {shape_keys}") + registered = tune_missing_configs( + shape_keys, level=args.level, device=device, n_edge=args.edges + ) + else: + if args.cf is None or args.lmax is None: + parser.error("either --model or both --cf and --lmax are required") + groups = [name.strip() for name in args.kernels.split(",") if name.strip()] + unknown = sorted(set(groups) - set(_SWEEPS)) + if unknown: + parser.error(f"unknown sweep groups: {unknown}") + print( + f"cf={args.cf} lmax={args.lmax} focus={args.focus} heads={args.heads} " + f"edges={args.edges}" + ) + registered = {} + for name in groups: + print(f"== {name} ==") + kwargs: dict[str, Any] = { + "n_focus": args.focus, + "n_edge": args.edges, + "device": device, + } + if name == "flash_bwd": + kwargs["n_head"] = args.heads + result = _SWEEPS[name](args.cf, args.lmax, **kwargs) + for family, entries in result.items(): + register_tile_configs(family, entries) + registered.setdefault(family, {}).update(entries) + + if registered: + print("\n# tile_config_data.py fragment (merge under the GPU name):") + print(_format_data_fragment(registered)) + else: + print("all requested keys are already covered; nothing swept") + + +if __name__ == "__main__": + main() diff --git a/deepmd/kernels/triton/sezm/tile_config_data.py b/deepmd/kernels/triton/sezm/tile_config_data.py new file mode 100644 index 0000000000..fa1ec073a6 --- /dev/null +++ b/deepmd/kernels/triton/sezm/tile_config_data.py @@ -0,0 +1,230 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Built-in launch-configuration data for the shape-tuned SeZM Triton kernels. + +This module is pure data: one nested mapping per GPU model, keyed by the +exact device name reported by :func:`torch.cuda.get_device_name`. The +query layer in :mod:`.tile_configs` selects the sub-mapping of the running +GPU and resolves individual keys; devices without an entry here fall back +to the conservative defaults of every kernel family (correct on any CUDA +device, merely not tuned). + +Entry semantics +--------------- +Every per-family table maps an exact shape key to either a launch +configuration tuple or ``None``: + +- a tuple is the winning configuration measured by the sweep; +- ``None`` records that the sweep ran and the tuned kernel did **not** beat + its baseline for this key (win-list families) or that the default + configuration itself won (default-keyed families) -- the fallback is the + measured optimum, not a guess; +- an absent key means the shape was never swept on this GPU. The freeze + auto-tuner (:func:`.sweep_tile_configs.tune_missing_configs`) treats only + absent keys as work. + +Key conventions and value layouts are documented in :mod:`.tile_configs`; +regeneration is documented in :mod:`.sweep_tile_configs`. All entries +below were swept at production edge counts (3e5 to 6.5e5 edges) with the +``(C_wide, lmax)``-keyed families measured at ``n_focus = 2``. +""" + +from __future__ import ( + annotations, +) + +__all__ = ["BUILTIN_TILE_CONFIGS"] + +# fmt: off +BUILTIN_TILE_CONFIGS: dict[ + str, dict[str, dict[tuple[int, int], tuple | None]] +] = { + "NVIDIA H20": { + # (Cf, lmax) -> (BLOCK_M, num_warps, num_stages) + "gate": { + (32, 1): (32, 4, 2), + (32, 2): (64, 4, 2), + (32, 3): (64, 4, 2), + (32, 4): (64, 4, 1), + (32, 5): (64, 4, 1), + (32, 6): (64, 4, 1), + (64, 1): (32, 16, 2), + (64, 2): (64, 8, 1), + (64, 3): (64, 8, 2), + (64, 4): (64, 8, 1), + (64, 5): (16, 8, 2), + (64, 6): (16, 8, 2), + (96, 1): (8, 4, 2), + (96, 2): (16, 8, 1), + (96, 3): (8, 8, 2), + (96, 4): (8, 8, 2), + (96, 5): (8, 8, 1), + (96, 6): (8, 8, 1), + (128, 1): (16, 16, 1), + (128, 2): (16, 16, 1), + (128, 3): (32, 16, 1), + (128, 4): (16, 16, 1), + (128, 5): (16, 16, 1), + (128, 6): (16, 16, 2), + }, + # (Cf, lmax) -> (BLOCK_M, num_warps, num_stages); keys with + # Cf >= GATE_BMM_MIN_FOCUS_DIM are structurally absent (the gate + # projection runs as a cuBLAS bmm there and the recompute kernel is + # never launched). + "recompute": { + (32, 1): (64, 4, 1), + (32, 2): (32, 4, 1), + (32, 3): (64, 4, 2), + (32, 4): (32, 4, 1), + (32, 5): (32, 4, 1), + (32, 6): (32, 4, 2), + (64, 1): (32, 4, 2), + (64, 2): (64, 8, 2), + (64, 3): (64, 8, 1), + (64, 4): (64, 8, 2), + (64, 5): (64, 8, 2), + (64, 6): (16, 8, 1), + }, + # (Cf, lmax) -> (BLOCK_M, num_warps, num_stages) + "point": { + (32, 1): (64, 8, 1), + (32, 2): (16, 4, 1), + (32, 3): (64, 8, 2), + (32, 4): (16, 4, 1), + (32, 5): (16, 4, 2), + (32, 6): (16, 4, 2), + (64, 1): (16, 4, 1), + (64, 2): (16, 8, 1), + (64, 3): (32, 8, 2), + (64, 4): (32, 8, 2), + (64, 5): (16, 8, 2), + (64, 6): (16, 8, 1), + (96, 1): (8, 4, 2), + (96, 2): (8, 8, 2), + (96, 3): (8, 8, 2), + (96, 4): (8, 8, 2), + (96, 5): (8, 8, 2), + (96, 6): (8, 8, 1), + (128, 1): (8, 8, 2), + (128, 2): (8, 8, 2), + (128, 3): (8, 8, 1), + (128, 4): (8, 8, 2), + (128, 5): (8, 8, 1), + (128, 6): (8, 8, 1), + }, + # (C_wide, lmax) -> (num_warps, num_stages); None records keys where + # the upstream default (2, 2) itself won the sweep. + "rotate_mix_fwd": { + (64, 1): (1, 2), + (64, 2): (1, 2), + (64, 3): (1, 2), + (64, 4): (1, 2), + (64, 5): (1, 2), + (64, 6): (1, 2), + (128, 1): (1, 2), + (128, 2): (1, 2), + (128, 3): (1, 2), + (128, 4): (2, 1), + (128, 5): None, + (128, 6): (1, 2), + (192, 1): (1, 1), + (192, 2): None, + (192, 3): (2, 1), + (192, 4): (1, 2), + (192, 5): None, + (192, 6): (1, 2), + (256, 1): (1, 1), + (256, 2): None, + (256, 3): (1, 1), + (256, 4): (1, 2), + (256, 5): (4, 1), + (256, 6): (1, 2), + }, + # (C_wide, lmax) -> (BLOCK_E, num_warps, num_stages); win list + # against the per-edge kernel, None keeps the per-edge kernel. + "flash_bwd_block": { + (64, 1): (4, 2, 1), + (64, 2): (4, 2, 1), + (64, 3): (4, 2, 2), + (64, 4): (4, 2, 2), + (64, 5): (4, 2, 1), + (64, 6): (4, 2, 1), + (128, 1): None, + (128, 2): (2, 2, 1), + (128, 3): None, + (128, 4): None, + (128, 5): (2, 2, 1), + (128, 6): None, + (192, 1): None, + (192, 2): None, + (192, 3): None, + (192, 4): None, + (192, 5): None, + (192, 6): None, + (256, 1): None, + (256, 2): None, + (256, 3): None, + (256, 4): None, + (256, 5): None, + (256, 6): None, + }, + # (C_wide, lmax) -> (BLOCK_E, num_warps, num_stages); win list + # against the per-edge kernel, None keeps the per-edge kernel. + "rotate_mix_bwd_block": { + (64, 1): (8, 2, 1), + (64, 2): (8, 4, 1), + (64, 3): (4, 2, 2), + (64, 4): (4, 2, 1), + (64, 5): (4, 2, 2), + (64, 6): (4, 2, 1), + (128, 1): None, + (128, 2): None, + (128, 3): None, + (128, 4): (4, 4, 1), + (128, 5): (2, 2, 1), + (128, 6): (2, 2, 1), + (192, 1): None, + (192, 2): None, + (192, 3): None, + (192, 4): None, + (192, 5): None, + (192, 6): None, + (256, 1): None, + (256, 2): None, + (256, 3): None, + (256, 4): None, + (256, 5): None, + (256, 6): None, + }, + # (Cf, lmax) -> four (BLOCK_M, BLOCK_N, BLOCK_K, num_warps, + # num_stages) GEMM configurations in the order (forward m0, + # forward |m|=1, backward m0, backward |m|=1). Every tuple entry + # passed the fp64 exactness sweep; None would keep the fp32 stack. + "stack_fp16x3": { + (32, 1): ((128, 64, 64, 4, 1), (64, 64, 32, 4, 1), (128, 64, 32, 8, 1), (64, 64, 64, 4, 1)), + (32, 2): ((64, 64, 32, 4, 3), (64, 64, 32, 4, 3), (64, 64, 32, 4, 3), (64, 64, 32, 4, 3)), + (32, 3): ((64, 64, 32, 4, 3), (64, 64, 32, 4, 1), (64, 64, 32, 4, 3), (128, 64, 32, 8, 1)), + (32, 4): ((64, 64, 32, 4, 1), (64, 64, 32, 4, 1), (64, 64, 32, 4, 1), (64, 64, 32, 4, 1)), + (32, 5): ((128, 64, 32, 8, 1), (64, 64, 32, 4, 1), (64, 64, 64, 4, 1), (64, 64, 64, 4, 1)), + (32, 6): ((64, 64, 32, 4, 1), (64, 64, 32, 4, 1), (64, 64, 32, 4, 1), (64, 128, 64, 4, 1)), + (64, 1): ((64, 64, 32, 4, 3), (64, 64, 32, 4, 3), (64, 64, 32, 4, 3), (64, 64, 32, 4, 3)), + (64, 2): ((128, 64, 32, 8, 1), (64, 64, 32, 4, 1), (128, 64, 32, 8, 1), (64, 64, 32, 4, 1)), + (64, 3): ((64, 64, 32, 4, 1), (64, 64, 32, 4, 1), (64, 64, 32, 4, 1), (64, 128, 64, 4, 1)), + (64, 4): ((64, 64, 32, 4, 1), (64, 64, 32, 4, 1), (64, 64, 32, 4, 1), (64, 64, 32, 4, 1)), + (64, 5): ((64, 128, 64, 4, 1), (64, 64, 32, 4, 1), (64, 128, 64, 4, 1), (64, 128, 64, 4, 1)), + (64, 6): ((64, 64, 32, 4, 1), (64, 64, 32, 4, 1), (64, 64, 32, 4, 1), (64, 128, 64, 4, 1)), + (96, 1): ((128, 64, 32, 8, 1), (64, 64, 32, 4, 1), (128, 64, 32, 8, 1), (128, 64, 32, 8, 1)), + (96, 2): ((64, 64, 32, 4, 1), (64, 64, 32, 4, 1), (64, 64, 32, 4, 1), (64, 128, 64, 4, 1)), + (96, 3): ((64, 128, 64, 4, 1), (64, 64, 32, 4, 1), (64, 128, 64, 4, 1), (64, 64, 32, 4, 1)), + (96, 4): ((64, 64, 32, 4, 1), (64, 64, 32, 4, 1), (64, 64, 32, 4, 1), (64, 128, 64, 4, 1)), + (96, 5): ((64, 64, 32, 4, 1), (64, 64, 32, 4, 1), (64, 64, 32, 4, 1), (64, 64, 32, 4, 1)), + (96, 6): ((64, 64, 32, 4, 1), (64, 64, 32, 4, 1), (64, 64, 32, 4, 1), (64, 64, 32, 4, 1)), + (128, 1): ((64, 128, 64, 4, 1), (64, 64, 32, 4, 1), (64, 128, 64, 4, 1), (64, 128, 64, 4, 1)), + (128, 2): ((64, 128, 64, 4, 1), (64, 64, 32, 4, 1), (64, 128, 64, 4, 1), (64, 64, 32, 4, 1)), + (128, 3): ((64, 64, 32, 4, 1), (64, 64, 32, 4, 1), (64, 64, 32, 4, 1), (64, 128, 64, 4, 1)), + (128, 4): ((64, 128, 64, 4, 1), (64, 64, 32, 4, 1), (64, 128, 64, 4, 1), (64, 64, 32, 4, 1)), + (128, 5): ((64, 128, 64, 4, 1), (64, 64, 32, 4, 1), (64, 128, 64, 4, 1), (64, 128, 64, 4, 1)), + (128, 6): ((64, 128, 64, 4, 1), (64, 64, 32, 4, 1), (64, 128, 64, 4, 1), (64, 128, 64, 4, 1)), + }, + }, +} +# fmt: on diff --git a/deepmd/kernels/triton/sezm/tile_configs.py b/deepmd/kernels/triton/sezm/tile_configs.py new file mode 100644 index 0000000000..2167f71862 --- /dev/null +++ b/deepmd/kernels/triton/sezm/tile_configs.py @@ -0,0 +1,352 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +r"""Launch-configuration lookup for the shape-tuned SeZM Triton kernels. + +Configurations are resolved through two layers: + +1. *Built-in tables* (:mod:`.tile_config_data`), keyed by the exact GPU name + reported by :func:`torch.cuda.get_device_name`. These ship with the + package and hold the sweep results for the GPUs the maintainers have + tuned; a device without a built-in table resolves every key to the + conservative default of its kernel family (correct on any CUDA device, + merely not tuned). +2. *Runtime registrations* (:func:`register_tile_configs`), which take + precedence over the built-in tables in the current process. The freeze + auto-tuner (:func:`.sweep_tile_configs.tune_missing_configs`) sweeps the + shape keys of the checkpoint being frozen on the local GPU and registers + the winners here, so the traced ``.pt2`` bakes tuned launches even on + devices without built-in coverage. Registrations are process-local by + design: a ``.pt2`` is not portable across GPU models, so its tuning does + not need to be either. + +Two shape-key conventions are used: + +- ``(focus_dim, lmax)`` for kernels whose register pressure is per focus + stream (the value-path pointwise kernels and the fp16x3 stack GEMMs); + entries are valid for any focus count ``F``. +- ``(C_wide, lmax)`` with ``C_wide = n_focus * focus_dim`` for kernels that + vectorize over the full hidden width (the rotate+mix kernels and the + edge-block flash-attention backward). + +Fallback behaviour on an unresolved key depends on the family: + +- ``gate`` / ``recompute`` / ``point`` fall back to a spill-safe + configuration of the same kernel and ``rotate_mix_fwd`` to the upstream + default: tile choices never affect numerical results (they change the + schedule, not any reduction order), and the conservative end degrades + gracefully. +- ``flash_bwd_block`` and ``rotate_mix_bwd_block`` are win lists: a key + resolves to a configuration only where the edge-block schedule beat the + per-edge kernel by at least 3% in the sweep, and anything else keeps the + per-edge kernel. The edge-block schedule wins on narrow hidden widths + (large per-edge cross-lane reduction overhead) and loses badly on wide + ones (register-tile pressure), so the win list is the routing criterion, + not merely a tuning hint. +- ``stack_fp16x3`` is a validated win list: every entry passed the fp64 + exactness sweep for the exact kernel binary it launches, and an + unresolved key keeps the fp32 mixing stack. These entries are + load-bearing for correctness -- some ``(num_warps, num_stages)`` + combinations of the three-``tl.dot`` k-loop are miscompiled by the Triton + software pipeliner into silent NaN rows at production edge counts, and + the affected set shifts with any change to the kernel body. Never edit + an fp16x3 entry by hand; always regenerate it through the sweep. + +Register-pressure guidance +-------------------------- +The winning ``BLOCK_M`` of the pointwise kernels shrinks monotonically as +the register-pressure product ``lmax * next_power_of_2(Cf)`` grows (wide +64-row tiles for ``Cf = 32``, narrow 8..16-row tiles at ``Cf >= 96``); a +candidate on the wrong side of the spill point can be an order of magnitude +slower, which is why the tables are exact-keyed rather than heuristic. The +same product governs the edge-block backward kernels through their +``(BLOCK_E, C_wide)`` register tiles. + +Wide-channel regime +------------------- +At ``Cf >= GATE_BMM_MIN_FOCUS_DIM`` the per-group ``CP x CP`` register dot of +the gate forward/backward spills regardless of the tile choice (a padded 96 +behaves like 128). In that regime the sigmoid projection and the gate-logit +contraction run as cuBLAS batched matmuls and the Triton kernels keep only +the pointwise work, so ``gate`` entries for those keys were swept with the +projection disabled and ``recompute`` entries do not exist. +""" + +from __future__ import ( + annotations, +) + +import functools + +import torch + +from .tile_config_data import ( + BUILTIN_TILE_CONFIGS, +) + +__all__ = [ + "GATE_BMM_MIN_FOCUS_DIM", + "TILE_CONFIG_FAMILIES", + "flash_bwd_block_config", + "gate_config", + "has_tile_config", + "point_config", + "recompute_config", + "register_tile_configs", + "rotate_mix_bwd_block_config", + "rotate_mix_fwd_config", + "stack_fp16x3_configs", +] + +# Per-focus channel width at or above which the gate sigmoid projection and +# the gate-logit contraction are delegated to cuBLAS batched matmuls. +GATE_BMM_MIN_FOCUS_DIM = 96 + +TILE_CONFIG_FAMILIES = ( + "gate", + "recompute", + "point", + "rotate_mix_fwd", + "flash_bwd_block", + "rotate_mix_bwd_block", + "stack_fp16x3", +) + +_POINTWISE_FALLBACK = (16, 8, 2) +_ROTATE_MIX_FWD_DEFAULT = (2, 2) + +# Runtime registrations, highest lookup precedence. Populated by the freeze +# auto-tuner and by manual sweep runs in the same process. +_RUNTIME: dict[str, dict[tuple[int, int], tuple | None]] = { + family: {} for family in TILE_CONFIG_FAMILIES +} + + +@functools.cache +def _builtin_tables() -> dict[str, dict[tuple[int, int], tuple | None]]: + """Return the built-in tables of the running GPU (empty when untuned).""" + if not torch.cuda.is_available(): + return {} + device_name = torch.cuda.get_device_name(torch.cuda.current_device()) + return BUILTIN_TILE_CONFIGS.get(device_name, {}) + + +def _lookup(family: str, key: tuple[int, int]) -> tuple | None: + """Resolve ``key`` through the runtime and built-in layers. + + A ``None`` result folds together an explicit ``None`` entry (the sweep + ran and the family default is the measured optimum) and an absent key + (never swept on this GPU): the caller behaves identically in both cases. + """ + runtime = _RUNTIME[family] + if key in runtime: + return runtime[key] + return _builtin_tables().get(family, {}).get(key) + + +def _runtime_tile_configs(family: str) -> dict[tuple[int, int], tuple | None]: + """Return the mutable runtime table of ``family``. + + Internal accessor for the sweep (which must restore pre-sweep entries + when a run aborts) and for tests; regular callers register through + :func:`register_tile_configs` only. + """ + if family not in TILE_CONFIG_FAMILIES: + raise ValueError( + f"unknown tile-config family {family!r}; expected one of " + f"{TILE_CONFIG_FAMILIES}" + ) + return _RUNTIME[family] + + +def register_tile_configs( + family: str, entries: dict[tuple[int, int], tuple | None] +) -> None: + """Register swept launch configurations for the current process. + + Registered entries take precedence over the built-in tables and feed the + same lookup functions, so a registration made before model construction + is picked up by the construction-time operator bindings and baked into + any subsequent trace. + + Parameters + ---------- + family : str + One of :data:`TILE_CONFIG_FAMILIES`. + entries : dict[tuple[int, int], tuple or None] + Shape keys mapped to the winning configuration, or to ``None`` to + record that the sweep ran and the family default is the measured + optimum for that key. + + Raises + ------ + ValueError + If ``family`` is not a known kernel family. + """ + if family not in TILE_CONFIG_FAMILIES: + raise ValueError( + f"unknown tile-config family {family!r}; expected one of " + f"{TILE_CONFIG_FAMILIES}" + ) + _RUNTIME[family].update(entries) + + +def has_tile_config(family: str, key: tuple[int, int]) -> bool: + """Return whether ``key`` has been swept on this GPU. + + An explicit ``None`` entry counts as swept (the default configuration is + the measured optimum); only keys absent from both the runtime and the + built-in layer report ``False``. The freeze auto-tuner uses this to + decide which keys still need work. + """ + if family not in TILE_CONFIG_FAMILIES: + raise ValueError( + f"unknown tile-config family {family!r}; expected one of " + f"{TILE_CONFIG_FAMILIES}" + ) + return key in _RUNTIME[family] or key in _builtin_tables().get(family, {}) + + +def gate_config(focus_dim: int, lmax: int) -> tuple[int, int, int]: + """Return ``(BLOCK_M, num_warps, num_stages)`` for the gate forward kernel. + + Parameters + ---------- + focus_dim : int + Per-focus channel width ``Cf``. + lmax : int + Maximum spherical harmonic degree. + + Returns + ------- + tuple[int, int, int] + The swept launch configuration, or the spill-safe fallback for + unresolved keys. + """ + return _lookup("gate", (focus_dim, lmax)) or _POINTWISE_FALLBACK + + +def recompute_config(focus_dim: int, lmax: int) -> tuple[int, int, int]: + """Return ``(BLOCK_M, num_warps, num_stages)`` for the gate recompute kernel. + + Parameters + ---------- + focus_dim : int + Per-focus channel width ``Cf``. + lmax : int + Maximum spherical harmonic degree. + + Returns + ------- + tuple[int, int, int] + The swept launch configuration, or the spill-safe fallback for + unresolved keys. + """ + return _lookup("recompute", (focus_dim, lmax)) or _POINTWISE_FALLBACK + + +def point_config(focus_dim: int, lmax: int) -> tuple[int, int, int]: + """Return ``(BLOCK_M, num_warps, num_stages)`` for the backward pointwise kernel. + + Parameters + ---------- + focus_dim : int + Per-focus channel width ``Cf``. + lmax : int + Maximum spherical harmonic degree. + + Returns + ------- + tuple[int, int, int] + The swept launch configuration, or the spill-safe fallback for + unresolved keys. + """ + return _lookup("point", (focus_dim, lmax)) or _POINTWISE_FALLBACK + + +def rotate_mix_fwd_config(c_wide: int, lmax: int) -> tuple[int, int]: + """Return ``(num_warps, num_stages)`` for the rotate+mix forward kernel. + + Parameters + ---------- + c_wide : int + Full hidden width ``n_focus * focus_dim``. + lmax : int + Maximum spherical harmonic degree. + + Returns + ------- + tuple[int, int] + The swept launch configuration, or the upstream default ``(2, 2)`` + for unresolved keys. + """ + return _lookup("rotate_mix_fwd", (c_wide, lmax)) or _ROTATE_MIX_FWD_DEFAULT + + +def flash_bwd_block_config(c_wide: int, lmax: int) -> tuple[int, int, int] | None: + """Return the edge-block flash-attention backward config, or ``None``. + + Parameters + ---------- + c_wide : int + Full hidden width ``n_focus * focus_dim``. + lmax : int + Maximum spherical harmonic degree. + + Returns + ------- + tuple[int, int, int] or None + ``(BLOCK_E, num_warps, num_stages)`` when the edge-block schedule won + the sweep for this key; ``None`` keeps the per-edge kernel. + """ + return _lookup("flash_bwd_block", (c_wide, lmax)) + + +def rotate_mix_bwd_block_config(c_wide: int, lmax: int) -> tuple[int, int, int] | None: + """Return the edge-block rotate+mix backward config, or ``None``. + + Parameters + ---------- + c_wide : int + Full hidden width ``n_focus * focus_dim``. + lmax : int + Maximum spherical harmonic degree. + + Returns + ------- + tuple[int, int, int] or None + ``(BLOCK_E, num_warps, num_stages)`` when the edge-block schedule won + the sweep for this key; ``None`` keeps the per-edge kernel. + """ + return _lookup("rotate_mix_bwd_block", (c_wide, lmax)) + + +def stack_fp16x3_configs( + focus_dim: int, lmax: int +) -> ( + tuple[ + tuple[int, int, int, int, int], + tuple[int, int, int, int, int], + tuple[int, int, int, int, int], + tuple[int, int, int, int, int], + ] + | None +): + """Return the validated fp16x3 stack GEMM configs, or ``None``. + + Parameters + ---------- + focus_dim : int + Per-focus channel width ``Cf``. + lmax : int + Maximum spherical harmonic degree. + + Returns + ------- + tuple or None + The four ``(BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages)`` + configurations in the order (forward m0, forward |m|=1, backward m0, + backward |m|=1) when the key passed the fp64 validation sweep; + ``None`` keeps the fp32 mixing stack. There is deliberately no + fallback configuration: an unvalidated configuration may be + miscompiled into silent NaN (see the module docstring). + """ + return _lookup("stack_fp16x3", (focus_dim, lmax)) diff --git a/deepmd/kernels/triton/sezm/wigner_monomials.py b/deepmd/kernels/triton/sezm/wigner_monomials.py new file mode 100644 index 0000000000..a7b3c82e2e --- /dev/null +++ b/deepmd/kernels/triton/sezm/wigner_monomials.py @@ -0,0 +1,322 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +# pyright: reportMissingImports=false +# ruff: noqa: ANN001, ANN202, RUF005 +"""Quaternion monomial design matrices with compile-time exponent tables. + +The Wigner-D construction for degrees ``l >= 2`` evaluates, per edge, a fixed +monomial basis of the unit quaternion + + ``M[e, m] = q0^a_m * q1^b_m * q2^c_m * q3^d_m``, + +with ``a_m + b_m + c_m + d_m`` equal to the kernel degree, followed by one +matrix multiply against a precomputed coefficient table. The reference chain +(power table, ``gather``, ``prod``) materializes three ``(4, P + 1, E)`` +intermediates per degree kernel, and its ``prod`` backward lowers to a +``cumprod`` scan pair -- several milliseconds per model call at typical edge +counts. Here the exponent table is a compile-time constant: the kernel +builds the four scalar power ladders in registers and emits every monomial +(and, in the backward, its four leave-one-out derivatives +``d M_m / d q_i = e_i * q_i^{e_i - 1} * prod_{j != i} q_j^{e_j}``) as an +unrolled register product. No intermediate ever touches DRAM. + +The operator is functional (``mutates_args=()``) with a fake kernel and an +autograd formula whose backward is itself a ``triton_op``, so it composes +with the SeZM ``make_fx`` lowering and the AOTInductor freeze exactly like +the other ``sezm_triton`` operators. The exponent table is passed as a +Python ``list[int]`` and must be extracted from the coefficient buffers in +eager context (module construction), never at trace time: a trace-time +``.tolist()`` on a tensor creates unbacked symbols and aborts export. +""" + +from __future__ import ( + annotations, +) + +import torch +from torch import ( + Tensor, +) +from torch.library import ( + wrap_triton, +) + +__all__ = [ + "WIGNER_MONOMIALS_TRITON_AVAILABLE", + "wigner_monomials", +] + +try: + import triton + import triton.language as tl + + WIGNER_MONOMIALS_TRITON_AVAILABLE = True +except ImportError: # pragma: no cover - exercised only without triton + WIGNER_MONOMIALS_TRITON_AVAILABLE = False + +_BLOCK_EDGES = 256 + + +# ====================================================================== +# Eager reference / fallback implementations +# ====================================================================== +def _monomials_reference(q: Tensor, exponents: list[int], max_power: int) -> Tensor: + """Eager ground truth: explicit power ladder and per-monomial products.""" + n_mono = len(exponents) // 4 + powers = [torch.ones_like(q)] + for _ in range(max_power): + powers.append(powers[-1] * q) + table = torch.stack(powers, dim=1) # (E, max_power + 1, 4) + columns = [ + (table[:, exponents[4 * m + 0], 0] * table[:, exponents[4 * m + 1], 1]) + * (table[:, exponents[4 * m + 2], 2] * table[:, exponents[4 * m + 3], 3]) + for m in range(n_mono) + ] + return torch.stack(columns, dim=1) + + +def _monomials_backward_reference( + grad_out: Tensor, q: Tensor, exponents: list[int], max_power: int +) -> Tensor: + """Closed-form eager backward returning ``grad_q`` with shape (E, 4).""" + n_mono = len(exponents) // 4 + powers = [torch.ones_like(q)] + for _ in range(max_power): + powers.append(powers[-1] * q) + table = torch.stack(powers, dim=1) # (E, max_power + 1, 4) + grad_q = torch.zeros_like(q) + for m in range(n_mono): + e = exponents[4 * m : 4 * m + 4] + g = grad_out[:, m] + for i in range(4): + if e[i] == 0: + continue + partial = g * float(e[i]) * table[:, e[i] - 1, i] + for j in range(4): + if j != i: + partial = partial * table[:, e[j], j] + grad_q[:, i] += partial + return grad_q + + +# ====================================================================== +# Triton kernels +# ====================================================================== +if WIGNER_MONOMIALS_TRITON_AVAILABLE: + + @triton.jit + def _monomials_fwd_kernel( + q_ptr, # (E, 4) contiguous + out_ptr, # (E, M) + n_edge, + EXPS: tl.constexpr, # flat exponent tuple (a0, b0, c0, d0, a1, ...) + M: tl.constexpr, + MAXP: tl.constexpr, + BLOCK_M: tl.constexpr, + ): + """Register power ladders and fully unrolled monomial products.""" + pid = tl.program_id(0) + offs = (pid * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64) + mask = offs < n_edge + + q0 = tl.load(q_ptr + offs * 4 + 0, mask=mask, other=0.0) + q1 = tl.load(q_ptr + offs * 4 + 1, mask=mask, other=0.0) + q2 = tl.load(q_ptr + offs * 4 + 2, mask=mask, other=0.0) + q3 = tl.load(q_ptr + offs * 4 + 3, mask=mask, other=0.0) + + ones = tl.full((BLOCK_M,), 1.0, dtype=tl.float32) + p0 = (ones,) + p1 = (ones,) + p2 = (ones,) + p3 = (ones,) + for _ in tl.static_range(MAXP): + p0 = p0 + (p0[-1] * q0,) + p1 = p1 + (p1[-1] * q1,) + p2 = p2 + (p2[-1] * q2,) + p3 = p3 + (p3[-1] * q3,) + + # ``+ 0`` forces the tuple index to a constexpr expression, which the + # Triton frontend requires for subscripting loop-carried tuples. + for m in tl.static_range(M): + val = (p0[EXPS[4 * m + 0] + 0] * p1[EXPS[4 * m + 1] + 0]) * ( + p2[EXPS[4 * m + 2] + 0] * p3[EXPS[4 * m + 3] + 0] + ) + tl.store(out_ptr + offs * M + m, val, mask=mask) + + @triton.jit + def _monomials_bwd_kernel( + g_ptr, # (E, M) + q_ptr, # (E, 4) + gq_ptr, # (E, 4) + n_edge, + EXPS: tl.constexpr, + M: tl.constexpr, + MAXP: tl.constexpr, + BLOCK_M: tl.constexpr, + ): + """Analytic leave-one-out backward accumulated in registers.""" + pid = tl.program_id(0) + offs = (pid * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64) + mask = offs < n_edge + + q0 = tl.load(q_ptr + offs * 4 + 0, mask=mask, other=0.0) + q1 = tl.load(q_ptr + offs * 4 + 1, mask=mask, other=0.0) + q2 = tl.load(q_ptr + offs * 4 + 2, mask=mask, other=0.0) + q3 = tl.load(q_ptr + offs * 4 + 3, mask=mask, other=0.0) + + ones = tl.full((BLOCK_M,), 1.0, dtype=tl.float32) + p0 = (ones,) + p1 = (ones,) + p2 = (ones,) + p3 = (ones,) + for _ in tl.static_range(MAXP): + p0 = p0 + (p0[-1] * q0,) + p1 = p1 + (p1[-1] * q1,) + p2 = p2 + (p2[-1] * q2,) + p3 = p3 + (p3[-1] * q3,) + + g0 = tl.zeros((BLOCK_M,), dtype=tl.float32) + g1 = tl.zeros((BLOCK_M,), dtype=tl.float32) + g2 = tl.zeros((BLOCK_M,), dtype=tl.float32) + g3 = tl.zeros((BLOCK_M,), dtype=tl.float32) + for m in tl.static_range(M): + g = tl.load(g_ptr + offs * M + m, mask=mask, other=0.0) + if EXPS[4 * m + 0] > 0: + g0 += (g * (EXPS[4 * m + 0] + 0.0)) * ( + (p0[EXPS[4 * m + 0] - 1] * p1[EXPS[4 * m + 1] + 0]) + * (p2[EXPS[4 * m + 2] + 0] * p3[EXPS[4 * m + 3] + 0]) + ) + if EXPS[4 * m + 1] > 0: + g1 += (g * (EXPS[4 * m + 1] + 0.0)) * ( + (p0[EXPS[4 * m + 0] + 0] * p1[EXPS[4 * m + 1] - 1]) + * (p2[EXPS[4 * m + 2] + 0] * p3[EXPS[4 * m + 3] + 0]) + ) + if EXPS[4 * m + 2] > 0: + g2 += (g * (EXPS[4 * m + 2] + 0.0)) * ( + (p0[EXPS[4 * m + 0] + 0] * p1[EXPS[4 * m + 1] + 0]) + * (p2[EXPS[4 * m + 2] - 1] * p3[EXPS[4 * m + 3] + 0]) + ) + if EXPS[4 * m + 3] > 0: + g3 += (g * (EXPS[4 * m + 3] + 0.0)) * ( + (p0[EXPS[4 * m + 0] + 0] * p1[EXPS[4 * m + 1] + 0]) + * (p2[EXPS[4 * m + 2] + 0] * p3[EXPS[4 * m + 3] - 1]) + ) + + tl.store(gq_ptr + offs * 4 + 0, g0, mask=mask) + tl.store(gq_ptr + offs * 4 + 1, g1, mask=mask) + tl.store(gq_ptr + offs * 4 + 2, g2, mask=mask) + tl.store(gq_ptr + offs * 4 + 3, g3, mask=mask) + + +# ====================================================================== +# Dispatch, operator registration and public API +# ====================================================================== +def _use_triton(tensor: Tensor) -> bool: + return ( + WIGNER_MONOMIALS_TRITON_AVAILABLE + and tensor.is_cuda + and tensor.dtype is torch.float32 + ) + + +def _forward_impl(q: Tensor, exponents: list[int], max_power: int) -> Tensor: + if not _use_triton(q): + return _monomials_reference(q, exponents, int(max_power)) + n_edge = q.shape[0] + n_mono = len(exponents) // 4 + out = torch.empty((n_edge, n_mono), device=q.device, dtype=q.dtype) + if type(n_edge) is int and n_edge == 0: + return out + wrap_triton(_monomials_fwd_kernel)[(triton.cdiv(n_edge, _BLOCK_EDGES),)]( + q.contiguous(), + out, + n_edge, + EXPS=tuple(exponents), + M=n_mono, + MAXP=int(max_power), + BLOCK_M=_BLOCK_EDGES, + num_warps=4, + num_stages=2, + ) + return out + + +def _backward_impl( + grad_out: Tensor, q: Tensor, exponents: list[int], max_power: int +) -> Tensor: + if not _use_triton(q): + return _monomials_backward_reference(grad_out, q, exponents, int(max_power)) + n_edge = q.shape[0] + grad_q = torch.empty((n_edge, 4), device=q.device, dtype=q.dtype) + if type(n_edge) is int and n_edge == 0: + return grad_q + wrap_triton(_monomials_bwd_kernel)[(triton.cdiv(n_edge, _BLOCK_EDGES),)]( + grad_out.contiguous(), + q.contiguous(), + grad_q, + n_edge, + EXPS=tuple(exponents), + M=len(exponents) // 4, + MAXP=int(max_power), + BLOCK_M=_BLOCK_EDGES, + num_warps=4, + num_stages=2, + ) + return grad_q + + +_monomials_op = torch.library.triton_op( + "sezm_triton::wigner_monomials", mutates_args=() +)(_forward_impl) + +_monomials_bwd_op = torch.library.triton_op( + "sezm_triton::wigner_monomials_bwd", mutates_args=() +)(_backward_impl) + + +@_monomials_op.register_fake +def _(q, exponents, max_power): + return q.new_empty((q.shape[0], len(exponents) // 4)) + + +@_monomials_bwd_op.register_fake +def _(grad_out, q, exponents, max_power): + return q.new_empty((q.shape[0], 4)) + + +def _setup_context(ctx, inputs, output): + q, exponents, max_power = inputs + ctx.save_for_backward(q) + ctx.exponents = exponents + ctx.max_power = max_power + + +def _backward(ctx, grad_out): + (q,) = ctx.saved_tensors + grad_q = _monomials_bwd_op(grad_out.contiguous(), q, ctx.exponents, ctx.max_power) + return grad_q, None, None + + +_monomials_op.register_autograd(_backward, setup_context=_setup_context) + + +def wigner_monomials(q: Tensor, exponents: list[int], max_power: int) -> Tensor: + """Evaluate a fixed quaternion monomial basis per edge. + + Parameters + ---------- + q : Tensor + Unit quaternions with shape (E, 4). + exponents : list[int] + Flattened exponent table ``(a0, b0, c0, d0, a1, ...)`` with + ``4 * M`` entries; must be a Python list of compile-time constants + (extracted in eager context, never at trace time). + max_power : int + Largest exponent appearing in the table (the power-ladder depth). + + Returns + ------- + Tensor + Monomial design matrix with shape (E, M), where column ``m`` is + ``q0^a_m * q1^b_m * q2^c_m * q3^d_m``. + """ + return _monomials_op(q, exponents, int(max_power)) diff --git a/deepmd/kernels/utils.py b/deepmd/kernels/utils.py new file mode 100644 index 0000000000..08837e8fde --- /dev/null +++ b/deepmd/kernels/utils.py @@ -0,0 +1,102 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +""" +Environment-variable gates for the SeZM/DPA4 hardware-accelerated kernels. + +This module centralizes the opt-in selectors that route inference through the +custom Triton and CuTe kernel packages. The gates are read once at model +construction time so that they become compile-time constants in the traced +(``make_fx``) graph. +""" + +from __future__ import ( + annotations, +) + +import os + +_INFER_TRUE = ("1", "true", "yes", "on") + +TRITON_INFER_LEVELS = (0, 1, 2, 3) + + +def triton_infer_level() -> int: + """Return the opt-in Triton inference level from ``DP_TRITON_INFER``. + + The level is read at module construction time so that it becomes a + compile-time constant in the traced (``make_fx``) graph. It only takes + effect during inference; training always uses the dense reference path. + + Levels are cumulative: + + - ``0`` -- Triton disabled; every operation uses the dense reference path. + - ``1`` -- universal kernels that need no launch-configuration table: + block-diagonal rotation, radial degree mixing, the ``SO2Linear`` + block GEMM, Wigner monomials, flash-attention aggregation, and the + segmented force assembly. These are either runtime-autotuned or run a + single shape-independent configuration. + - ``2`` -- adds kernels whose launch configuration is resolved from the + swept ``(focus_dim, lmax)`` / ``(C_wide, lmax)`` tables in + :mod:`.triton.tile_configs`: the fused SO(2) value path and the + edge-block backward kernels. A key absent from a table falls back to + the level-1 kernel (or a spill-safe configuration) for that operation, + so unswept shapes never regress below level 1. + - ``3`` -- adds the fp16x3 split-compensated mixing-stack GEMMs on + tensor cores. Entries exist only for table keys whose configuration + passed the fp64 validation sweep; unswept shapes keep the level-2 fp32 + stack. This level trades a bounded accuracy perturbation for speed + (see :mod:`.triton.so2_stack_fp16x3`). + + Returns + ------- + int + The configured level in ``{0, 1, 2, 3}``. + + Raises + ------ + ValueError + If ``DP_TRITON_INFER`` is not an integer in ``{0, 1, 2, 3}``. + """ + raw = os.environ.get("DP_TRITON_INFER", "0").strip() + try: + level = int(raw) + except ValueError: + raise ValueError( + f"DP_TRITON_INFER must be an integer in {TRITON_INFER_LEVELS}, got {raw!r}" + ) from None + if level not in TRITON_INFER_LEVELS: + raise ValueError( + f"DP_TRITON_INFER must be one of {TRITON_INFER_LEVELS}, got {level}" + ) + return level + + +def use_cute_infer() -> bool: + """Return whether the opt-in CuTe inference operator is enabled. + + The flag is controlled by the ``DP_CUTE_INFER`` environment variable and is + read at module construction time. It selects the fused CuTe SO(2) value-path + operator (an independent path from ``DP_TRITON_INFER``) and only takes effect + during inference; training always uses the dense reference path. + + Returns + ------- + bool + ``True`` when ``DP_CUTE_INFER`` is set to a truthy value. + """ + return os.environ.get("DP_CUTE_INFER", "0").strip().lower() in _INFER_TRUE + + +def use_amp_infer() -> bool: + """Return whether bf16 autocast is enabled for inference. + + The flag is controlled by the ``DP_AMP_INFER`` environment variable and is + read at module construction time. It only affects inference when the + descriptor's ``use_amp`` option is also enabled; training follows + ``use_amp`` regardless of this environment variable. + + Returns + ------- + bool + ``True`` when ``DP_AMP_INFER`` is set to a truthy value. + """ + return os.environ.get("DP_AMP_INFER", "0").strip().lower() in _INFER_TRUE diff --git a/deepmd/pt/entrypoints/freeze_pt2.py b/deepmd/pt/entrypoints/freeze_pt2.py index 78daf3ef65..77988d7c91 100644 --- a/deepmd/pt/entrypoints/freeze_pt2.py +++ b/deepmd/pt/entrypoints/freeze_pt2.py @@ -46,7 +46,11 @@ from deepmd.dpmodel.utils.region import ( normalize_coord, ) +from deepmd.kernels.utils import ( + triton_infer_level, +) from deepmd.pt.model.descriptor.sezm_nn.so2 import ( + SO2Convolution, SO2Linear, ) from deepmd.pt.model.model import ( @@ -270,7 +274,7 @@ def _collect_metadata( "ntypes": _get_model_ntypes(model), "rcut": float(model.get_rcut()), "sel": [int(s) for s in model.get_sel()], - "lower_input_kind": "nlist" if is_spin else "edge_vec", + "lower_input_kind": model.export_lower_input_kind(), "dim_fparam": int(model.get_dim_fparam()), "dim_aparam": int(model.get_dim_aparam()), "dim_chg_spin": int(model.get_dim_chg_spin()), @@ -295,6 +299,70 @@ def _collect_metadata( return metadata +def _tune_triton_configs(model: torch.nn.Module, target_device: torch.device) -> None: + """Tune the shape-keyed Triton launch tables for this checkpoint's shapes. + + At ``DP_TRITON_INFER >= 2`` the traced graph bakes launch configurations + resolved from the tables in ``deepmd.kernels.triton.sezm.tile_configs``. Shape keys + absent from the built-in tables (an untuned GPU model, or an untuned + width/degree) are swept here on the local GPU -- the exact hardware the + ``.pt2`` will run on, since AOTInductor artifacts are not portable across + GPU models -- and registered for the current process before tracing. + Keys already covered cost nothing. + + The fused value-path entries are then rebound: the mixing-stack operator + selection (fp32 versus fp16x3) is fixed at construction time, which + predates the registrations made here. + """ + if triton_infer_level() < 2: + return + if target_device.type != "cuda" or not torch.cuda.is_available(): + return + from deepmd.kernels.triton.sezm.so2_value_path import ( + SO2_VALUE_PATH_TRITON_AVAILABLE, + make_triton_value_path, + ) + + if not SO2_VALUE_PATH_TRITON_AVAILABLE: + return + from deepmd.kernels.triton.sezm.sweep_tile_configs import ( + collect_model_shape_keys, + tune_missing_configs, + ) + from deepmd.kernels.triton.sezm.tile_configs import ( + _builtin_tables, + ) + + # The built-in tables and the sweep both resolve against the current + # device; pin it to the AOTI target so a freeze aimed at a secondary GPU + # tunes and looks up the right hardware (mixed-model hosts). + if target_device.index is not None: + torch.cuda.set_device(target_device) + _builtin_tables.cache_clear() + + shape_keys = collect_model_shape_keys(model) + registered = tune_missing_configs( + shape_keys, level=triton_infer_level(), device=target_device + ) + if registered: + log.info( + "Registered freshly tuned Triton launch configurations: %s", + {family: sorted(entries) for family, entries in registered.items()}, + ) + else: + log.info( + "Triton launch tables already cover this checkpoint's shapes on %s; " + "no tuning needed.", + torch.cuda.get_device_name(target_device), + ) + # Rebind unconditionally: the fp32-versus-fp16x3 stack selection was made + # at construction time, possibly against a different current device's + # tables, and must reflect the target device and any fresh registrations. + for module in model.modules(): + if isinstance(module, SO2Convolution) and module.triton_infer_level >= 2: + module._triton_value_path = make_triton_value_path(module) + + # The trace-time sendlist for the with-comm artifact embeds the address of a # numpy array (``int**`` contract of ``border_op``). The array must outlive the # trace + export call; the exported graph never reads it at runtime (the op is @@ -406,8 +474,14 @@ def _make_sample_inputs( ) -> tuple[torch.Tensor | None, ...]: """Build representative ``forward_common_lower`` inputs for tracing. - The spin path returns the nlist lower signature; the energy path returns the - single-domain edge schema (folded ``edge_index``, extended scatter indices). + Three lower ABIs are produced, selected by ``model.export_lower_input_kind()`` + and whether the model carries spin: + + - virtual spin (``nlist``): the DeepSpin extended-input signature, since the + graph expands virtual atoms internally; + - native spin (``edge_vec``): the energy edge schema plus the owned-atom + spins (the first ``nloc`` extended rows, where ``mapping`` is identity); + - energy (``edge_vec``): the plain single-domain edge schema. """ ( ext_coord, @@ -419,7 +493,7 @@ def _make_sample_inputs( aparam, charge_spin, ) = _build_sample_extended(model, nframes, nloc, device, has_spin) - if has_spin: + if has_spin and model.export_lower_input_kind() == "nlist": return ( ext_coord, ext_atype, @@ -437,6 +511,19 @@ def _make_sample_inputs( formatted_nlist, mapping_t, ) + if has_spin: + return ( + edge_schema.coord, + edge_schema.atype, + edge_schema.edge_index, + edge_schema.edge_vec, + edge_schema.edge_scatter_index, + edge_schema.edge_mask, + ext_spin[:, :nloc], + fparam, + aparam, + charge_spin, + ) return ( edge_schema.coord, edge_schema.atype, @@ -491,19 +578,22 @@ def _make_comm_sample_inputs( The parallel path indexes the extended node set directly, so ``edge_index`` coincides with ``edge_scatter_index`` (both extended) and ghost features are refreshed via ``border_op`` rather than gathered through a folded mapping. - The frame axis is fixed at one, matching LAMMPS single-frame inference. + The frame axis is fixed at one, matching LAMMPS single-frame inference. The + native spin scheme threads the EXTENDED per-node spin (ghost spins ride the + same exchange), inserted after ``edge_mask`` to match its with-comm signature. """ + has_spin = _model_has_spin(model) ( ext_coord, ext_atype, nlist_t, mapping_t, - _ext_spin, + ext_spin, fparam, aparam, charge_spin, ) = _build_sample_extended( - model, nframes=1, nloc=nloc, device=device, has_spin=False + model, nframes=1, nloc=nloc, device=device, has_spin=has_spin ) formatted_nlist: torch.Tensor = model.format_nlist(ext_coord, ext_atype, nlist_t) edge_schema = edge_schema_from_extended( @@ -512,7 +602,7 @@ def _make_comm_sample_inputs( formatted_nlist, mapping_t, ) - return ( + edge_inputs = ( edge_schema.coord, # (1, nall, 3) edge_schema.atype, # (1, nloc) ext_atype, # (1, nall) @@ -520,11 +610,11 @@ def _make_comm_sample_inputs( edge_schema.edge_vec, edge_schema.edge_scatter_index, # edge_scatter_index: extended (2, E) edge_schema.edge_mask, - fparam, - aparam, - charge_spin, - *_make_edge_comm_tensors(mapping_t, nloc, device), ) + comm_tensors = _make_edge_comm_tensors(mapping_t, nloc, device) + if has_spin: + return (*edge_inputs, ext_spin, fparam, aparam, charge_spin, *comm_tensors) + return (*edge_inputs, fparam, aparam, charge_spin, *comm_tensors) def _resolve_nframes( @@ -570,17 +660,24 @@ def _resolve_nframes( def _build_dynamic_shapes( sample_inputs: tuple[torch.Tensor | None, ...], ) -> tuple: - """Build positional dynamic-shape constraints for the traced lower input.""" + """Build positional dynamic-shape constraints for the traced lower input. + + The lower ABI is recovered from the sample structure: a floating-point + tensor at index 2 is the extended spin of the deepspin-scheme nlist contract, + while an integer ``edge_index`` there marks the edge contract. A native-spin + edge sample carries the extra per-local-atom spin tensor, giving it ten + positional entries against the energy contract's nine. + """ nframes_dim = torch.export.Dim("nframes", min=1) - has_spin = ( - len(sample_inputs) >= 7 + nloc_dim = torch.export.Dim("nloc", min=1) + nedge_dim = torch.export.Dim("nedge", min=2) + is_nlist_spin = ( + len(sample_inputs) >= 3 and sample_inputs[2] is not None and sample_inputs[2].is_floating_point() ) - nall_dim = torch.export.Dim("nall", min=4 if has_spin else 1) - nloc_dim = torch.export.Dim("nloc", min=1) - nedge_dim = torch.export.Dim("nedge", min=2) - if has_spin: + if is_nlist_spin: + nall_dim = torch.export.Dim("nall", min=4) fparam = sample_inputs[5] aparam = sample_inputs[6] charge_spin = sample_inputs[7] if len(sample_inputs) == 8 else None @@ -596,16 +693,36 @@ def _build_dynamic_shapes( if len(sample_inputs) == 8: shapes = (*shapes, {0: nframes_dim} if charge_spin is not None else None) return shapes - fparam = sample_inputs[6] - aparam = sample_inputs[7] - charge_spin = sample_inputs[8] if len(sample_inputs) == 9 else None - shapes = ( + + nall_dim = torch.export.Dim("nall", min=1) + edge_shapes = ( {0: nframes_dim, 1: nall_dim}, # extended_coord: (nframes, nall, 3) {0: nframes_dim, 1: nloc_dim}, # atype {1: nedge_dim}, # edge_index {0: nedge_dim}, # edge_vec {1: nedge_dim}, # edge_scatter_index {0: nedge_dim}, # edge_mask + ) + # Native-spin edge contract: extra per-local-atom spin leaf at index 6. + is_native_spin = len(sample_inputs) == 10 + if is_native_spin: + fparam, aparam, charge_spin = ( + sample_inputs[7], + sample_inputs[8], + sample_inputs[9], + ) + return ( + *edge_shapes, + {0: nframes_dim, 1: nloc_dim}, # spin: (nframes, nloc, 3) + {0: nframes_dim} if fparam is not None else None, + {0: nframes_dim, 1: nloc_dim} if aparam is not None else None, + {0: nframes_dim} if charge_spin is not None else None, + ) + fparam = sample_inputs[6] + aparam = sample_inputs[7] + charge_spin = sample_inputs[8] if len(sample_inputs) == 9 else None + shapes = ( + *edge_shapes, {0: nframes_dim} if fparam is not None else None, {0: nframes_dim, 1: nloc_dim} if aparam is not None else None, ) @@ -622,15 +739,14 @@ def _build_with_comm_dynamic_shapes( The frame axis is fixed at one (LAMMPS single-frame inference), so only ``nall``, ``nloc`` and ``nedge`` vary. The eight communication tensors are static: ``nswap`` is fixed at LAMMPS init and the graph carries no variation - across its value (``border_op`` is opaque to the exported program). + across its value (``border_op`` is opaque to the exported program). The + native spin contract inserts the extended (nall) spin after ``edge_mask``, + giving 19 positional entries against the energy contract's 18. """ nall_dim = torch.export.Dim("nall", min=1) nloc_dim = torch.export.Dim("nloc", min=1) nedge_dim = torch.export.Dim("nedge", min=2) - fparam = sample_inputs[7] - aparam = sample_inputs[8] - charge_spin = sample_inputs[9] - base = ( + edge_base = ( {1: nall_dim}, # coord: (1, nall, 3) {1: nloc_dim}, # atype: (1, nloc) {1: nall_dim}, # extended_atype: (1, nall) @@ -638,6 +754,27 @@ def _build_with_comm_dynamic_shapes( {0: nedge_dim}, # edge_vec: (nedge, 3) {1: nedge_dim}, # edge_scatter_index: (2, nedge) {0: nedge_dim}, # edge_mask: (nedge,) + ) + is_native_spin = len(sample_inputs) == 19 + if is_native_spin: + fparam, aparam, charge_spin = ( + sample_inputs[8], + sample_inputs[9], + sample_inputs[10], + ) + base = ( + *edge_base, + {1: nall_dim}, # spin: (1, nall, 3) + None if fparam is None else {}, # fparam: (1, ndf) static + None if aparam is None else {1: nloc_dim}, # aparam: (1, nloc, nda) + None if charge_spin is None else {}, # charge_spin: (1, nchg) static + ) + return (*base, *((None,) * 8)) + fparam = sample_inputs[7] + aparam = sample_inputs[8] + charge_spin = sample_inputs[9] + base = ( + *edge_base, None if fparam is None else {}, # fparam: (1, ndf) static None if aparam is None else {1: nloc_dim}, # aparam: (1, nloc, nda) None if charge_spin is None else {}, # charge_spin: (1, nchg) static @@ -750,6 +887,10 @@ def freeze_sezm_to_pt2( if isinstance(module, SO2Linear): module._force_block_diag_matmul = force_block_diag + # Sweep any Triton launch-table keys this checkpoint needs that are not + # covered for the local GPU, so the traced graph bakes tuned launches. + _tune_triton_configs(model, target_device) + _, sample_inputs_cpu = _resolve_nframes( model, nloc=7, @@ -757,50 +898,11 @@ def freeze_sezm_to_pt2( has_spin=is_spin, ) - if is_spin: - ( - ext_coord, - ext_atype, - ext_spin, - nlist_t, - mapping_t, - fparam, - aparam, - charge_spin, - ) = sample_inputs_cpu - traced = model.forward_common_lower_exportable( - ext_coord, - ext_atype, - ext_spin, - nlist_t, - mapping_t, - fparam=fparam, - aparam=aparam, - charge_spin=charge_spin, - ) - else: - ( - coord, - atype, - edge_index, - edge_vec, - edge_scatter_index, - edge_mask, - fparam, - aparam, - charge_spin, - ) = sample_inputs_cpu - traced = model.forward_common_lower_exportable( - coord, - atype, - edge_index, - edge_vec, - edge_scatter_index, - edge_mask, - fparam=fparam, - aparam=aparam, - charge_spin=charge_spin, - ) + # Each model's exportable signature matches its sample tuple positionally + # (energy / native-spin edge ABI, or virtual-spin nlist ABI), so a single + # splat covers all three contracts. + log.info("Tracing the lower graph on CPU (make_fx)...") + traced = model.forward_common_lower_exportable(*sample_inputs_cpu) # Output key order is taken from a concrete run; Python dict order # is stable and matches what DeepPotPTExpt::extract_outputs zips @@ -809,6 +911,7 @@ def freeze_sezm_to_pt2( sample_out = traced(*sample_inputs_cpu) output_keys = list(sample_out.keys()) + log.info("Exporting the traced graph (torch.export)...") exported = torch.export.export( traced, sample_inputs_cpu, @@ -828,21 +931,33 @@ def freeze_sezm_to_pt2( exported = move_to_device_pass(exported, target_device) out_path_str = str(out_path) - compile_options = build_inductor_compile_options() + compile_options = build_inductor_compile_options(inference=True) # Keep AOTInductor aligned with the eval compile path. ``triton.max_tiles=1`` # keeps data-dependent edge axes on Triton's x grid, whose bound is large # enough for production-scale neighbor lists. + log.info( + "Compiling the AOTInductor package for %s (the slowest freeze stage; " + "typically several minutes)...", + target_device, + ) with inductor_config.patch({**compile_options, "triton.max_tiles": 1}): aoti_compile_and_package(exported, package_path=out_path_str) # Second artifact: the LAMMPS multi-rank with-comm graph. It threads the # eight border_op communication tensors so cross-rank ghost features are - # exchanged between interaction blocks. Excluded for spin (nlist lower - # interface) and bridging models (Source Freeze Propagation is not - # rank-decomposable); those fall back to single-rank inference. - with_comm = (not is_spin) and model.supports_edge_parallel() + # exchanged between interaction blocks. Gated on the edge_vec lower contract + # (energy and native spin), so virtual spin (nlist interface) is excluded; + # bridging models report supports_edge_parallel()=False (Source Freeze + # Propagation is not rank-decomposable). Both fall back to single-rank. + with_comm = ( + model.export_lower_input_kind() == "edge_vec" and model.supports_edge_parallel() + ) with_comm_bytes: bytes | None = None if with_comm: + log.info( + "Compiling the parallel with-comm artifact (second AOTInductor " + "compilation)..." + ) with_comm_bytes = _export_with_comm_artifact( model, target_device=target_device, diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index 2128328b16..dffdbee5e1 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -284,7 +284,7 @@ def _setup_nlist_backend(self, nlist_backend: str) -> None: if callable(self_built) and self_built(): # The model builds its own neighbor list and runs the native path; # an external strategy would bypass it, so always use native. - log.info( + log.debug( "Ignoring nlist_backend=%r: %s uses its own built-in neighbor list.", nlist_backend, type(inner).__name__, diff --git a/deepmd/pt/loss/ener_spin.py b/deepmd/pt/loss/ener_spin.py index 07e3403fdb..8645a1a3b0 100644 --- a/deepmd/pt/loss/ener_spin.py +++ b/deepmd/pt/loss/ener_spin.py @@ -23,6 +23,75 @@ ) +def _masked_force_mag_tensors( + label: dict[str, torch.Tensor], + model_pred: dict[str, torch.Tensor], +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Collect magnetic-force labels and predictions on spin-active atoms. + + Parameters + ---------- + label : dict[str, torch.Tensor] + Batch labels containing ``force_mag``. + model_pred : dict[str, torch.Tensor] + Model outputs containing ``force_mag`` and ``mask_mag``. + + Returns + ------- + label_fm : torch.Tensor + Reference magnetic forces with shape ``(n_mag, 3)``. + pred_fm : torch.Tensor + Predicted magnetic forces with shape ``(n_mag, 3)``. + mag_counts : torch.Tensor + Number of spin-active atoms in each frame, with shape ``(nframes,)``. + """ + atomic_mask = model_pred["mask_mag"].expand(-1, -1, 3) + label_fm = label["force_mag"][atomic_mask].reshape(-1, 3) + pred_fm = model_pred["force_mag"][atomic_mask].reshape(-1, 3) + mag_counts = model_pred["mask_mag"].sum(dim=(1, 2)).to(torch.int64) + return label_fm, pred_fm, mag_counts + + +def _mean_within_segments( + values: torch.Tensor, + segment_lengths: torch.Tensor, +) -> torch.Tensor: + """Reduce ``values`` to a per-segment mean over contiguous frame blocks. + + Parameters + ---------- + values : torch.Tensor + Values laid out in frame order, with shape ``(n_values,)``. + segment_lengths : torch.Tensor + Length of each segment, with shape ``(n_segments,)``. + + Returns + ------- + torch.Tensor + Per-segment means with shape ``(n_segments,)``. Empty segments return + zero so frame-weighted reductions remain finite. + """ + nsegments = segment_lengths.shape[0] + if values.numel() == 0: + return torch.zeros( + nsegments, + dtype=values.dtype, + device=values.device, + ) + + segment_ids = torch.repeat_interleave( + torch.arange(nsegments, device=values.device, dtype=torch.long), + segment_lengths, + ) + totals = torch.zeros(nsegments, dtype=values.dtype, device=values.device) + totals.scatter_add_(0, segment_ids, values) + + means = torch.zeros(nsegments, dtype=values.dtype, device=values.device) + nonempty = segment_lengths > 0 + means[nonempty] = totals[nonempty] / segment_lengths[nonempty].to(values.dtype) + return means + + class EnergySpinLoss(TaskLoss): def __init__( self, @@ -254,14 +323,9 @@ def forward( if self.has_fm and "force_mag" in model_pred and "force_mag" in label: find_force_m = label.get("find_force_mag", 0.0) pref_fm = pref_fm * find_force_m - nframes = model_pred["force_mag"].shape[0] - atomic_mask = model_pred["mask_mag"].expand([-1, -1, 3]) - label_force_mag = label["force_mag"][atomic_mask].view(nframes, -1, 3) - model_pred_force_mag = model_pred["force_mag"][atomic_mask].view( - nframes, -1, 3 - ) + label_fm, pred_fm, mag_counts = _masked_force_mag_tensors(label, model_pred) if self.loss_func == "mse": - diff_fm = label_force_mag - model_pred_force_mag + diff_fm = label_fm - pred_fm l2_force_mag_loss = torch.mean(torch.square(diff_fm)) if not self.inference: more_loss["l2_force_m_loss"] = self.display_if_exist( @@ -284,13 +348,11 @@ def forward( mae_fm.detach(), find_force_m ) elif self.loss_func == "mae": - l1_force_mag_loss = F.l1_loss( - label_force_mag, model_pred_force_mag, reduction="none" - ) + per_atom_l1 = F.l1_loss(label_fm, pred_fm, reduction="none").sum(-1) more_loss["mae_fm"] = self.display_if_exist( - l1_force_mag_loss.mean().detach(), find_force_m + per_atom_l1.mean().detach(), find_force_m ) - l1_force_mag_loss = l1_force_mag_loss.sum(-1).mean(-1).sum() + l1_force_mag_loss = _mean_within_segments(per_atom_l1, mag_counts).sum() loss += (pref_fm * torch.nan_to_num(l1_force_mag_loss)).to( GLOBAL_PT_FLOAT_PRECISION ) diff --git a/deepmd/pt/model/atomic_model/sezm_atomic_model.py b/deepmd/pt/model/atomic_model/sezm_atomic_model.py index 087d1e63f3..201f1a7715 100644 --- a/deepmd/pt/model/atomic_model/sezm_atomic_model.py +++ b/deepmd/pt/model/atomic_model/sezm_atomic_model.py @@ -733,9 +733,8 @@ def _build_dens_fitting_kwargs(self) -> dict[str, Any]: """Reconstruct SeZM `dens`-head kwargs from energy head and descriptor.""" descriptor = self.descriptor kwargs = self._build_ener_fitting_kwargs() - node_l_schedule = getattr(descriptor, "node_l_schedule", descriptor.l_schedule) - kwargs["condition_lmax"] = int(node_l_schedule[0]) - kwargs["latent_lmax"] = int(node_l_schedule[-1]) + kwargs["condition_lmax"] = int(descriptor.node_init_lmax) + kwargs["latent_lmax"] = int(descriptor.node_readout_lmax) kwargs["channels"] = int(descriptor.channels) return kwargs diff --git a/deepmd/pt/model/descriptor/sezm.py b/deepmd/pt/model/descriptor/sezm.py index df4ade51bc..5ec9a1da19 100644 --- a/deepmd/pt/model/descriptor/sezm.py +++ b/deepmd/pt/model/descriptor/sezm.py @@ -51,6 +51,9 @@ from deepmd.dpmodel.utils.seed import ( child_seed, ) +from deepmd.kernels.utils import ( + use_amp_infer, +) from deepmd.pt.utils import ( env, ) @@ -87,6 +90,7 @@ ScalarRMSNorm, SeZMInteractionBlock, SeZMTypeEmbedding, + SpinEmbedding, WignerDCalculator, build_edge_cache, build_edge_cache_from_edges, @@ -212,7 +216,12 @@ class DescrptSeZM(BaseDescriptor, nn.Module): The node degree of block `i` is `l_schedule[i] + extra_node_l`, while SO(2) message passing still uses `l_schedule[i]`. n_blocks - Number of blocks (only used when `l_schedule` is None). + Number of blocks (only used when `l_schedule` is None). ``0`` disables + the interaction blocks and builds the zero-block descriptor: type + embedding, optional env FiLM and geometric initial embedding, then the + final SO(3) read-out. The backbone degree is taken from `lmax` + (plus `extra_node_l`). Geometry then enters only through the GIE, which + is active when `use_env_seed=True` and `lmax + extra_node_l > 0`. so2_norm If True, apply intermediate ReducedEquivariantRMSNorm between SO(2) mixing layers. When False (default), no normalization is applied between layers. @@ -352,8 +361,15 @@ class DescrptSeZM(BaseDescriptor, nn.Module): interaction block, driven by the SO(3) Wigner-D grid, so ``l>0`` geometry is folded into ``l=0`` before the scalar is extracted. The value selects the quadratic grid product (``"glu"``) or the polynomial point-wise grid - MLP (``"mlp"``). The Wigner-D frame order follows ``kmax``. The residual - stays on the ``l=0`` channel. + MLP (``"mlp"``). The Wigner-D frame order follows ``kmax``. + readout_layers + Number of stacked equivariant residual read-out FFNs (default ``1``). + Every layer is an ``x + FFN(x)`` residual block sharing the read-out + degree; intermediate layers keep the full SO(3) tensor so high-degree + geometry is folded into ``l=0`` repeatedly, and only the final layer + slices the ``l=0`` channel from its residual sum. With ``so3_readout`` of + ``"none"`` the stack is a degree-0 scalar residual MLP on the ``l=0`` + slice. lebedev_quadrature Either one boolean applied to both S2 branches, or two booleans ``[so2_enabled, ffn_enabled]`` aligned with ``s2_activation``. If @@ -370,7 +386,8 @@ class DescrptSeZM(BaseDescriptor, nn.Module): FFN always keeps this user-provided value. use_amp If True, use automatic mixed precision (AMP) with bfloat16 on CUDA - during training. This can improve speed and reduce memory usage. + during training. In eval/inference, AMP is opt-in through + ``DP_AMP_INFER``. This can improve speed and reduce memory usage. Enabling this option is recommended on GPUs with native bfloat16 support. Disable it on GPUs without native bfloat16 support to avoid runtime errors or additional conversion overhead. @@ -463,6 +480,7 @@ def __init__( message_node_s2: bool = False, message_node_so3: bool = False, so3_readout: str = "none", + readout_layers: int = 1, lebedev_quadrature: bool | list[bool] | None = True, activation_function: str = "silu", glu_activation: bool = True, @@ -477,6 +495,7 @@ def __init__( inner_clamp_r_outer: float | None = None, add_chg_spin_ebd: bool = False, default_chg_spin: list[float] | None = None, + use_spin: list[bool] | None = None, **kwargs: Any, ) -> None: super().__init__() @@ -553,6 +572,9 @@ def __init__( self.so3_readout = str(so3_readout).lower() if self.so3_readout not in {"none", "glu", "mlp"}: raise ValueError("`so3_readout` must be one of 'none', 'glu', or 'mlp'") + self.readout_layers = int(readout_layers) + if self.readout_layers < 1: + raise ValueError("`readout_layers` must be >= 1") if lebedev_quadrature is None: lebedev_quadrature = [True, True] elif isinstance(lebedev_quadrature, bool): @@ -593,7 +615,8 @@ def __init__( self.compute_dtype = get_promoted_dtype(self.dtype) self.mlp_bias = bool(mlp_bias) self.layer_scale = bool(layer_scale) - self.use_amp = bool(use_amp) # and self.training + self.use_amp = bool(use_amp) + self.use_amp_infer = use_amp_infer() self.trainable = bool(trainable) self.seed = seed self.random_gamma = bool(random_gamma) @@ -606,6 +629,12 @@ def __init__( None if default_chg_spin is None else [float(x) for x in default_chg_spin] ) + # === Native per-atom spin embedding === + # The spin vector enters the descriptor as an l=0 magnitude scalar plus + # an l=1 direction feature (see ``SpinEmbedding``). Providing per-type + # ``use_spin`` flags enables the native spin embedding. + self.use_spin = None if use_spin is None else [bool(x) for x in use_spin] + # === Zone bridging: InnerClamp + Source Freeze Propagation Gate === # Both the geometry clamp (``InnerClamp``) and the message-passing # switch (``BridgingSwitch``) are activated together on the same @@ -656,6 +685,7 @@ def __init__( seed_full_attn = child_seed(self.seed, 5) seed_block_attn = child_seed(self.seed, 6) seed_charge_spin = child_seed(self.seed, 7) + seed_spin_embedding = child_seed(self.seed, 8) # === L/M schedules === self._init_lm_schedules(lmax, n_blocks, l_schedule, mmax, m_schedule) @@ -664,7 +694,6 @@ def __init__( raise ValueError("`kmax` must be non-negative") if self.kmax > self.lmax: raise ValueError("`kmax` must be <= `lmax`") - self.ebed_dims = [get_so3_dim_of_lmax(l) for l in self.l_schedule] self._init_node_l_schedules(extra_node_l) self.rad_sizes_per_block = [l + 1 for l in self.l_schedule] @@ -772,6 +801,30 @@ def __init__( else: self.charge_spin_embedding = None + if self.use_spin is not None: + if self.node_init_lmax < 1: + raise ValueError( + "`use_spin` requires a node degree >= 1 " + "(lmax + extra_node_l) to host the l=1 spin feature." + ) + self.spin_embedding: SpinEmbedding | None = SpinEmbedding( + ntypes=self.ntypes, + channels=self.channels, + use_spin=self.use_spin, + activation_function=self.activation_function, + dtype=self.compute_dtype, # force fp32+ + seed=seed_spin_embedding, + trainable=self.trainable, + ) + # Packed rows hosting the l=1 spin coefficients (m = -1, 0, +1). + self.register_buffer( + "_spin_l1_rows", + torch.arange(1, 4, dtype=torch.long, device=self.device), + persistent=False, + ) + else: + self.spin_embedding = None + # === Env FiLM embedding (optional) === if self.use_env_seed: self.env_seed_embedding: EnvironmentInitialEmbedding | None = ( @@ -786,6 +839,7 @@ def __init__( mlp_bias=self.mlp_bias, activation_function=self.activation_function, eps=self.eps, + use_spin=self.use_spin, dtype=self.compute_dtype, # force fp32+ trainable=self.trainable, seed=seed_env_seed, @@ -845,7 +899,7 @@ def __init__( # GIE and truncated for each SO2Conv block. # radial_mlp specifies hidden layer sizes; input/output layers are prepended/appended. # Use fp32+ precision (same as RBF output) for numerical stability. - radial_out_dim = (self.node_l_schedule[0] + 1) * self.channels + radial_out_dim = (self.node_init_lmax + 1) * self.channels radial_mlp_layers = [self.n_radial, *self.radial_mlp, radial_out_dim] self.radial_embedding = RadialMLP( radial_mlp_layers, @@ -870,22 +924,22 @@ def __init__( ] self._need_full_wigner = not all(block_edge_cartesian) self.wigner_calc = WignerDCalculator( - lmax=self.l_schedule[0], + lmax=self.mp_init_lmax, eps=self.eps, dtype=self.compute_dtype, # force fp32+ ) - self.use_gie = self.use_env_seed and self.node_l_schedule[0] > 0 + self.use_gie = self.use_env_seed and self.node_init_lmax > 0 if self.use_gie: self.gie = GeometricInitialEmbedding( - lmax=self.node_l_schedule[0], + lmax=self.node_init_lmax, channels=self.channels, dtype=self.compute_dtype, # force fp32+ ) if self.extra_node_l > 0: self.gie_zonal_wigner_calc: WignerDCalculator | None = ( WignerDCalculator( - lmax=self.node_l_schedule[0], + lmax=self.node_init_lmax, eps=self.eps, dtype=self.compute_dtype, ) @@ -987,28 +1041,33 @@ def __init__( seed=child_seed(seed_block_attn, 2000), ) - # === Final FFN for l=0 output mixing === - # ``so3_readout="none"`` runs a degree-0 scalar FFN on the l=0 slice. - # ``"glu"``/``"mlp"`` run a full FFN at the last block's node degree whose - # SO(3) Wigner-D grid folds l>0 geometry into l=0; the value selects the - # quadratic grid product or the point-wise grid MLP. - readout_lmax = self.node_l_schedule[-1] - self.output_ffn = EquivariantFFN( - lmax=0 if self.so3_readout == "none" else readout_lmax, - channels=self.channels, - hidden_channels=self.out_ffn_neurons, - kmax=min(self.kmax, readout_lmax), - grid_mlp=self.so3_readout == "mlp", - grid_branch=0, - dtype=self.compute_dtype, - s2_activation=False, - ffn_so3_grid=self.so3_readout != "none", - activation_function=self.out_activation_function, - glu_activation=self.out_glu_activation, - mlp_bias=self.mlp_bias, - trainable=self.trainable, - seed=seed_out, + # === Final FFN stack for l=0 output mixing === + # ``readout_layers`` residual blocks run in sequence (see + # ``_apply_readout``): ``readout_pre_layers`` keep the full SO(3) tensor + # and only the final ``output_ffn`` slices l=0. The final layer keeps the + # ``output_ffn`` name and ``seed_out`` so a single-layer read-out matches + # the single-module checkpoint layout. + readout_lmax = self.node_readout_lmax + readout_ffn_kwargs = { + "lmax": 0 if self.so3_readout == "none" else readout_lmax, + "channels": self.channels, + "hidden_channels": self.out_ffn_neurons, + "kmax": min(self.kmax, readout_lmax), + "grid_mlp": self.so3_readout == "mlp", + "grid_branch": 0, + "dtype": self.compute_dtype, + "s2_activation": False, + "ffn_so3_grid": self.so3_readout != "none", + "activation_function": self.out_activation_function, + "glu_activation": self.out_glu_activation, + "mlp_bias": self.mlp_bias, + "trainable": self.trainable, + } + self.readout_pre_layers = nn.ModuleList( + EquivariantFFN(**readout_ffn_kwargs, seed=child_seed(seed_out, layer_index)) + for layer_index in range(self.readout_layers - 1) ) + self.output_ffn = EquivariantFFN(**readout_ffn_kwargs, seed=seed_out) for p in self.parameters(): p.requires_grad = self.trainable @@ -1046,6 +1105,7 @@ def forward( fparam: torch.Tensor | None = None, force_embedding: torch.Tensor | None = None, charge_spin: torch.Tensor | None = None, + spin: torch.Tensor | None = None, ) -> tuple[ torch.Tensor, torch.Tensor, @@ -1083,7 +1143,7 @@ def forward( force_embedding Optional precomputed equivariant force embedding with shape ``(nf * nloc, D, 1, channels)``, where - ``D = (node_l_schedule[0] + 1) ** 2``. This tensor is added to the + ``D = (node_init_lmax + 1) ** 2``. This tensor is added to the initial SO(3) backbone state before the interaction blocks. charge_spin Frame-level charge and spin conditions with shape (nf, 2). @@ -1124,6 +1184,7 @@ def forward( edge_mask=edge_mask, force_embedding=force_embedding, charge_spin=charge_spin, + spin=spin, ) return ( descriptor, @@ -1166,6 +1227,14 @@ def forward( nloc=nloc, ) + # Native spin: condition the l=0 type features on the spin magnitude + # and hold the l=1 direction coefficients for the backbone seed. + spin_vec = None + if self.spin_embedding is not None and spin is not None: + type_ebed, spin_vec = self._apply_spin_embedding( + type_ebed, spin, atype_loc.reshape(-1), n_nodes=n_nodes + ) + # === Step 4. Build edge cache once (geometry + RBF + Wigner-D) === # Zone bridging (InnerClamp + SFPG + ZBL) is not routed through the # standard DeePMD path: bridging only makes physical sense when @@ -1193,21 +1262,21 @@ def forward( build_wigner=self._need_full_wigner, ) - ebed_dim_0 = self.node_ebed_dims[0] # (node_lmax+1)^2 + ebed_dim_0 = self.node_init_dim # (node_init_lmax+1)^2 x0 = type_ebed # (N, C) x0_out = x0 # (N, C) # === Step 5. Compute radial features once (fp32+) === - # Shape: (E, (node_lmax+1)*C) -> (E, node_lmax+1, C) + # Shape: (E, (node_init_lmax+1)*C) -> (E, node_init_lmax+1, C) radial_feat = None with nvtx_range("radial_embedding"): if edge_cache.src.numel() > 0: radial_feat = rearrange( self.radial_embedding(edge_cache.edge_rbf), "E (L C) -> E L C", - L=self.node_l_schedule[0] + 1, + L=self.node_init_lmax + 1, C=self.channels, - ) # (E, lmax+1, C) + ) # (E, node_init_lmax+1, C) if self.version >= 1.1: radial_feat = radial_feat * edge_cache.edge_env.reshape(-1, 1, 1) @@ -1215,10 +1284,16 @@ def forward( with nvtx_range("env_film"): if self.use_env_seed and edge_cache.src.numel() > 0: atype_flat = atype_loc.reshape(-1) # (N,) + spin_flat = ( + spin.reshape(n_nodes, 3) + if (self.spin_embedding is not None and spin is not None) + else None + ) film = self.env_seed_embedding( edge_cache=edge_cache, atype_flat=atype_flat, n_nodes=n_nodes, + spin=spin_flat, ) # (N, 2*C) scale_logits = film[:, : self.channels] # (N, C) shift_logits = film[:, self.channels :] # (N, C) @@ -1234,19 +1309,35 @@ def forward( x = type_ebed.new_zeros(n_nodes, ebed_dim_0, 1, self.channels) # (N, D, 1, C) x[:, 0, 0, :] = x0_out - # === Step 8. Geometric Initial Embedding (fp32+) === + # === Step 8. Geometric Initial Embedding (+ neighbor spin l=1) === with nvtx_range("gie"): if self.use_gie and radial_feat is not None: # GIE only needs l>=1, slice radial_feat[:, 1:, :] zonal_coupling = self._build_gie_zonal_coupling(edge_cache) + spin_l1_message = ( + self.spin_embedding.edge_l1( + spin.reshape(n_nodes, 3), + atype_loc.reshape(-1), + edge_cache, + ) + if (self.spin_embedding is not None and spin is not None) + else None + ) x = x + self.gie( n_nodes=n_nodes, edge_cache=edge_cache, radial_feat=radial_feat[:, 1:, :], zonal_coupling=zonal_coupling, + spin_l1_message=spin_l1_message, ).unsqueeze(2) - # === Step 9. Fuse edge type features into radial features (fp32+) === + # === Step 9. Add the on-site native spin l=1 to the backbone === + # The neighbor-spin l=1 is aggregated inside GIE (degree-normalized like + # the geometry); the atom's own spin direction is added here, un-normalized. + if spin_vec is not None: + x = x.index_add(1, self._spin_l1_rows, spin_vec.unsqueeze(2)) + + # === Step 10. Fuse edge type features into radial features (fp32+) === with nvtx_range("radial_fuse"): if radial_feat is not None: radial_feat = radial_feat + rearrange( @@ -1259,36 +1350,24 @@ def forward( else: rad_feat_per_block = [] - # === Step 10. Convert to self.dtype and run blocks === + # === Step 11. Convert to self.dtype and run blocks === + # The block stage is skipped entirely when there are no interaction + # blocks (zero-block descriptor) or no valid edges, sparing the working + # edge-cache dtype cast that only the blocks consume. with nvtx_range("blocks"): x = x.to(dtype=self.dtype) # (N, D, 1, C) if force_embedding is not None: x = x + force_embedding.to(dtype=self.dtype) - if edge_cache.src.numel() > 0: + if self.blocks and edge_cache.src.numel() > 0: edge_cache = edge_cache_to_dtype(edge_cache, self.dtype) with self._compute_mode_ctx(extended_coord.device): x = self._forward_blocks(x, edge_cache, rad_feat_per_block) - # === Step 11. Final l=0 output mixing === - # ``none`` feeds the l=0 slice only; ``glu``/``mlp`` feed the full - # (N, D, 1, C) node tensor so the SO(3) grid folds l>0 into l=0. The - # residual is added on the full coefficient tensor before extracting - # l=0: slicing the summed tensor rather than the FFN output keeps the - # saved degree-axis stride static under torch.compile dynamic shapes. + # === Step 12. Final l=0 output mixing === with nvtx_range("output_ffn"): - ffn_in = ( - x[:, 0:1, :, :] - .reshape(n_nodes, 1, 1, self.channels) - .to(dtype=self.compute_dtype) - if self.so3_readout == "none" - # truncate to the final node degree: the empty-edge path - # skips the blocks, leaving x at node_ebed_dims[0]; output_ffn - # is built for node_ebed_dims[-1]. No-op when blocks ran. - else x[:, : self.node_ebed_dims[-1], :, :].to(dtype=self.compute_dtype) - ) - x_scalar = (ffn_in + self.output_ffn(ffn_in))[:, 0:1, :, :] + x_scalar = self._apply_readout(x, n_nodes) - # === Step 12. Reshape to (nf, nloc, channels) and return === + # === Step 13. Reshape to (nf, nloc, channels) and return === descriptor = rearrange( x_scalar, "(nf nloc) 1 1 C -> nf nloc C", nf=nf, nloc=nloc ) # (nf, nloc, C) @@ -1310,6 +1389,7 @@ def forward_with_edges( edge_mask: torch.Tensor, force_embedding: torch.Tensor | None = None, charge_spin: torch.Tensor | None = None, + spin: torch.Tensor | None = None, comm_dict: dict[str, torch.Tensor] | None = None, nloc: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: @@ -1343,7 +1423,7 @@ def forward_with_edges( force_embedding Optional precomputed equivariant force embedding with shape ``(nf * nloc, D, 1, channels)``, where - ``D = (node_l_schedule[0] + 1) ** 2``. This tensor is added to the + ``D = (node_init_lmax + 1) ** 2``. This tensor is added to the initial SO(3) backbone state before the interaction blocks. charge_spin Frame-level charge and spin conditions with shape (nf, 2). @@ -1399,6 +1479,14 @@ def forward_with_edges( ) n_nodes = type_ebed.shape[0] + # Native spin: condition the l=0 type features on the spin magnitude + # and hold the l=1 direction coefficients for the backbone seed. + spin_vec = None + if self.spin_embedding is not None and spin is not None: + type_ebed, spin_vec = self._apply_spin_embedding( + type_ebed, spin, atype_flat, n_nodes=n_nodes + ) + # === Step 3. Build edge cache once (sparse edges) === with nvtx_range("build_edge_cache"): edge_cache = build_edge_cache_from_edges( @@ -1425,7 +1513,7 @@ def forward_with_edges( build_wigner=self._need_full_wigner, ) - ebed_dim_0 = self.node_ebed_dims[0] # (node_lmax+1)^2 + ebed_dim_0 = self.node_init_dim # (node_init_lmax+1)^2 x0 = type_ebed # (N, C) x0_out = x0 # (N, C) @@ -1434,19 +1522,25 @@ def forward_with_edges( radial_feat_flat = self.radial_embedding(edge_cache.edge_rbf) radial_feat = radial_feat_flat.reshape( radial_feat_flat.shape[0], - self.node_l_schedule[0] + 1, + self.node_init_lmax + 1, self.channels, - ) # (E, lmax+1, C) + ) # (E, node_init_lmax+1, C) if self.version >= 1.1: radial_feat = radial_feat * edge_cache.edge_env.reshape(-1, 1, 1) # === Step 5. Env FiLM conditioning (optional, fp32+) === with nvtx_range("env_film"): if self.use_env_seed: + spin_flat = ( + spin.reshape(n_nodes, 3) + if (self.spin_embedding is not None and spin is not None) + else None + ) film = self.env_seed_embedding( edge_cache=edge_cache, atype_flat=atype_flat, n_nodes=n_nodes, + spin=spin_flat, ) # (N, 2*C) scale_logits = film[:, : self.channels] # (N, C) shift_logits = film[:, self.channels :] # (N, C) @@ -1462,18 +1556,32 @@ def forward_with_edges( x = type_ebed.new_zeros(n_nodes, ebed_dim_0, 1, self.channels) # (N, D, 1, C) x[:, 0, 0, :] = x0_out - # === Step 7. Geometric Initial Embedding (fp32+) === + # === Step 7. Geometric Initial Embedding (+ neighbor spin l=1) === with nvtx_range("gie"): if self.use_gie: zonal_coupling = self._build_gie_zonal_coupling(edge_cache) + spin_l1_message = ( + self.spin_embedding.edge_l1( + spin.reshape(n_nodes, 3), atype_flat, edge_cache + ) + if (self.spin_embedding is not None and spin is not None) + else None + ) x = x + self.gie( n_nodes=n_nodes, edge_cache=edge_cache, radial_feat=radial_feat[:, 1:, :], zonal_coupling=zonal_coupling, + spin_l1_message=spin_l1_message, ).unsqueeze(2) - # === Step 8. Fuse edge type features into radial features (fp32+) === + # === Step 8. Add the on-site native spin l=1 to the backbone === + # The neighbor-spin l=1 is aggregated inside GIE; the + # atom's own spin direction is added here, un-normalized. + if spin_vec is not None: + x = x.index_add(1, self._spin_l1_rows, spin_vec.unsqueeze(2)) + + # === Step 9. Fuse edge type features into radial features (fp32+) === with nvtx_range("radial_fuse"): radial_feat = radial_feat.to(dtype=self.dtype) radial_feat = radial_feat + rearrange( @@ -1483,18 +1591,21 @@ def forward_with_edges( radial_feat[:, :rad_len, :] for rad_len in self.rad_sizes_per_block ] - # === Step 9. Convert to self.dtype and run blocks === + # === Step 10. Convert to self.dtype and run blocks === + # The block stage is skipped entirely for the zero-block descriptor, + # sparing the working edge-cache dtype cast that only the blocks consume. with nvtx_range("blocks"): x = x.to(dtype=self.dtype) # (N, D, 1, C) if force_embedding is not None: x = x + force_embedding.to(dtype=self.dtype) - edge_cache = edge_cache_to_dtype(edge_cache, self.dtype) - with self._compute_mode_ctx(extended_coord.device): - x = self._forward_blocks( - x, edge_cache, rad_feat_per_block, comm_dict=comm_dict - ) + if self.blocks: + edge_cache = edge_cache_to_dtype(edge_cache, self.dtype) + with self._compute_mode_ctx(extended_coord.device): + x = self._forward_blocks( + x, edge_cache, rad_feat_per_block, comm_dict=comm_dict + ) - # === Step 10. Keep the owned-atom rows for the read-out === + # === Step 11. Keep the owned-atom rows for the read-out === # ``n_out_nodes`` is the owned-node count in the flattened layout # (``nf * nloc``). Single-domain: ``out_nloc == n_per_frame``, so this # equals the whole node set and the slice is a no-op. Parallel @@ -1503,26 +1614,11 @@ def forward_with_edges( n_out_nodes = nf * out_nloc x = x[:n_out_nodes] - # === Step 11. Final l=0 output mixing === - # ``none`` feeds the l=0 slice only; ``glu``/``mlp`` feed the full - # (N, D, 1, C) node tensor so the SO(3) grid folds l>0 into l=0. The - # residual is added on the full coefficient tensor before extracting - # l=0: slicing the summed tensor rather than the FFN output keeps the - # saved degree-axis stride static under torch.compile dynamic shapes. + # === Step 12. Final l=0 output mixing === with nvtx_range("output_ffn"): - ffn_in = ( - x[:, 0:1, :, :] - .reshape(n_out_nodes, 1, 1, self.channels) - .to(dtype=self.compute_dtype) - if self.so3_readout == "none" - # truncate to the final node degree: the empty-edge path - # skips the blocks, leaving x at node_ebed_dims[0]; output_ffn - # is built for node_ebed_dims[-1]. No-op when blocks ran. - else x[:, : self.node_ebed_dims[-1], :, :].to(dtype=self.compute_dtype) - ) - x_scalar = (ffn_in + self.output_ffn(ffn_in))[:, 0:1, :, :] + x_scalar = self._apply_readout(x, n_out_nodes) - # === Step 12. Reshape to (nf, nloc, channels) and return === + # === Step 13. Reshape to (nf, nloc, channels) and return === descriptor = x_scalar.reshape(nf, out_nloc, self.channels) # (nf, nloc, C) return descriptor.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), x.contiguous() @@ -1601,7 +1697,7 @@ def node_l0_extractor(v: torch.Tensor) -> torch.Tensor: x = block_output # === Step 3. Final aggregation over all completed unit representations === - final_dim = self.node_ebed_dims[-1] + final_dim = self.node_readout_dim final_sources = [source[:, :final_dim, :, :] for source in unit_history] x = self.final_full_attn_res( sources=final_sources, @@ -1633,7 +1729,7 @@ def node_l0_extractor(v: torch.Tensor) -> torch.Tensor: x = block_output # === Step 3. Final aggregation over all completed block summaries === - final_dim = self.node_ebed_dims[-1] + final_dim = self.node_readout_dim final_sources = [source[:, :final_dim, :, :] for source in block_history] x = self.final_block_attn_res( sources=final_sources, @@ -1642,6 +1738,46 @@ def node_l0_extractor(v: torch.Tensor) -> torch.Tensor: ).to(dtype=self.dtype) return x + def _apply_readout(self, x: torch.Tensor, n_rows: int) -> torch.Tensor: + """Fold the node tensor into the scalar (``l=0``) descriptor. + + Runs the ``readout_layers`` stack of equivariant residual read-out FFNs. + ``so3_readout="none"`` feeds only the ``l=0`` slice; ``"glu"``/``"mlp"`` + feed the full ``(N, D, 1, C)`` node tensor so the SO(3) grid folds + ``l>0`` geometry into ``l=0``. Each layer is an ``x + FFN(x)`` residual: + the ``readout_pre_layers`` keep the full tensor so the geometry keeps + folding, while the final ``output_ffn`` slices the ``l=0`` channel from + its residual sum. Slicing the summed tensor rather than the FFN output + keeps the saved degree-axis stride static under ``torch.compile`` dynamic + shapes. + + Parameters + ---------- + x + Node features with shape ``(n_rows, D, 1, channels)``. With the + blocks skipped (zero-block or empty-edge path) ``D`` is the initial + degree; otherwise the pyramid has shrunk it, so the read-out slice to + ``node_readout_dim`` is a no-op there. + n_rows + Number of node rows fed to the read-out. + + Returns + ------- + torch.Tensor + Scalar descriptor with shape ``(n_rows, 1, 1, channels)``. + """ + if self.so3_readout == "none": + x_ro = ( + x[:, 0:1, :, :] + .reshape(n_rows, 1, 1, self.channels) + .to(dtype=self.compute_dtype) + ) + else: + x_ro = x[:, : self.node_readout_dim, :, :].to(dtype=self.compute_dtype) + for layer in self.readout_pre_layers: + x_ro = x_ro + layer(x_ro) + return (x_ro + self.output_ffn(x_ro))[:, 0:1, :, :] + def _edge_quaternion(self, edge_cache: EdgeFeatureCache) -> torch.Tensor: """ Return the cached global->local edge quaternion, rebuilding if absent. @@ -1688,7 +1824,7 @@ def _build_gie_zonal_coupling( return calc.forward_zonal(self._edge_quaternion(edge_cache), lmin=1) if self.gie_zonal_wigner_calc is None: return None - mp_row_count = self.ebed_dims[0] - 1 + mp_row_count = self.mp_init_dim - 1 mp_row_index = self.gie.non_scalar_row_index[:mp_row_count] mp_m0_col_index = self.gie.zonal_m0_col_index_for_row[:mp_row_count] mp_coupling = edge_cache.Dt_full[ @@ -1733,6 +1869,44 @@ def _apply_charge_spin_embedding( condition = condition[:, None, :].expand(nf, nloc, self.channels) return type_ebed + condition.reshape_as(type_ebed) + def _apply_spin_embedding( + self, + type_ebed: torch.Tensor, + spin: torch.Tensor, + atype_flat: torch.Tensor, + *, + n_nodes: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Inject the per-atom spin embedding into the node features. + + The l=0 magnitude scalar is added to the flattened type embedding so it + propagates into the scalar backbone, the per-edge type features, and + every block's radial features (exactly like the type embedding). The l=1 + direction coefficients are returned for the caller to add to the + equivariant backbone after the geometric initial embedding. + + Parameters + ---------- + type_ebed + Flattened type embedding with shape (N, channels). + spin + Per-atom spin vectors with shape (nf, nloc, 3) or (N, 3). + atype_flat + Flattened local atom types with shape (N,). + n_nodes + Number of local nodes ``N = nf * nloc``. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + The l=0-conditioned type embedding with shape (N, channels) and the + packed l=1 direction coefficients with shape (N, 3, channels). + """ + scalar, vector = self.spin_embedding(spin.reshape(n_nodes, 3), atype_flat) + type_ebed = type_ebed + scalar.to(dtype=type_ebed.dtype) + return type_ebed, vector + def _edge_type_keep_mask( self, atype_flat: torch.Tensor, @@ -1819,14 +1993,19 @@ def _init_lm_schedules( mmax: int | None, m_schedule: list[int] | None, ) -> None: - """Parse and validate L/M schedules, setting self.l_schedule/m_schedule/lmax/mmax.""" + """Parse and validate L/M schedules, setting self.l_schedule/m_schedule/lmax/mmax. + + An empty schedule (``n_blocks=0`` or ``l_schedule=[]``) is valid and + selects the zero-block descriptor: no interaction blocks are built, only + the initial SO(3) backbone (type embedding, optional env FiLM and GIE) + followed by the final read-out. The backbone degree then derives from + the configured ``lmax``/``mmax`` instead of the schedule endpoints. + """ # === L schedule === if l_schedule is None: self.l_schedule = [int(lmax)] * int(n_blocks) else: self.l_schedule = [int(x) for x in l_schedule] - if len(self.l_schedule) == 0: - raise ValueError("`l_schedule` must be non-empty") if any(x < 0 for x in self.l_schedule): raise ValueError("`l_schedule` entries must be non-negative") if any( @@ -1835,7 +2014,9 @@ def _init_lm_schedules( ): raise ValueError("`l_schedule` must be non-increasing (pyramid schedule)") - self.lmax = int(self.l_schedule[0]) + # The first entry sets the maximum degree; with zero blocks the backbone + # degree falls back to the configured ``lmax``. + self.lmax = int(self.l_schedule[0]) if self.l_schedule else int(lmax) self.n_blocks = len(self.l_schedule) # === M schedule === @@ -1849,8 +2030,6 @@ def _init_lm_schedules( self.m_schedule = [min(mmax_i, int(l)) for l in self.l_schedule] else: self.m_schedule = [int(x) for x in m_schedule] - if len(self.m_schedule) == 0: - raise ValueError("`m_schedule` must be non-empty") if len(self.m_schedule) != len(self.l_schedule): raise ValueError("`m_schedule` must have the same length as `l_schedule`") if any(x < 0 for x in self.m_schedule): @@ -1860,10 +2039,30 @@ def _init_lm_schedules( "`m_schedule` entries must satisfy `m_schedule[i] <= l_schedule[i]`" ) - self.mmax = int(self.m_schedule[0]) + self.mmax = ( + int(self.m_schedule[0]) + if self.m_schedule + else (int(mmax) if mmax is not None else int(self.lmax)) + ) def _init_node_l_schedules(self, extra_node_l: int) -> None: - """Parse node degree schedules derived from message-passing schedules.""" + """Parse node degree schedules and resolve the canonical backbone degrees. + + The descriptor references three backbone degrees that must stay valid + even with zero interaction blocks, so they are resolved here into + scalars rather than indexed off the (possibly empty) schedules: + + - ``mp_init_lmax`` : message-passing degree at initialization, driving + the Wigner-D calculator and the GIE message-passing coupling rows. + - ``node_init_lmax`` : node backbone degree at initialization, driving + the radial-embedding width, the initial state dimension, and GIE. + - ``node_readout_lmax`` : node backbone degree fed to the read-out FFN. + + With blocks these equal ``l_schedule[0]``, ``node_l_schedule[0]`` and + ``node_l_schedule[-1]``; with zero blocks all three collapse onto the + configured ``lmax`` (plus ``extra_node_l`` on the node side), so the + pyramid endpoints are never read from an empty list. + """ self.extra_node_l = int(extra_node_l) if self.extra_node_l < 0: raise ValueError("`extra_node_l` must be non-negative") @@ -1873,8 +2072,16 @@ def _init_node_l_schedules(self, extra_node_l: int) -> None: self.node_ebed_dims = [ get_so3_dim_of_lmax(l_value) for l_value in self.node_l_schedule ] - self.node_lmax = int(self.node_l_schedule[0]) - self.node_ebed_dim = int(self.node_ebed_dims[0]) + + # === Canonical backbone degrees (valid for any block count) === + self.mp_init_lmax = int(self.lmax) + self.node_init_lmax = int(self.lmax) + self.extra_node_l + self.node_readout_lmax = ( + int(self.node_l_schedule[-1]) if self.n_blocks > 0 else self.node_init_lmax + ) + self.mp_init_dim = get_so3_dim_of_lmax(self.mp_init_lmax) + self.node_init_dim = get_so3_dim_of_lmax(self.node_init_lmax) + self.node_readout_dim = get_so3_dim_of_lmax(self.node_readout_lmax) def _canonicalize_charge_spin( self, @@ -1964,21 +2171,28 @@ def _compute_mode_ctx(self, device: torch.device) -> Generator[None, None, None] Notes ----- - - When `use_amp=True` and the model is in training mode, enables - torch.autocast with bfloat16 on CUDA. This can improve speed and - reduce memory usage on GPUs with native bfloat16 support. + - When `use_amp=True`, enables torch.autocast with bfloat16 on CUDA + during training. Eval/inference enables the same autocast region only + when ``DP_AMP_INFER`` was truthy at construction time. + This can improve speed and reduce memory usage on GPUs with native + bfloat16 support. Disable AMP on GPUs without native bfloat16 support to avoid runtime errors or additional conversion overhead. - Only affects autocast-eligible operations. - - Does nothing during inference (`self.training=False`), on non-CUDA - devices, or when `use_amp=False`. + - Does nothing during inference (`self.training=False`) unless + ``DP_AMP_INFER`` is enabled, on non-CUDA devices, or when + `use_amp=False`. Yields ------ None Runs the wrapped region under the configured AMP setting. """ - if not self.use_amp or device.type != "cuda" or not self.training: + if ( + not self.use_amp + or device.type != "cuda" + or (not self.training and not self.use_amp_infer) + ): yield return @@ -2255,6 +2469,7 @@ def serialize(self) -> dict[str, Any]: "message_node_s2": self.message_node_s2, "message_node_so3": self.message_node_so3, "so3_readout": self.so3_readout, + "readout_layers": self.readout_layers, "lebedev_quadrature": self.lebedev_quadrature, "activation_function": self.activation_function, "glu_activation": self.glu_activation, @@ -2268,6 +2483,7 @@ def serialize(self) -> dict[str, Any]: "inner_clamp_r_outer": self.inner_clamp_r_outer, "add_chg_spin_ebd": self.add_chg_spin_ebd, "default_chg_spin": self.default_chg_spin, + "use_spin": self.use_spin, }, "@variables": {key: np_safe(value) for key, value in state.items()}, "env_mat": DPEnvMat(self.rcut, self.rcut, self.eps).serialize(), diff --git a/deepmd/pt/model/descriptor/sezm_nn/__init__.py b/deepmd/pt/model/descriptor/sezm_nn/__init__.py index ddbf9c959b..aae88f5f86 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/__init__.py +++ b/deepmd/pt/model/descriptor/sezm_nn/__init__.py @@ -44,6 +44,7 @@ EnvironmentInitialEmbedding, GeometricInitialEmbedding, SeZMTypeEmbedding, + SpinEmbedding, ) from .ffn import ( EquivariantFFN, @@ -173,6 +174,7 @@ "SeZMDirectForceHead", "SeZMInteractionBlock", "SeZMTypeEmbedding", + "SpinEmbedding", "SwiGLU", "WignerDCalculator", "apply_lora_to_sezm", diff --git a/deepmd/pt/model/descriptor/sezm_nn/activation.py b/deepmd/pt/model/descriptor/sezm_nn/activation.py index d732680777..a19eb4c498 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/activation.py +++ b/deepmd/pt/model/descriptor/sezm_nn/activation.py @@ -87,7 +87,8 @@ class GatedActivation(nn.Module): Whether to use bias in the gate linear layer. layout Tensor layout convention. ``"nfdc"`` means input shape (N, F, D, C); - ``"ndfc"`` means input shape (N, D, F, C). + ``"ndfc"`` means input shape (N, D, F, C); ``"fndc"`` means input shape + (F, N, D, C), the focus-major layout used by the SO(2) mixing stack. trainable Whether parameters are trainable. seed @@ -123,8 +124,8 @@ def __init__( self.precision = RESERVED_PRECISION_DICT[dtype] self.mlp_bias = bool(mlp_bias) self.layout = str(layout).lower() - if self.layout not in {"nfdc", "ndfc"}: - raise ValueError("`layout` must be either 'nfdc' or 'ndfc'") + if self.layout not in {"nfdc", "ndfc", "fndc"}: + raise ValueError("`layout` must be one of 'nfdc', 'ndfc', or 'fndc'") self.scalar_act = ActivationFn(activation_function) @@ -169,7 +170,8 @@ def forward( ---------- x Value features. Shape is (N, F, D, C) when ``layout='nfdc'``, - or (N, D, F, C) when ``layout='ndfc'``. + (N, D, F, C) when ``layout='ndfc'``, or (F, N, D, C) when + ``layout='fndc'``. gate Optional gate features with the same layout as ``x``. When provided, enables GLU mode: @@ -182,6 +184,10 @@ def forward( torch.Tensor Gated features with the same layout as ``x``. """ + # ``ndfc`` carries the degree axis at position 1; ``nfdc`` and the + # focus-major ``fndc`` carry it at position 2. Every select/narrow/reshape + # below is expressed against this single degree axis, so the three layouts + # share one code path apart from the per-focus gate projection. degree_axis = 1 if self.layout == "ndfc" else 2 if gate is not None: @@ -200,9 +206,15 @@ def forward( return x0 input_dtype = gate_scalar_source.dtype - gating_scalars = torch.sigmoid( - self.gate_linear(gate_scalar_source.to(dtype=self.dtype)) - ).to(dtype=input_dtype) + gate_src = gate_scalar_source.to(dtype=self.dtype) + if self.layout == "fndc": + # The scalar source is focus-major (F, N, C). ``FocusLinear`` mixes + # channels with the focus stream on axis 1, so present it in the shared + # (N, F, C) convention and restore the focus-major orientation. + gate_logits = self.gate_linear(gate_src.transpose(0, 1)).transpose(0, 1) + else: + gate_logits = self.gate_linear(gate_src) + gating_scalars = torch.sigmoid(gate_logits).to(dtype=input_dtype) gating_scalars = gating_scalars.reshape( x.shape[0], gate_scalar_source.shape[1], self.lmax, self.channels ) diff --git a/deepmd/pt/model/descriptor/sezm_nn/cute/__init__.py b/deepmd/pt/model/descriptor/sezm_nn/cute/__init__.py deleted file mode 100644 index 63f146cbd9..0000000000 --- a/deepmd/pt/model/descriptor/sezm_nn/cute/__init__.py +++ /dev/null @@ -1,36 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -""" -CuTe-DSL accelerated SO(2) rotation operators for SeZM / DPA4. - -This package provides a self-contained, ``torch.compile``-friendly implementation -of the two fused gather + batched-GEMM operators used by the SeZM SO(2) edge -convolution: - -* ``rotate_to_local`` : ``out[e] = wigner[e][coeff_index] @ x[src[e]]`` -* ``rotate_back`` : ``out[e] = wigner[e][:, coeff_index] @ x_local[e]`` - -The kernels are written with the NVIDIA CuTe DSL (``cutlass.cute``) and fuse the -Wigner-row/column gather and the source-node gather directly into the matmul, so -the large ``D_to_m`` / ``x_src`` intermediates are never materialized. They are -exposed through the modern ``torch.library.custom_op`` API (functional, with -``register_fake`` + ``register_autograd``) so that they compose correctly with -``torch.compile`` and autograd. - -The top-level entry points are re-exported here for convenience. -""" - -from __future__ import ( - annotations, -) - -from .so2_rotation import ( - SEZM_CUTE_AVAILABLE, - rotate_back_cute, - rotate_to_local_cute, -) - -__all__ = [ - "SEZM_CUTE_AVAILABLE", - "rotate_back_cute", - "rotate_to_local_cute", -] diff --git a/deepmd/pt/model/descriptor/sezm_nn/cute/so2_rotation.py b/deepmd/pt/model/descriptor/sezm_nn/cute/so2_rotation.py deleted file mode 100644 index 9af65aaaac..0000000000 --- a/deepmd/pt/model/descriptor/sezm_nn/cute/so2_rotation.py +++ /dev/null @@ -1,918 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -# pyright: reportMissingImports=false -# ruff: noqa: ANN001 -""" -CuTe-DSL fused SO(2) rotation kernels for SeZM / DPA4. - -Status and benchmark conclusion -------------------------------- -This implementation is experimental and is **not** wired into the production -SO(2) convolution. The shipping accelerated rotation path uses the Triton -block-diagonal kernels in ``sezm_nn/triton/so2_rotation.py`` (enabled by -``DP_TRITON_INFER``); this module is retained for reference and further -experiments. - -In head-to-head benchmarks against the compiled dense ``bmm`` and the Triton -kernels, the CuTe path had the best peak memory (roughly 2-4x lower than the -compiled dense path, lower than Triton) and won the forward pass, but its -``rotate_back`` backward -- and the forward+backward at large ``lmax`` (~10) -- -were slower than cuBLAS. The Triton block-diagonal kernels were chosen for -production because their speed (2-8x over the dense baseline) and native -``torch.compile`` composability outweigh the CuTe memory advantage in the target -``lmax`` 2-5, ``mmax == 1`` regime. - -Operator definitions (ground truth, fp32) ------------------------------------------ -Let ``x`` be packed node features, ``src`` the per-edge source-node indices, -``wigner`` the per-edge block-diagonal Wigner-D matrices, ``coeff_index`` the -``m``-major reduced-layout indices and ``dim_full = D`` the full packed SO(3) -dimension (``D <= Dw`` where ``wigner`` is ``(E, Dw, Dw)``). - -``rotate_to_local`` lifts global node features into the per-edge local frame and -truncates to the reduced layout in one fused step:: - - out[e, i, c] = sum_j wigner[e, coeff_index[i], j] * x[src[e], j, c] - # i in [0, Dm), j in [0, D), c in [0, C) - -``rotate_back`` is the (column-selected) inverse rotation:: - - out[e, i, c] = sum_j wigner[e, i, coeff_index[j]] * x_local[e, j, c] - # i in [0, D), j in [0, Dm), c in [0, C) - -Both operators are batched (one tiny GEMM per edge) with two gathers fused in: -the Wigner row/column selection by ``coeff_index`` and the source-node gather by -``src``. Fusing the gathers means the large ``D_to_m`` ``(E, Dm, D)`` and -``x_src`` ``(E, D, C)`` intermediates produced by the eager ``index_select`` + -``bmm`` reference are never written to or read from global memory, which is the -main source of the speed/peak-memory advantage. - -Backward (both feature *and* ``wigner`` gradients, required for forces) ----------------------------------------------------------------------- -``rotate_to_local``:: - - grad_edge[e, j, c] = sum_i wigner[e, coeff_index[i], j] * grad_out[e, i, c] - grad_x = scatter_add(grad_edge, dim=0, index=src) # (N, D, C) - grad_wigner[e, coeff_index[i], j] = sum_c grad_out[e, i, c] * x[src[e], j, c] - -``rotate_back``:: - - grad_x_local[e, j, c] = sum_i wigner[e, i, coeff_index[j]] * grad_out[e, i, c] - grad_wigner[e, i, coeff_index[j]] = sum_c grad_out[e, i, c] * x_local[e, j, c] - -(all other entries of ``grad_wigner`` are zero). - -Kernel design -------------- -Every kernel computes, per edge, a small matrix product ``out = A @ B`` (with -one operand gathered) using a **2D register-blocked GEMM**: - -* one CUDA block per edge; -* the operand whose layout is ``(K, C)`` (the source-node / local / grad_out - tile) is staged once into shared memory; -* each thread owns a ``TM x TN`` register tile of the output and sweeps the - contraction dimension ``K``, loading ``TM`` values of ``A`` and ``TN`` values - of ``B`` per step and issuing ``TM*TN`` FFMAs. This pushes the load:FFMA ratio - to ``(TM+TN)/(TM*TN)`` so the kernel is compute-bound rather than - load/store-unit bound; -* the per-output-row Wigner index gather (``coeff_index``) is hoisted out of the - contraction loop into registers. - -The two ``grad_wigner`` kernels are batched outer products (contraction over the -channel axis ``C``) and use the same register-blocked skeleton with a 2D tile -sweep over the ``(Dm, D)`` output. When both per-edge operands fit in shared -memory (small/medium ``lmax``) both are staged there; otherwise only -``grad_out`` is staged and the other operand streams from global memory through -L1. The ``rotate_to_local`` ``grad_x`` contribution is fused with its -source-node scatter via atomic adds, so neither a ``grad_edge`` intermediate nor -a separate ``index_add`` is materialized. - -All accumulation is fp32 (no TF32), keeping the potential-energy surface smooth. - -Composability -------------- -The kernels are wrapped with ``torch.library.custom_op`` (functional, -``mutates_args=()``) plus ``register_fake`` and ``register_autograd``. The -backward is itself a custom op, so ``torch.compile`` can include and -differentiate the whole thing as an opaque, side-effect-free operator. Kernels -are launched on torch's current CUDA stream so they order correctly with the -surrounding eager / compiled graph. -""" - -from __future__ import ( - annotations, -) - -import threading -from typing import ( - TYPE_CHECKING, - Any, -) - -import torch -from torch import ( - Tensor, -) - -if TYPE_CHECKING: - from collections.abc import ( - Callable, - ) - - from cuda.bindings import driver as _cuda_driver - -try: - import cutlass - import cutlass.cute as cute - import cutlass.torch as cutlass_torch - from cutlass.cute.runtime import ( - from_dlpack, - ) - - SEZM_CUTE_AVAILABLE = True -except Exception: # pragma: no cover - import guard for non-CuTe environments - SEZM_CUTE_AVAILABLE = False - - -# === Kernel tuning constants ================================================= -# Register-tile dimensions (TM output rows x TN output cols per thread) and the -# block thread geometry. ``C`` (= 64) is the channel axis and is the N dimension -# for the matmul-like kernels; ``TN`` divides ``C``. -_TM = 4 -_TN = 4 -_BLOCK_ROWS = 16 # block.y for matmul-like kernels (block.x = C // TN) - -# grad_wigner: budget (bytes) below which both operands are staged in shared -# memory (the fast path); above it (e.g. lmax=10) only grad_out is staged and -# the other operand streams from global memory through L1. -_GW_SMEM_BUDGET = 46000 - - -def _gw_tile(D: int, Dm: int, C: int) -> tuple[int, int, int, int, bool]: - """Pick (TM, TN, BX, BY, both_in_smem) for a grad_wigner output of (M, N). - - The register tile and block geometry are chosen so the block is well - occupied for the given output size, and both operands are staged in shared - memory when the per-edge tiles fit inside ``_GW_SMEM_BUDGET``. - """ - both = (Dm + D) * C * 4 <= _GW_SMEM_BUDGET - if D <= 20: # small output (e.g. lmax=3): keep the tile/block small - return 2, 2, 16, 16, both - if D <= 50: # medium output (e.g. lmax=5) - return 8, 8, 8, 8, both - if both: # large output that still fits both operands (e.g. lmax=7) - return 8, 4, 8, 16, both - return 8, 8, 8, 8, both # large output, only grad_out staged (e.g. lmax=10) - - -# === Eager reference (ground truth, also used as fallback) =================== -def _rotate_to_local_eager( - x: Tensor, src: Tensor, wigner: Tensor, coeff_index: Tensor, dim_full: int -) -> Tensor: - """Reference ``D_to_m @ x[src]`` used for fallback and validation.""" - d_to_m = wigner[:, :dim_full, :dim_full].index_select(1, coeff_index) - return torch.bmm(d_to_m, x.index_select(0, src)) - - -def _rotate_back_eager( - x_local: Tensor, wigner: Tensor, coeff_index: Tensor, dim_full: int -) -> Tensor: - """Reference ``Dt_from_m @ x_local`` used for fallback and validation.""" - dt_from_m = wigner[:, :dim_full, :dim_full].index_select(2, coeff_index) - return torch.bmm(dt_from_m, x_local) - - -if SEZM_CUTE_AVAILABLE: - _F32 = cutlass.Float32 - _I64 = cutlass.Int64 - - # ------------------------------------------------------------------ - # Family 1: out(M, C) = A(M, K) @ S(K, C), with S staged in shared - # memory and A read from the Wigner tensor with a per-element gather. - # Specialized by how A[m, k] maps into the (Dw, Dw) Wigner block. - # ------------------------------------------------------------------ - def _build_rotate_to_local_fwd(D: int, Dm: int, C: int) -> Callable: - """``out[m=i, n=c] = sum_{k=j} wigner[e, idx[m], k] * x[src[e], k, n]``.""" - M, K = Dm, D - TM, TN, BY = _TM, _TN, _BLOCK_ROWS - BX = C // TN - T = BX * BY - - @cute.kernel - def kernel(m_x, m_src, m_w, m_idx, m_out) -> None: - e, _, _ = cute.arch.block_idx() - cx, ry, _ = cute.arch.thread_idx() - smem = cute.arch.alloc_smem(_F32, K * C) - s_s = cute.make_tensor(smem, cute.make_layout((K, C), stride=(C, 1))) - src_e = m_src[e] - x_node = m_x[src_e, None, None] - tid = ry * BX + cx - for kk in cutlass.range(tid, K * C, T): - s_s[kk // C, kk % C] = x_node[kk // C, kk % C] - cute.arch.sync_threads() - - w_e = m_w[e, None, None] - out_e = m_out[e, None, None] - for rt0 in cutlass.range(ry * TM, M, BY * TM): - acc = cute.make_fragment((TM, TN), _F32) - wi = cute.make_fragment((TM,), _I64) - bf = cute.make_fragment((TN,), _F32) - for t in range(TM): - wi[t] = m_idx[(rt0 + t) % M] # gathered Wigner row - for n in range(TN): - acc[t, n] = _F32(0.0) - for k in cutlass.range(K): - for n in range(TN): - bf[n] = s_s[k, cx * TN + n] - for t in range(TM): - a = w_e[wi[t], k] - for n in range(TN): - acc[t, n] = acc[t, n] + a * bf[n] - for t in range(TM): - m = rt0 + t - if m < M: - for n in range(TN): - out_e[m, cx * TN + n] = acc[t, n] - - @cute.jit - def host(m_x, m_src, m_w, m_idx, m_out, stream: _cuda_driver.CUstream) -> None: - e = m_out.shape[0] - kernel(m_x, m_src, m_w, m_idx, m_out).launch( - grid=[e, 1, 1], block=[BX, BY, 1], stream=stream - ) - - return host - - def _build_rotate_back_fwd(D: int, Dm: int, C: int) -> Callable: - """``out[m=i, n=c] = sum_{k=j} wigner[e, m, idx[k]] * x_local[e, k, n]``.""" - M, K = D, Dm - TM, TN, BY = _TM, _TN, _BLOCK_ROWS - BX = C // TN - T = BX * BY - - @cute.kernel - def kernel(m_xl, m_w, m_idx, m_out) -> None: - e, _, _ = cute.arch.block_idx() - cx, ry, _ = cute.arch.thread_idx() - smem = cute.arch.alloc_smem(_F32, K * C) - s_s = cute.make_tensor(smem, cute.make_layout((K, C), stride=(C, 1))) - xl_e = m_xl[e, None, None] - tid = ry * BX + cx - for kk in cutlass.range(tid, K * C, T): - s_s[kk // C, kk % C] = xl_e[kk // C, kk % C] - cute.arch.sync_threads() - - w_e = m_w[e, None, None] - out_e = m_out[e, None, None] - for rt0 in cutlass.range(ry * TM, M, BY * TM): - acc = cute.make_fragment((TM, TN), _F32) - wr = cute.make_fragment((TM,), _I64) - bf = cute.make_fragment((TN,), _F32) - for t in range(TM): - wr[t] = (rt0 + t) % M # direct Wigner row - for n in range(TN): - acc[t, n] = _F32(0.0) - for k in cutlass.range(K): - kk = m_idx[k] # gathered Wigner column - for n in range(TN): - bf[n] = s_s[k, cx * TN + n] - for t in range(TM): - a = w_e[wr[t], kk] - for n in range(TN): - acc[t, n] = acc[t, n] + a * bf[n] - for t in range(TM): - m = rt0 + t - if m < M: - for n in range(TN): - out_e[m, cx * TN + n] = acc[t, n] - - @cute.jit - def host(m_xl, m_w, m_idx, m_out, stream: _cuda_driver.CUstream) -> None: - e = m_out.shape[0] - kernel(m_xl, m_w, m_idx, m_out).launch( - grid=[e, 1, 1], block=[BX, BY, 1], stream=stream - ) - - return host - - def _build_rotate_to_local_bwd_dx(D: int, Dm: int, C: int) -> Callable: - """``grad_x[src[e], m=j, n=c] += sum_{k=i} wigner[e, idx[k], m] * grad_out[e, k, n]``. - - The per-edge gradient and the scatter-add into ``grad_x`` (indexed by - ``src``) are fused: each block accumulates its tile and atomically adds it - into the destination node. This avoids a materialized ``grad_edge`` tensor - and a separate ``index_add`` pass. - """ - M, K = D, Dm - TM, TN, BY = _TM, _TN, _BLOCK_ROWS - BX = C // TN - T = BX * BY - - @cute.kernel - def kernel(m_go, m_w, m_src, m_idx, m_gx) -> None: - e, _, _ = cute.arch.block_idx() - cx, ry, _ = cute.arch.thread_idx() - smem = cute.arch.alloc_smem(_F32, K * C) - s_s = cute.make_tensor(smem, cute.make_layout((K, C), stride=(C, 1))) - go_e = m_go[e, None, None] - tid = ry * BX + cx - for kk in cutlass.range(tid, K * C, T): - s_s[kk // C, kk % C] = go_e[kk // C, kk % C] - cute.arch.sync_threads() - - w_e = m_w[e, None, None] - gx_node = m_gx[m_src[e], None, None] # (D, C) view into grad_x[src] - gx_base = gx_node.iterator # contiguous (C, 1): element (m, c) -> m*C + c - for rt0 in cutlass.range(ry * TM, M, BY * TM): - acc = cute.make_fragment((TM, TN), _F32) - wc = cute.make_fragment((TM,), _I64) - bf = cute.make_fragment((TN,), _F32) - for t in range(TM): - wc[t] = (rt0 + t) % M # direct Wigner column (= output row m) - for n in range(TN): - acc[t, n] = _F32(0.0) - for k in cutlass.range(K): - kk = m_idx[k] # gathered Wigner row - for n in range(TN): - bf[n] = s_s[k, cx * TN + n] - for t in range(TM): - a = w_e[kk, wc[t]] - for n in range(TN): - acc[t, n] = acc[t, n] + a * bf[n] - for t in range(TM): - m = rt0 + t - if m < M: - for n in range(TN): - cute.arch.atomic_add( - gx_base + (m * C + cx * TN + n), acc[t, n] - ) - - @cute.jit - def host(m_go, m_w, m_src, m_idx, m_gx, stream: _cuda_driver.CUstream) -> None: - e = m_go.shape[0] - kernel(m_go, m_w, m_src, m_idx, m_gx).launch( - grid=[e, 1, 1], block=[BX, BY, 1], stream=stream - ) - - return host - - def _build_rotate_back_bwd_dx(D: int, Dm: int, C: int) -> Callable: - """``grad_x_local[m=j, n=c] = sum_{k=i} wigner[e, k, idx[m]] * grad_out[e, k, n]``.""" - M, K = Dm, D - TM, TN, BY = _TM, _TN, _BLOCK_ROWS - BX = C // TN - T = BX * BY - - @cute.kernel - def kernel(m_go, m_w, m_idx, m_gxl) -> None: - e, _, _ = cute.arch.block_idx() - cx, ry, _ = cute.arch.thread_idx() - smem = cute.arch.alloc_smem(_F32, K * C) - s_s = cute.make_tensor(smem, cute.make_layout((K, C), stride=(C, 1))) - go_e = m_go[e, None, None] - tid = ry * BX + cx - for kk in cutlass.range(tid, K * C, T): - s_s[kk // C, kk % C] = go_e[kk // C, kk % C] - cute.arch.sync_threads() - - w_e = m_w[e, None, None] - gxl_e = m_gxl[e, None, None] - for rt0 in cutlass.range(ry * TM, M, BY * TM): - acc = cute.make_fragment((TM, TN), _F32) - wc = cute.make_fragment((TM,), _I64) - bf = cute.make_fragment((TN,), _F32) - for t in range(TM): - wc[t] = m_idx[(rt0 + t) % M] # gathered Wigner column - for n in range(TN): - acc[t, n] = _F32(0.0) - for k in cutlass.range(K): - for n in range(TN): - bf[n] = s_s[k, cx * TN + n] - for t in range(TM): - a = w_e[k, wc[t]] - for n in range(TN): - acc[t, n] = acc[t, n] + a * bf[n] - for t in range(TM): - m = rt0 + t - if m < M: - for n in range(TN): - gxl_e[m, cx * TN + n] = acc[t, n] - - @cute.jit - def host(m_go, m_w, m_idx, m_gxl, stream: _cuda_driver.CUstream) -> None: - e = m_go.shape[0] - kernel(m_go, m_w, m_idx, m_gxl).launch( - grid=[e, 1, 1], block=[BX, BY, 1], stream=stream - ) - - return host - - # ------------------------------------------------------------------ - # Family 2: grad_wigner = grad_out @ other^T (contraction over the - # channel axis C). 2D register-blocked sweep over the (M, N) output, - # grad_out staged in shared memory, other read from global memory. - # ------------------------------------------------------------------ - def _build_rotate_to_local_bwd_dw(D: int, Dm: int, C: int) -> Callable: - """``grad_wigner[e, idx[m=i], n=j] = sum_{k=c} grad_out[e, m, k] * x[src[e], n, k]``.""" - M, N, K = Dm, D, C - TM, TN, BX, BY, both = _gw_tile(D, Dm, C) - T = BX * BY - - @cute.kernel - def kernel(m_go, m_x, m_src, m_idx, m_gw) -> None: - e, _, _ = cute.arch.block_idx() - cx, ry, _ = cute.arch.thread_idx() - sgo = cute.arch.alloc_smem(_F32, M * C) - s_go = cute.make_tensor(sgo, cute.make_layout((M, C), stride=(C, 1))) - go_e = m_go[e, None, None] - src_e = m_src[e] - x_node = m_x[src_e, None, None] # (N=D, C) - tid = ry * BX + cx - for kk in cutlass.range(tid, M * C, T): - s_go[kk // C, kk % C] = go_e[kk // C, kk % C] - # Optionally stage the second operand in shared memory too. - sx = cute.arch.alloc_smem(_F32, (N * C) if both else 1) - s_x = cute.make_tensor( - sx, cute.make_layout(((N, C) if both else (1, 1)), stride=(C, 1)) - ) - if cutlass.const_expr(both): - for kk in cutlass.range(tid, N * C, T): - s_x[kk // C, kk % C] = x_node[kk // C, kk % C] - cute.arch.sync_threads() - - gw_e = m_gw[e, None, None] - for mt0 in cutlass.range(ry * TM, M, BY * TM): - orow = cute.make_fragment((TM,), _I64) - rt = cute.make_fragment((TM,), cutlass.Int32) - for t in range(TM): - rt[t] = (mt0 + t) % M # clamped smem row (hoisted out of K loop) - orow[t] = m_idx[rt[t]] # gathered output row - for nt0 in cutlass.range(cx * TN, N, BX * TN): - acc = cute.make_fragment((TM, TN), _F32) - af = cute.make_fragment((TM,), _F32) - bf = cute.make_fragment((TN,), _F32) - ct = cute.make_fragment((TN,), cutlass.Int32) - for n in range(TN): - ct[n] = (nt0 + n) % N # clamped col (hoisted) - for t in range(TM): - for n in range(TN): - acc[t, n] = _F32(0.0) - for k in cutlass.range(K): - for t in range(TM): - af[t] = s_go[rt[t], k] - if cutlass.const_expr(both): - for n in range(TN): - bf[n] = s_x[ct[n], k] - else: - for n in range(TN): - bf[n] = x_node[ct[n], k] - for t in range(TM): - for n in range(TN): - acc[t, n] = acc[t, n] + af[t] * bf[n] - for t in range(TM): - if mt0 + t < M: - for n in range(TN): - if nt0 + n < N: - gw_e[orow[t], nt0 + n] = acc[t, n] - - @cute.jit - def host(m_go, m_x, m_src, m_idx, m_gw, stream: _cuda_driver.CUstream) -> None: - e = m_go.shape[0] - kernel(m_go, m_x, m_src, m_idx, m_gw).launch( - grid=[e, 1, 1], block=[BX, BY, 1], stream=stream - ) - - return host - - def _build_rotate_back_bwd_dw(D: int, Dm: int, C: int) -> Callable: - """``grad_wigner[e, m=i, idx[n=j]] = sum_{k=c} grad_out[e, m, k] * x_local[e, n, k]``.""" - M, N, K = D, Dm, C - TM, TN, BX, BY, both = _gw_tile(D, Dm, C) - T = BX * BY - - @cute.kernel - def kernel(m_go, m_xl, m_idx, m_gw) -> None: - e, _, _ = cute.arch.block_idx() - cx, ry, _ = cute.arch.thread_idx() - sgo = cute.arch.alloc_smem(_F32, M * C) - s_go = cute.make_tensor(sgo, cute.make_layout((M, C), stride=(C, 1))) - go_e = m_go[e, None, None] - xl_e = m_xl[e, None, None] # (N=Dm, C) - tid = ry * BX + cx - for kk in cutlass.range(tid, M * C, T): - s_go[kk // C, kk % C] = go_e[kk // C, kk % C] - sx = cute.arch.alloc_smem(_F32, (N * C) if both else 1) - s_x = cute.make_tensor( - sx, cute.make_layout(((N, C) if both else (1, 1)), stride=(C, 1)) - ) - if cutlass.const_expr(both): - for kk in cutlass.range(tid, N * C, T): - s_x[kk // C, kk % C] = xl_e[kk // C, kk % C] - cute.arch.sync_threads() - - gw_e = m_gw[e, None, None] - for mt0 in cutlass.range(ry * TM, M, BY * TM): - rt = cute.make_fragment((TM,), cutlass.Int32) - for t in range(TM): - rt[t] = (mt0 + t) % M # clamped smem row (hoisted out of K loop) - for nt0 in cutlass.range(cx * TN, N, BX * TN): - acc = cute.make_fragment((TM, TN), _F32) - ocol = cute.make_fragment((TN,), _I64) - ct = cute.make_fragment((TN,), cutlass.Int32) - af = cute.make_fragment((TM,), _F32) - bf = cute.make_fragment((TN,), _F32) - for n in range(TN): - ct[n] = (nt0 + n) % N # clamped col (hoisted) - ocol[n] = m_idx[ct[n]] # gathered output column - for t in range(TM): - for n in range(TN): - acc[t, n] = _F32(0.0) - for k in cutlass.range(K): - for t in range(TM): - af[t] = s_go[rt[t], k] - if cutlass.const_expr(both): - for n in range(TN): - bf[n] = s_x[ct[n], k] - else: - for n in range(TN): - bf[n] = xl_e[ct[n], k] - for t in range(TM): - for n in range(TN): - acc[t, n] = acc[t, n] + af[t] * bf[n] - for t in range(TM): - i = mt0 + t - if i < M: - for n in range(TN): - if nt0 + n < N: - gw_e[i, ocol[n]] = acc[t, n] - - @cute.jit - def host(m_go, m_xl, m_idx, m_gw, stream: _cuda_driver.CUstream) -> None: - e = m_go.shape[0] - kernel(m_go, m_xl, m_idx, m_gw).launch( - grid=[e, 1, 1], block=[BX, BY, 1], stream=stream - ) - - return host - - # === Compiled-kernel cache ============================================== - _compiled_cache: dict[tuple, Any] = {} - _cache_lock = threading.Lock() - - def _get_compiled(key: tuple, builder: Callable, example_args: tuple) -> Any: - """Return a JIT-compiled host function, compiling and caching on miss.""" - comp = _compiled_cache.get(key) - if comp is not None: - return comp - with _cache_lock: - comp = _compiled_cache.get(key) - if comp is None: - host = builder(*key[1:]) - comp = cute.compile(host, *example_args) - _compiled_cache[key] = comp - return comp - - def _cute_f(t: Tensor) -> Any: - """Wrap a contiguous (>=2D) fp32 tensor as a CuTe tensor (last dim leading).""" - return from_dlpack(t).mark_layout_dynamic(leading_dim=t.dim() - 1) - - def _cute_i(t: Tensor) -> Any: - """Wrap a contiguous 1D int64 tensor as a CuTe tensor.""" - return from_dlpack(t).mark_layout_dynamic() - - # === Low-level kernel dispatch (operate on plain, detached tensors) ====== - def _launch_rotate_to_local_fwd( - x: Tensor, src: Tensor, wigner: Tensor, coeff_index: Tensor, dim_full: int - ) -> Tensor: - e = src.shape[0] - d, dm, c = dim_full, coeff_index.shape[0], x.shape[2] - out = torch.empty((e, dm, c), dtype=x.dtype, device=x.device) - m_x, m_src, m_w = _cute_f(x), _cute_i(src), _cute_f(wigner) - m_idx, m_out = _cute_i(coeff_index), _cute_f(out) - stream = cutlass_torch.current_stream() - comp = _get_compiled( - ("rtl_fwd", d, dm, c), - _build_rotate_to_local_fwd, - (m_x, m_src, m_w, m_idx, m_out, stream), - ) - comp(m_x, m_src, m_w, m_idx, m_out, stream) - return out - - def _launch_rotate_back_fwd( - x_local: Tensor, wigner: Tensor, coeff_index: Tensor, dim_full: int - ) -> Tensor: - e = x_local.shape[0] - d, dm, c = dim_full, coeff_index.shape[0], x_local.shape[2] - out = torch.empty((e, d, c), dtype=x_local.dtype, device=x_local.device) - m_xl, m_w = _cute_f(x_local), _cute_f(wigner) - m_idx, m_out = _cute_i(coeff_index), _cute_f(out) - stream = cutlass_torch.current_stream() - comp = _get_compiled( - ("rb_fwd", d, dm, c), - _build_rotate_back_fwd, - (m_xl, m_w, m_idx, m_out, stream), - ) - comp(m_xl, m_w, m_idx, m_out, stream) - return out - - def _launch_rotate_to_local_bwd( - grad_out: Tensor, - x: Tensor, - src: Tensor, - wigner: Tensor, - coeff_index: Tensor, - dim_full: int, - ) -> tuple[Tensor, Tensor]: - n, e = x.shape[0], src.shape[0] - d, dm, c = dim_full, coeff_index.shape[0], x.shape[2] - stream = cutlass_torch.current_stream() - - # grad_x: per-edge gradient fused with the scatter-add into the source - # node via atomic adds (no materialized grad_edge, no separate index_add). - grad_x = torch.zeros((n, d, c), dtype=x.dtype, device=x.device) - m_go, m_w = _cute_f(grad_out), _cute_f(wigner) - m_src, m_idx, m_gx = _cute_i(src), _cute_i(coeff_index), _cute_f(grad_x) - comp_dx = _get_compiled( - ("rtl_bwd_dx", d, dm, c), - _build_rotate_to_local_bwd_dx, - (m_go, m_w, m_src, m_idx, m_gx, stream), - ) - comp_dx(m_go, m_w, m_src, m_idx, m_gx, stream) - - # grad_wigner: per-edge outer product written into the gathered rows. - grad_wigner = torch.zeros_like(wigner) - m_x, m_gw = _cute_f(x), _cute_f(grad_wigner) - comp_dw = _get_compiled( - ("rtl_bwd_dw", d, dm, c), - _build_rotate_to_local_bwd_dw, - (m_go, m_x, m_src, m_idx, m_gw, stream), - ) - comp_dw(m_go, m_x, m_src, m_idx, m_gw, stream) - return grad_x, grad_wigner - - def _launch_rotate_back_bwd( - grad_out: Tensor, - x_local: Tensor, - wigner: Tensor, - coeff_index: Tensor, - dim_full: int, - ) -> tuple[Tensor, Tensor]: - e = x_local.shape[0] - d, dm, c = dim_full, coeff_index.shape[0], x_local.shape[2] - stream = cutlass_torch.current_stream() - - grad_x_local = torch.empty( - (e, dm, c), dtype=x_local.dtype, device=x_local.device - ) - m_go, m_w = _cute_f(grad_out), _cute_f(wigner) - m_idx, m_gxl = _cute_i(coeff_index), _cute_f(grad_x_local) - comp_dx = _get_compiled( - ("rb_bwd_dx", d, dm, c), - _build_rotate_back_bwd_dx, - (m_go, m_w, m_idx, m_gxl, stream), - ) - comp_dx(m_go, m_w, m_idx, m_gxl, stream) - - grad_wigner = torch.zeros_like(wigner) - m_xl, m_gw = _cute_f(x_local), _cute_f(grad_wigner) - comp_dw = _get_compiled( - ("rb_bwd_dw", d, dm, c), - _build_rotate_back_bwd_dw, - (m_go, m_xl, m_idx, m_gw, stream), - ) - comp_dw(m_go, m_xl, m_idx, m_gw, stream) - return grad_x_local, grad_wigner - - # === torch.library custom ops =========================================== - # Forward + backward are registered as functional custom ops so the whole - # operator is opaque to torch.compile yet correctly differentiable. - - @torch.library.custom_op( - "sezm_cute::rotate_to_local", mutates_args=(), device_types="cuda" - ) - def _op_rotate_to_local( - x: Tensor, - src: Tensor, - wigner: Tensor, - coeff_index: Tensor, - dim_full: int, - ) -> Tensor: - return _launch_rotate_to_local_fwd( - x.detach().contiguous(), - src.detach().contiguous(), - wigner.detach().contiguous(), - coeff_index.detach().contiguous(), - int(dim_full), - ) - - @_op_rotate_to_local.register_fake - def _( - x: Tensor, - src: Tensor, - wigner: Tensor, - coeff_index: Tensor, - dim_full: int, - ) -> Tensor: - return x.new_empty((src.shape[0], coeff_index.shape[0], x.shape[2])) - - @torch.library.custom_op( - "sezm_cute::rotate_to_local_bwd", mutates_args=(), device_types="cuda" - ) - def _op_rotate_to_local_bwd( - grad_out: Tensor, - x: Tensor, - src: Tensor, - wigner: Tensor, - coeff_index: Tensor, - dim_full: int, - ) -> tuple[Tensor, Tensor]: - return _launch_rotate_to_local_bwd( - grad_out.detach().contiguous(), - x.detach().contiguous(), - src.detach().contiguous(), - wigner.detach().contiguous(), - coeff_index.detach().contiguous(), - int(dim_full), - ) - - @_op_rotate_to_local_bwd.register_fake - def _( - grad_out: Tensor, - x: Tensor, - src: Tensor, - wigner: Tensor, - coeff_index: Tensor, - dim_full: int, - ) -> tuple[Tensor, Tensor]: - return torch.empty_like(x), torch.empty_like(wigner) - - def _rtl_setup_context(ctx: Any, inputs: tuple, output: Tensor) -> None: - x, src, wigner, coeff_index, dim_full = inputs - ctx.save_for_backward(x, src, wigner, coeff_index) - ctx.dim_full = int(dim_full) - - def _rtl_backward(ctx: Any, grad_out: Tensor) -> tuple: - x, src, wigner, coeff_index = ctx.saved_tensors - grad_x, grad_wigner = torch.ops.sezm_cute.rotate_to_local_bwd( - grad_out, x, src, wigner, coeff_index, ctx.dim_full - ) - return grad_x, None, grad_wigner, None, None - - _op_rotate_to_local.register_autograd( - _rtl_backward, setup_context=_rtl_setup_context - ) - - @torch.library.custom_op( - "sezm_cute::rotate_back", mutates_args=(), device_types="cuda" - ) - def _op_rotate_back( - x_local: Tensor, - wigner: Tensor, - coeff_index: Tensor, - dim_full: int, - ) -> Tensor: - return _launch_rotate_back_fwd( - x_local.detach().contiguous(), - wigner.detach().contiguous(), - coeff_index.detach().contiguous(), - int(dim_full), - ) - - @_op_rotate_back.register_fake - def _( - x_local: Tensor, - wigner: Tensor, - coeff_index: Tensor, - dim_full: int, - ) -> Tensor: - return x_local.new_empty((x_local.shape[0], dim_full, x_local.shape[2])) - - @torch.library.custom_op( - "sezm_cute::rotate_back_bwd", mutates_args=(), device_types="cuda" - ) - def _op_rotate_back_bwd( - grad_out: Tensor, - x_local: Tensor, - wigner: Tensor, - coeff_index: Tensor, - dim_full: int, - ) -> tuple[Tensor, Tensor]: - return _launch_rotate_back_bwd( - grad_out.detach().contiguous(), - x_local.detach().contiguous(), - wigner.detach().contiguous(), - coeff_index.detach().contiguous(), - int(dim_full), - ) - - @_op_rotate_back_bwd.register_fake - def _( - grad_out: Tensor, - x_local: Tensor, - wigner: Tensor, - coeff_index: Tensor, - dim_full: int, - ) -> tuple[Tensor, Tensor]: - return torch.empty_like(x_local), torch.empty_like(wigner) - - def _rb_setup_context(ctx: Any, inputs: tuple, output: Tensor) -> None: - x_local, wigner, coeff_index, dim_full = inputs - ctx.save_for_backward(x_local, wigner, coeff_index) - ctx.dim_full = int(dim_full) - - def _rb_backward(ctx: Any, grad_out: Tensor) -> tuple: - x_local, wigner, coeff_index = ctx.saved_tensors - grad_x_local, grad_wigner = torch.ops.sezm_cute.rotate_back_bwd( - grad_out, x_local, wigner, coeff_index, ctx.dim_full - ) - return grad_x_local, grad_wigner, None, None - - _op_rotate_back.register_autograd(_rb_backward, setup_context=_rb_setup_context) - - -# === Public API ============================================================== -def _cute_usable(channels: int, *tensors: Tensor) -> bool: - """Return True when the CuTe fast path is available for these tensors.""" - if not SEZM_CUTE_AVAILABLE: - return False - if int(channels) < _TN or int(channels) % _TN != 0: - return False - return all( - t.is_cuda and t.dtype == torch.float32 for t in tensors if t.is_floating_point() - ) - - -def rotate_to_local_cute( - x: Tensor, - src: Tensor, - wigner: Tensor, - coeff_index: Tensor, - dim_full: int, -) -> Tensor: - """ - Fused ``global -> local reduced`` rotation (CuTe fast path with eager fallback). - - Parameters - ---------- - x - Node features with shape ``(N, D, C)``. - src - Source-node indices with shape ``(E,)``. - wigner - Packed Wigner-D matrices with shape ``(E, Dw, Dw)`` (``Dw >= dim_full``). - coeff_index - Reduced-layout row indices with shape ``(Dm,)``. - dim_full - Full packed SO(3) dimension ``D``. - - Returns - ------- - Tensor - Rotated reduced-layout edge features with shape ``(E, Dm, C)``. - - Notes - ----- - Experimental path that is not used in production. See the module docstring - for the benchmark conclusion and why the Triton kernels were chosen instead. - """ - if _cute_usable(x.shape[2], x, wigner) and src.numel() > 0: - return torch.ops.sezm_cute.rotate_to_local( - x, src, wigner, coeff_index, int(dim_full) - ) - return _rotate_to_local_eager(x, src, wigner, coeff_index, dim_full) - - -def rotate_back_cute( - x_local: Tensor, - wigner: Tensor, - coeff_index: Tensor, - dim_full: int, -) -> Tensor: - """ - Fused ``local reduced -> global`` rotation (CuTe fast path with eager fallback). - - Parameters - ---------- - x_local - Reduced-layout edge features with shape ``(E, Dm, C)``. - wigner - Packed Wigner-D matrices with shape ``(E, Dw, Dw)`` (``Dw >= dim_full``). - coeff_index - Reduced-layout column indices with shape ``(Dm,)``. - dim_full - Full packed SO(3) dimension ``D``. - - Returns - ------- - Tensor - Lifted global-layout edge features with shape ``(E, D, C)``. - - Notes - ----- - Experimental path that is not used in production. See the module docstring - for the benchmark conclusion and why the Triton kernels were chosen instead. - """ - if _cute_usable(x_local.shape[2], x_local, wigner) and x_local.shape[0] > 0: - return torch.ops.sezm_cute.rotate_back( - x_local, wigner, coeff_index, int(dim_full) - ) - return _rotate_back_eager(x_local, wigner, coeff_index, dim_full) diff --git a/deepmd/pt/model/descriptor/sezm_nn/embedding.py b/deepmd/pt/model/descriptor/sezm_nn/embedding.py index e31d7e2b65..c01c3060b5 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/embedding.py +++ b/deepmd/pt/model/descriptor/sezm_nn/embedding.py @@ -39,6 +39,9 @@ check_version_compatibility, ) +from .cartesian import ( + build_cartesian_basis, +) from .indexing import ( build_gie_zonal_index, get_so3_dim_of_lmax, @@ -193,6 +196,14 @@ def __init__( node_radial_l_index, persistent=True, ) + # The l=1 coefficients (packed rows 1..3) are the first three entries of + # the non-scalar sequence ``node_row_index = [1, 2, ..., D-1]``, so the + # native neighbor-spin l=1 message folds in at these local positions. + self.register_buffer( + "l1_local_index", + torch.arange(3, device=self.device, dtype=torch.long), + persistent=False, + ) def forward( self, @@ -201,6 +212,7 @@ def forward( edge_cache: EdgeFeatureCache, radial_feat: torch.Tensor, zonal_coupling: torch.Tensor | None = None, + spin_l1_message: torch.Tensor | None = None, ) -> torch.Tensor: """ Parameters @@ -214,6 +226,12 @@ def forward( zonal_coupling Optional precomputed zonal coupling with shape (E, D-1). If None, it is gathered from ``edge_cache.Dt_full``. + spin_l1_message + Optional per-edge neighbor-spin l=1 message with shape (E, 3, C) for + the native spin scheme (built by ``SpinEmbedding.edge_l1``). It is + added to the l=1 rows of the per-edge message, so it shares this + module's source gate, scatter and degree normalization with the + geometric message. Returns ------- @@ -249,6 +267,15 @@ def forward( zonal_coupling.unsqueeze(-1) * radial_value_for_row ) # (E, D-1, C) + # === Step 3b. Fold in the neighbor-spin l=1 message (native spin) === + # The l=1 coefficients occupy the first three packed non-scalar rows, so + # the neighbor-spin message joins the geometric message there and then + # shares the source gate, scatter and degree normalization below. + if spin_l1_message is not None: + non_scalar_message = non_scalar_message.index_add( + 1, self.l1_local_index, spin_l1_message + ) + # === Step 4. Source Freeze Propagation Gate (optional) === # Mute messages emitted by nodes whose local neighborhood enters # the frozen zone. ``edge_src_gate`` is ``None`` outside bridging @@ -335,6 +362,12 @@ class EnvironmentInitialEmbedding(nn.Module): Activation function for G network hidden layer. eps : float Small epsilon for numerical stability. + use_spin : list[bool] | None + Per-type spin flags (native spin scheme). When provided, the neighbor + spin is appended as extra coordinate channels of the environment matrix, + so the inner product ``D = M^T M`` additionally yields the neighbor + spin-spin invariants. A per-type mask gates the channel, so a + non-magnetic neighbor contributes zero and carries zero magnetic force. dtype : torch.dtype Parameter dtype. trainable : bool @@ -356,6 +389,7 @@ def __init__( mlp_bias: bool = False, activation_function: str = "silu", eps: float = 1e-7, + use_spin: list[bool] | None = None, dtype: torch.dtype, trainable: bool, seed: int | list[int] | None = None, @@ -378,9 +412,17 @@ def __init__( self.mlp_bias = bool(mlp_bias) self.activation_function = str(activation_function) self.eps = float(eps) + self.spin_flags = None if use_spin is None else [bool(x) for x in use_spin] + if self.spin_flags is not None and len(self.spin_flags) != int(ntypes): + raise ValueError("`use_spin` length must equal `ntypes`") self.dtype = dtype self.device = env.DEVICE self.precision = RESERVED_PRECISION_DICT[dtype] + # The environment matrix carries the 4 geometric channels ``[s, s*r_hat]`` + # plus, for the native spin scheme, the 3 envelope-gated neighbor-spin + # components, so the inner product ``D = M^T M`` yields the neighbor + # spin-spin invariants alongside the geometric ones. + self.coord_dim = 4 + (3 if self.spin_flags is not None else 0) self.register_buffer( "eps_sq_tensor", torch.tensor(self.eps * self.eps, dtype=self.dtype, device=self.device), @@ -454,6 +496,25 @@ def __init__( seed=seed_out, ) + # === Native spin: per-type mask and isotropic channel scale === + # The mask gates the neighbor-spin channel by source type, so a + # non-magnetic neighbor contributes zero and (critically) carries zero + # magnetic force ``-dE/ds``. The single scalar scale (shared across + # x/y/z) keeps the spin coordinates transforming with the geometry, so + # the env-matrix invariant stays SO(3)-invariant; ``output_proj`` is + # zero-initialized, so the spin contribution starts neutral regardless. + if self.spin_flags is not None: + spin_mask = torch.tensor( + [1.0 if flag else 0.0 for flag in self.spin_flags], + dtype=self.dtype, + device=self.device, + ) + self.register_buffer("spin_mask", spin_mask, persistent=False) + self.spin_scale = nn.Parameter( + torch.ones(1, dtype=self.dtype, device=self.device), + requires_grad=trainable, + ) + for p in self.parameters(): p.requires_grad = trainable @@ -463,6 +524,7 @@ def forward( edge_cache: EdgeFeatureCache, atype_flat: torch.Tensor, n_nodes: int, + spin: torch.Tensor | None = None, ) -> torch.Tensor: """ Compute environment FiLM logits for l=0 conditioning. @@ -475,6 +537,12 @@ def forward( Flattened atom types with shape (N,), where N = nf * nloc. n_nodes : int Number of nodes (N = nf * nloc). + spin : torch.Tensor | None + Per-node spin vectors with shape (N, 3) for the native spin scheme. + Used only when ``use_spin`` is set; the source (neighbor) spin is + appended to the environment matrix as an envelope-gated coordinate + channel. When ``None`` the spin channels are zero-padded so the + coordinate dimension stays fixed. Returns ------- @@ -494,6 +562,24 @@ def forward( r_hat = edge_vec * inv_r # (E, 3) r_tilde = torch.cat([s, s * r_hat], dim=-1) # (E, 4) + # === Step 1b. Append neighbor spin as extra coordinate channels === + # The source (neighbor) spin enters the environment matrix gated by the + # same C^3 envelope as the geometry, so it decays smoothly at rcut and a + # non-magnetic neighbor (s_j = 0) contributes exactly zero. The linear + # form keeps the magnetic force continuous at s = 0. + if self.spin_flags is not None: + if spin is not None: + spin_src = spin.index_select(0, src).to(dtype=r_tilde.dtype) # (E, 3) + # Gate by source type: a non-magnetic neighbor must not enter + # the energy, so its magnetic force ``-dE/ds`` stays exactly zero. + mask = self.spin_mask.index_select( + 0, atype_flat.index_select(0, src) + ).unsqueeze(-1) # (E, 1) + spin_chan = edge_env * self.spin_scale * spin_src * mask # (E, 3) + else: + spin_chan = r_tilde.new_zeros(r_tilde.shape[0], 3) + r_tilde = torch.cat([r_tilde, spin_chan], dim=-1) # (E, coord_dim) + # === Step 2. Compute G network input and output === # Use independent type embeddings (decoupled from main type embedding) atype_src = atype_flat.index_select(0, src) # (E,) @@ -511,17 +597,17 @@ def forward( g = self.g_layer2(self.g_layer1(g_input)) # (E, embed_dim) # === Step 3. Aggregate outer product by destination node === - # outer = r_tilde[:, :, None] * g[:, None, :] # (E, 4, embed_dim) - outer = torch.einsum("ei,ej->eij", r_tilde, g) # (E, 4, embed_dim) - outer_flat = outer.reshape(-1, 4 * self.embed_dim) # (E, 4*embed_dim) + # outer = r_tilde[:, :, None] * g[:, None, :] # (E, coord_dim, embed_dim) + outer = torch.einsum("ei,ej->eij", r_tilde, g) # (E, coord_dim, embed_dim) + outer_flat = outer.reshape(-1, self.coord_dim * self.embed_dim) # Source Freeze Propagation Gate: mute the outer-product contribution # of any edge whose source node has a neighbor in the frozen zone. src_gate = edge_cache.edge_src_gate if src_gate is not None: outer_flat = outer_flat * src_gate.to(dtype=outer_flat.dtype) - env_agg = outer_flat.new_zeros(n_nodes, 4 * self.embed_dim) # (N, 4*embed_dim) + env_agg = outer_flat.new_zeros(n_nodes, self.coord_dim * self.embed_dim) env_agg.index_add_(0, dst, outer_flat) - env_agg = env_agg.reshape(n_nodes, 4, self.embed_dim) # (N, 4, embed_dim) + env_agg = env_agg.reshape(n_nodes, self.coord_dim, self.embed_dim) # === Step 4. Smooth normalization by envelope-squared degree === # Reuse the cache's inverse-sqrt degree so the version-aware @@ -529,8 +615,11 @@ def forward( env_agg = env_agg * edge_cache.inv_sqrt_deg # === Step 5. D matrix construction: D = env_agg^T @ env_agg[:,:,:axis_dim] === - env_agg_t = env_agg.permute(0, 2, 1) # (N, embed_dim, 4) - env_agg_axis = env_agg[:, :, : self.axis_dim] # (N, 4, axis_dim) + # Summing over the coordinate axis makes D invariant to a joint rotation + # of the geometry and the spin channels; with the spin channels present, + # D additionally carries the neighbor spin-spin invariants. + env_agg_t = env_agg.permute(0, 2, 1) # (N, embed_dim, coord_dim) + env_agg_axis = env_agg[:, :, : self.axis_dim] # (N, coord_dim, axis_dim) D = torch.bmm(env_agg_t, env_agg_axis) # (N, embed_dim, axis_dim) # === Step 6. Output projection for FiLM logits === @@ -556,6 +645,7 @@ def serialize(self) -> dict[str, Any]: "mlp_bias": self.mlp_bias, "activation_function": self.activation_function, "eps": self.eps, + "use_spin": self.spin_flags, "precision": self.precision, "trainable": trainable, "seed": None, @@ -670,3 +760,247 @@ def forward(self, charge_spin: torch.Tensor) -> torch.Tensor: charge_embed = self.charge_embedding(charge) spin_embed = self.spin_embedding(spin) return self.mix_layer(torch.cat((charge_embed, spin_embed), dim=-1)) + + +class SpinEmbedding(nn.Module): + """ + Per-atom spin embedding for the native spin scheme. + + The per-atom spin vector ``s`` is injected as an equivariant extension of + the type embedding, producing two additive contributions to the descriptor + node features: + + - **l = 0 (invariant):** a small network of the squared magnitude ``|s|^2`` + yields a per-channel scalar added to the scalar type embedding. The + squared magnitude is used (rather than ``|s|``) so the feature is smooth + at ``s = 0`` and its gradient there vanishes, keeping the magnetic force + continuous as a spin crosses zero. + - **l = 1 (equivariant):** the Cartesian spin vector is mapped to the packed + ``l = 1`` coefficients through the SeZM Wigner-D convention (derived from + :func:`build_cartesian_basis`), then scaled by a per-type per-channel + weight. The map is linear in ``s``, so the contribution vanishes at + ``s = 0`` and rotates as an ``l = 1`` object under SO(3), i.e. + ``cart_to_l1(R s) = D^1(R) cart_to_l1(s)``. + + Both contributions are gated by a per-type spin mask, so atom types without + spin contribute exactly zero regardless of their (nominally zero) input. + + Parameters + ---------- + ntypes + Number of (real) atom types. + channels + Number of channels per (l, m) coefficient. + use_spin + Per-type boolean flags marking which atom types carry spin. + activation_function + Activation used by the magnitude network. + dtype + Parameter dtype. + seed + Random seed for initialization. + trainable + Whether parameters are trainable. + """ + + def __init__( + self, + *, + ntypes: int, + channels: int, + use_spin: list[bool], + activation_function: str = "silu", + dtype: torch.dtype, + seed: int | list[int] | None = None, + trainable: bool = True, + ) -> None: + super().__init__() + self.ntypes = int(ntypes) + self.channels = int(channels) + self.activation_function = str(activation_function) + self.dtype = dtype + self.device = env.DEVICE + self.precision = RESERVED_PRECISION_DICT[dtype] + if self.ntypes <= 0: + raise ValueError("`ntypes` must be positive") + if self.channels <= 0: + raise ValueError("`channels` must be positive") + if len(use_spin) != self.ntypes: + raise ValueError("`use_spin` length must equal `ntypes`") + + # === Per-type spin gate === + # Non-persistent: rebuilt from config on construction and moved with the + # module, so the deterministic mask never enters the serialized state. + spin_mask = torch.tensor( + [1.0 if bool(flag) else 0.0 for flag in use_spin], + dtype=self.dtype, + device=self.device, + ) + self.register_buffer("spin_mask", spin_mask, persistent=False) + + # === Cartesian -> packed l=1 projection === + # Derived from the SeZM packed basis so a spin vector rotates with the + # same Wigner-D block as the geometry. Non-persistent constant. + self.register_buffer( + "cart_to_l1", + self._build_cart_to_l1_matrix(), + persistent=False, + ) + + # === l=0 magnitude network: |s|^2 -> channels === + # The leading ``1 -> channels`` layer carries a singleton input + # dimension that HybridMuon routes to its Adam path automatically. + seed_scalar = child_seed(seed, 0) + self.mag_layer1 = MLPLayer( + 1, + self.channels, + bias=False, + activation_function=self.activation_function, + precision=self.precision, + seed=child_seed(seed_scalar, 0), + trainable=trainable, + ) + self.mag_layer2 = MLPLayer( + self.channels, + self.channels, + bias=False, + activation_function=None, + precision=self.precision, + seed=child_seed(seed_scalar, 1), + trainable=trainable, + ) + + # === l=1 per-type per-channel weight === + # ``adam_`` prefix routes the table to Adam in HybridMuon, matching the + # type-embedding treatment for per-type lookup parameters. + self.adam_spin_vec_weight = nn.Parameter( + torch.empty( + self.ntypes, self.channels, device=self.device, dtype=self.dtype + ) + ) + init_std = 1.0 / math.sqrt(float(self.ntypes + self.channels)) + nn.init.normal_( + self.adam_spin_vec_weight, + mean=0.0, + std=init_std, + generator=get_generator(child_seed(seed, 1)), + ) + + # === l=1 per-source-type per-channel weight for neighbor aggregation === + # Separate from the on-site weight: this scales the neighbor's spin + # direction before it is aggregated into the center node's l=1 seed. + self.adam_spin_nbr_weight = nn.Parameter( + torch.empty( + self.ntypes, self.channels, device=self.device, dtype=self.dtype + ) + ) + nn.init.normal_( + self.adam_spin_nbr_weight, + mean=0.0, + std=init_std, + generator=get_generator(child_seed(seed, 2)), + ) + + for p in self.parameters(): + p.requires_grad = trainable + + def forward( + self, spin: torch.Tensor, atype: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute the l=0 and l=1 spin contributions. + + Parameters + ---------- + spin + Per-atom spin vectors with shape (N, 3). + atype + Per-atom types with shape (N,). + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + ``(scalar, vector)`` where ``scalar`` has shape (N, channels) for + the l=0 contribution and ``vector`` has shape (N, 3, channels) for + the packed l=1 contribution (orders m = -1, 0, +1). Both are exactly + zero for atom types without spin. + """ + spin = spin.to(dtype=self.dtype) + mask = self.spin_mask.index_select(0, atype).unsqueeze(-1) # (N, 1) + + # === l=0: smooth invariant magnitude embedding === + mag2 = (spin * spin).sum(dim=-1, keepdim=True) # (N, 1) + scalar: torch.Tensor = self.mag_layer2(self.mag_layer1(mag2)) * mask # (N, C) + + # === l=1: equivariant direction embedding (linear in spin) === + l1 = torch.einsum("dk,nk->nd", self.cart_to_l1, spin) # (N, 3) + weight = self.adam_spin_vec_weight.index_select(0, atype) # (N, C) + vector = ( + l1.unsqueeze(-1) * weight.unsqueeze(1) * mask.unsqueeze(-1) + ) # (N, 3, C) + + return scalar, vector + + def edge_l1( + self, + spin: torch.Tensor, + atype: torch.Tensor, + edge_cache: EdgeFeatureCache, + ) -> torch.Tensor: + """ + Build the per-edge neighbor-spin l=1 message for the GIE aggregation. + + Each edge carries the packed ``l = 1`` coefficients of the source + (neighbor) spin, scaled by a per-source-type per-channel weight and + gated by the C^3 envelope. The message is returned per edge; the + geometric initial embedding folds it into the l=1 rows and applies the + shared source gate, scatter and degree normalization, so a neighbor's + spin direction enters an atom's l=1 backbone before any interaction + block (the spin analogue of the geometric initial embedding). + + Parameters + ---------- + spin + Per-node spin vectors with shape (N, 3). + atype + Per-node types with shape (N,). + edge_cache + Edge cache providing ``src`` and ``edge_env``. + + Returns + ------- + torch.Tensor + Per-edge packed l=1 message with shape (E, 3, channels), exactly + zero for non-magnetic neighbors. + """ + spin = spin.to(dtype=self.dtype) + spin_src = spin.index_select(0, edge_cache.src) # (E, 3) + atype_src = atype.index_select(0, edge_cache.src) # (E,) + + # Packed l=1 of the neighbor spin; the global-frame vector needs no + # Wigner-D rotation (it rotates with the geometry by construction). + l1 = torch.einsum("dk,ek->ed", self.cart_to_l1, spin_src) # (E, 3) + weight = self.adam_spin_nbr_weight.index_select(0, atype_src) # (E, C) + mask = self.spin_mask.index_select(0, atype_src) # (E,) + gate = edge_cache.edge_env * mask.unsqueeze(-1) # (E, 1) + return gate.unsqueeze(-1) * l1.unsqueeze(-1) * weight.unsqueeze(1) # (E, 3, C) + + def _build_cart_to_l1_matrix(self) -> torch.Tensor: + """ + Build the ``(3, 3)`` Cartesian-to-packed-``l=1`` projection. + + The packed ``l = 1`` coefficient of a vector ``v`` is obtained by + projecting the skew-symmetric matrix ``[v]_x`` onto the antisymmetric + ``l = 1`` block of :func:`build_cartesian_basis`. With packed order + ``m = -1, 0, +1``, row ``d`` and Cartesian component ``k`` give + ``M[d, k] = <[e_k]_x, B[1 + d]>_F``, so ``coeff = M @ v`` and + ``M @ (R v) = D^1(R) (M @ v)``. + """ + basis_l1 = build_cartesian_basis(1, dtype=self.dtype, device=self.device)[1:4] + # Skew (cross-product) matrices of the Cartesian unit vectors, following + # ``[v]_x w = v x w`` (matching ``build_edge_cartesian_tensors``). + skew_basis = torch.zeros(3, 3, 3, dtype=self.dtype, device=self.device) + skew_basis[0, 1, 2], skew_basis[0, 2, 1] = -1.0, 1.0 + skew_basis[1, 0, 2], skew_basis[1, 2, 0] = 1.0, -1.0 + skew_basis[2, 0, 1], skew_basis[2, 1, 0] = -1.0, 1.0 + return torch.einsum("kij,dij->dk", skew_basis, basis_l1) diff --git a/deepmd/pt/model/descriptor/sezm_nn/grid_net.py b/deepmd/pt/model/descriptor/sezm_nn/grid_net.py index b0226203fe..7c36e99e7b 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/grid_net.py +++ b/deepmd/pt/model/descriptor/sezm_nn/grid_net.py @@ -60,7 +60,7 @@ Callable, ) -GridNetLayout = Literal["ndfc", "nfdc", "flat"] +GridNetLayout = Literal["ndfc", "nfdc", "fndc", "flat"] GridNetMode = Literal["self", "cross"] GridNetOp = Literal["glu", "mlp", "branch"] @@ -479,8 +479,10 @@ def __init__( raise ValueError("`op_type` must be one of 'glu', 'mlp', or 'branch'") self.dtype = dtype self.layout = str(layout).lower() - if self.layout not in {"ndfc", "nfdc", "flat"}: - raise ValueError("`layout` must be one of 'ndfc', 'nfdc', or 'flat'") + if self.layout not in {"ndfc", "nfdc", "fndc", "flat"}: + raise ValueError( + "`layout` must be one of 'ndfc', 'nfdc', 'fndc', or 'flat'" + ) if self.mode == "self" and self.layout == "flat": raise ValueError("`layout='flat'` is only supported for cross grid nets") self.mlp_bias = bool(mlp_bias) @@ -702,10 +704,16 @@ def _from_grid(self, grid: torch.Tensor) -> torch.Tensor: return coeff.reshape(n_batch, coeff_dim, n_focus, -1) def _to_ndfc(self, value: torch.Tensor) -> tuple[torch.Tensor, tuple[int, ...]]: + # All grid operations run in the canonical ``(N, D, F, C)`` layout; the + # ``fndc`` re-orientation folds the focus-major SO(2) mixing layout into the + # same transpose the ``nfdc`` path performs, so the grid compute below is + # identical regardless of the caller's layout. if self.layout == "ndfc": return value, tuple(value.shape) if self.layout == "nfdc": return value.transpose(1, 2), tuple(value.shape) + if self.layout == "fndc": + return value.permute(1, 2, 0, 3), tuple(value.shape) n_batch, coeff_dim, _ = value.shape return ( value.reshape(n_batch, coeff_dim, self.n_focus, -1), @@ -721,6 +729,8 @@ def _restore_layout( return value if self.layout == "nfdc": return value.transpose(1, 2) + if self.layout == "fndc": + return value.permute(2, 0, 1, 3) n_batch, coeff_dim, _ = shape_info return value.reshape(n_batch, coeff_dim, -1) diff --git a/deepmd/pt/model/descriptor/sezm_nn/norm.py b/deepmd/pt/model/descriptor/sezm_nn/norm.py index 453c6af6af..e6704730f6 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/norm.py +++ b/deepmd/pt/model/descriptor/sezm_nn/norm.py @@ -446,18 +446,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Parameters ---------- x - Input tensor with shape (E, F, D_m_trunc, C). + Input tensor with shape (F, E, D_m_trunc, C). Returns ------- torch.Tensor - Normalized tensor with shape `(E, F, D_m_trunc, C)`, same dtype as + Normalized tensor with shape `(F, E, D_m_trunc, C)`, same dtype as input. """ in_dtype = x.dtype x = x.to(dtype=self.dtype) - x0 = x[:, :, :1, :] # (E, F, 1, C) - xt = x[:, :, 1:, :] # (E, F, D_m_trunc-1, C) + x0 = x[:, :, :1, :] # (F, E, 1, C) + xt = x[:, :, 1:, :] # (F, E, D_m_trunc-1, C) # === Step 1. Center the scalar slice === x0 = x0 - x0.mean(dim=-1, keepdim=True) @@ -480,13 +480,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: expanded_scale = torch.index_select( self.adam_scale, dim=1, index=self.degree_index_m ) - expanded_scale = expanded_scale.unsqueeze(0) # (1, F, D_m_trunc, C) + expanded_scale = expanded_scale.unsqueeze(1) # (F, 1, D_m_trunc, C) x0 = x0 * expanded_scale[:, :, :1, :] if xt.numel() > 0: xt = xt * expanded_scale[:, :, 1:, :] # === Step 4. Add scalar bias and restore layout === - bias0 = self.bias0.reshape(1, self.n_focus, 1, -1) # (1, F, 1, C) + bias0 = self.bias0.reshape(self.n_focus, 1, 1, -1) # (F, 1, 1, C) x0 = x0 + bias0 out = x0 if xt.numel() == 0 else torch.cat([x0, xt], dim=2) diff --git a/deepmd/pt/model/descriptor/sezm_nn/so2.py b/deepmd/pt/model/descriptor/sezm_nn/so2.py index 5d6e3298ff..fd194c94d7 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/so2.py +++ b/deepmd/pt/model/descriptor/sezm_nn/so2.py @@ -22,6 +22,10 @@ from deepmd.dpmodel.utils.seed import ( child_seed, ) +from deepmd.kernels.utils import ( + triton_infer_level, + use_cute_infer, +) from deepmd.pt.utils import ( env, ) @@ -81,7 +85,6 @@ np_safe, nvtx_range, safe_numpy_to_tensor, - use_triton_infer, ) if TYPE_CHECKING: @@ -124,8 +127,12 @@ class SO2Linear(nn.Module): rotates the output by the same angle. The weight is assembled once per forward (training) or cached (eval) - by ``_build_so2_weight()``, then applied via a single batched matmul - over all focus streams: ``einsum("efi,foi->efo")``. + by ``_build_so2_weight()`` in the ``(D_m*Cin, F, D_m*Cout)`` layout, then + applied as a per-``|m|``-block batched matmul over the focus streams. The + activation is carried in the focus-major layout ``(F, E, D_m, Cf)`` so that + the focus stream is the batch axis of the matmul: the assembled weight is + presented as ``(F, D_m*Cin, D_m*Cout)`` (a transient view, never a stored + parameter) and each block contracts with no transpose of the edge axis. Parameters ---------- @@ -340,24 +347,43 @@ def __init__( # Each |m| group occupies a contiguous (in, out) block on the diagonal. self._block_diag_slices = self._build_block_diag_slices() + # Inference fast path (``DP_TRITON_INFER >= 1``): the per-|m|-block + # batched bmm + cat of _block_diagonal_matmul is replaced by a fused + # Triton BN=64 block-diagonal GEMM that consumes the strided operands + # without a contiguity copy. Bound only when Triton is available and every + # block width aligns to BN=64; otherwise the eager path is kept. + self._block_diag_gemm = None + if triton_infer_level() >= 1: + from deepmd.kernels.triton.sezm.so2_block_gemm import ( + SO2_BLOCK_GEMM_TRITON_AVAILABLE, + block_diag_gemm, + slices_supported, + ) + + if SO2_BLOCK_GEMM_TRITON_AVAILABLE and slices_supported( + self._block_diag_slices + ): + self._block_diag_gemm = block_diag_gemm + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Parameters ---------- x - Input with shape (E, F, D_m_trunc, Cin), where D_m_trunc is the + Input with shape (F, E, D_m_trunc, Cin), where F is the focus stream + (the matmul batch axis), E the edge count, and D_m_trunc the coefficient dimension of the m-major layout truncated by `mmax`. Returns ------- torch.Tensor - Output with shape (E, F, D_m_trunc, Cout), where Cout is output channels. + Output with shape (F, E, D_m_trunc, Cout), where Cout is output channels. """ - # === Step 1. Flatten coefficient + channel axes for matmul === - # (E, F, D_m, Cin) -> (E, F, D_m*Cin) - n_edge = x.shape[0] + # === Step 1. Flatten coefficient + channel axes for the matmul === + # (F, E, D_m, Cin) -> (F, E, D_m*Cin); the focus stream stays the batch axis. + n_focus, n_edge = x.shape[0], x.shape[1] in_dim_total = self.reduced_dim * self.in_channels - x_flat = x.reshape(n_edge, self.n_focus, in_dim_total) + x_flat = x.reshape(n_focus, n_edge, in_dim_total) # === Step 2. Get block-diagonal weight (cached in eval+no_grad) === if self._cached_weight is not None: @@ -382,15 +408,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if use_block_diag: out_flat = self._block_diagonal_matmul(x_flat, weight) else: - out_flat = torch.einsum("efi,ifo->efo", x_flat, weight) - out = out_flat.reshape( - n_edge, self.n_focus, self.reduced_dim, self.out_channels - ) + out_flat = torch.einsum("fei,ifo->feo", x_flat, weight) + out = out_flat.reshape(n_focus, n_edge, self.reduced_dim, self.out_channels) # === Step 4. Bias on l=0 scalar index === if self.mlp_bias: bias0 = self.bias0.view(self.n_focus, self.out_channels) - out[:, :, 0, :] = out[:, :, 0, :] + bias0.unsqueeze(0) + out[:, :, 0, :] = out[:, :, 0, :] + bias0.unsqueeze(1) return out def _build_block_diag_slices(self) -> list[tuple[int, int, int, int]]: @@ -495,27 +519,36 @@ def _block_diagonal_matmul( ``weight`` is block-diagonal over ``|m|`` (cross-``|m|`` blocks are exactly zero), so concatenating the per-group matmuls reproduces the - dense ``einsum`` over the full ``(D_m*Cin, D_m*Cout)`` matrix while + dense contraction over the full ``(D_m*Cin, D_m*Cout)`` matrix while skipping the structural zeros. The result is fp32-equivalent to the dense path up to the matmul reduction order. + The focus stream is the batch axis of the per-block ``bmm``: the input + already carries it as ``(F, E, .)`` and the assembled weight is presented + as ``(F, D_m*Cin, D_m*Cout)``, so no edge-axis transpose is needed on + either operand and each block writes directly into the concatenated + output. The weight view is transient (the stored parameters keep their + assembled ``(D_m*Cin, F, D_m*Cout)`` layout). + Parameters ---------- x_flat : torch.Tensor - Flattened input with shape ``(E, F, D_m*Cin)``. + Flattened input with shape ``(F, E, D_m*Cin)``. weight : torch.Tensor Assembled block-diagonal weight with shape ``(D_m*Cin, F, D_m*Cout)``. Returns ------- torch.Tensor - Flattened output with shape ``(E, F, D_m*Cout)``. + Flattened output with shape ``(F, E, D_m*Cout)``. """ + weight = weight.permute(1, 0, 2) # (F, D_m*Cin, D_m*Cout) + if self._block_diag_gemm is not None and not self.training: + return self._block_diag_gemm(x_flat, weight, self._block_diag_slices) blocks = [ - torch.einsum( - "efi,ifo->efo", + torch.bmm( x_flat[:, :, in0:in1], - weight[in0:in1, :, out0:out1], + weight[:, in0:in1, out0:out1], ) for in0, in1, out0, out1 in self._block_diag_slices ] @@ -656,10 +689,10 @@ def __init__( for p in self.parameters(): p.requires_grad = trainable - # Inference fast path (opt-in via ``DP_TRITON_INFER``): a fused Triton + # Inference fast path (``DP_TRITON_INFER >= 1``): a fused Triton # kernel replaces the dense scatter and the tiny batched matmul of the # ``degree_channel`` low-rank branch in the ``mmax == 1`` layout. - self.use_triton_infer = use_triton_infer() + self.use_triton_infer = triton_infer_level() >= 1 self._radial_mix_block = None if ( self.use_triton_infer @@ -667,7 +700,7 @@ def __init__( and self.rank > 0 and self.mmax == 1 ): - from .triton.radial_mix import ( + from deepmd.kernels.triton.sezm.radial_mix import ( radial_mix_block, ) @@ -1138,12 +1171,32 @@ def __init__( self.device = env.DEVICE self.precision = RESERVED_PRECISION_DICT[dtype] self.compute_dtype = get_promoted_dtype(self.dtype) - # Optional Triton inference kernels for the SO(2) convolution, enabled by - # ``DP_TRITON_INFER=1`` (default disabled, in which case the dense - # ``bmm`` rotation is used). The flag is read once at construction so it - # is a compile-time constant in the traced (``make_fx``) graph, and it - # only takes effect during inference. - self.use_triton_infer = use_triton_infer() + # Opt-in inference fast paths, selected by ``DP_TRITON_INFER`` (a + # cumulative level, see :func:`triton_infer_level`) and + # ``DP_CUTE_INFER``. Each is read once at construction so it becomes a + # compile-time constant in the traced (``make_fx``) graph, and each + # only takes effect during inference. Level 1 replaces the dense + # ``bmm`` rotation with universal Triton kernels; level 2 additionally + # binds the table-configured fused value path; level 3 routes the + # mixing stack through the fp16x3 tensor-core operator on swept + # shapes. ``DP_CUTE_INFER`` selects the experimental CuTe value-path + # operator instead; both gates claim the same ``so2_message`` value + # path, so enabling them together has no coherent meaning and is + # rejected at construction. The fused value-path entries are bound at + # the end of construction (once every submodule exists) and stay + # ``None`` when the backend is unavailable or the block layout is + # unsupported. + self.triton_infer_level = triton_infer_level() + self.use_triton_infer = self.triton_infer_level >= 1 + self.use_cute_infer = use_cute_infer() + if self.use_triton_infer and self.use_cute_infer: + raise ValueError( + "DP_TRITON_INFER and DP_CUTE_INFER are mutually exclusive: " + "both select the fused SO(2) value-path backend. Enable " + "exactly one of them." + ) + self._cute_value_path = None + self._triton_value_path = None # === Step 1. Split deterministic seeds at the module top-level === seed_so2_stack = child_seed(seed, 0) @@ -1499,6 +1552,69 @@ def __init__( or self.node_wise_grid_product is not None ) + # === Step 12. Optional fused flash-attention aggregation kernel === + # Folds the entire ``n_atten_head > 0`` value aggregation -- block-diagonal + # rotate-back, inverse-rotation rescale, envelope-gated softmax weighting, + # and the destination scatter -- into a single destination-segmented + # Triton kernel, removing the transient ``x_message`` and weighted-value + # edge tensors and the ``index_add`` round trip. It shares the + # ``DP_TRITON_INFER`` gate with the other SeZM inference kernels and only + # engages for the supported ``mmax == 1`` attention layout without the + # optional focus-mix / value / output projections (the deployed DPA4 + # configuration); the op itself dispatches to an eager reference off the + # CUDA fp32 path. The output-side head gate stays a cheap node-level + # elementwise applied after the kernel. + self.use_flash_atten = ( + self.use_triton_infer + and self.n_atten_head > 0 + and self.mmax == 1 + and self.needs_local_frame + and not self.edge_cartesian + and not self.atten_f_mix + and self.attn_v_proj is None + and self.attn_o_proj is None + and self.attn_focus_mix is None + ) + self._flash_atten_fn = None + self._build_row_ptr_fn = None + if self.use_flash_atten: + from deepmd.kernels.triton.sezm.flash_atten import ( + build_row_ptr, + flash_atten_aggregate, + ) + + self._flash_atten_fn = flash_atten_aggregate + self._build_row_ptr_fn = build_row_ptr + + # === Step 13. Optional fused Triton SO(2) value-path operators === + # Fuses rotate-to-local, the radial degree mixing, the gated mixing + # stack, and the focus competition of ``so2_message`` into the + # ``sezm_triton::so2_rotate_mix`` / ``so2_mixing_stack`` operators. + # The factory validates the block layout (``mmax == 1``, gated stack + # with an identity final layer, supported focus widths) and returns + # ``None`` otherwise, leaving the reference path in charge. The value + # path resolves its launch configurations from the swept tables, so + # it engages at ``DP_TRITON_INFER >= 2``; at level 3 the factory + # additionally routes the mixing stack through the fp16x3 tensor-core + # operator on shapes whose configuration passed the fp64 validation + # sweep. + if self.triton_infer_level >= 2: + from deepmd.kernels.triton.sezm.so2_value_path import ( + make_triton_value_path, + ) + + self._triton_value_path = make_triton_value_path(self) + + # === Step 14. Optional fused CuTe SO(2) value-path operator === + # Experimental alternative backend; mutually exclusive with the Triton + # flag (enforced above). + if self.use_cute_infer: + from deepmd.kernels.cute.sezm import ( + make_cute_value_path, + ) + + self._cute_value_path = make_cute_value_path(self) + def forward( self, x: torch.Tensor, @@ -1533,7 +1649,17 @@ def forward( # === Step 2. Edge message: Cartesian product, SO(2) mixing, or the # rotation-free radial message when no local-frame operation is needed === - if self.edge_cartesian: + # In the fused flash-attention path the SO(2) message returns the + # pre-rotate-back per-focus local features; the rotate-back is folded into + # the aggregation kernel (Step 4). + run_flash = self.use_flash_atten and not self.training + x_local_flash: torch.Tensor | None = None + x_message: torch.Tensor | None = None + if run_flash: + x_local_flash, rad_feat = self.so2_message( + x, edge_cache, radial_feat, return_local=True + ) + elif self.edge_cartesian: x_message, rad_feat = self.cartesian_message(x, edge_cache, radial_feat) elif self.needs_local_frame: x_message, rad_feat = self.so2_message(x, edge_cache, radial_feat) @@ -1617,60 +1743,109 @@ def forward( ), ) # (E, F, H) - # === Step 4.3. Value projection and head-wise aggregation === - value_focus = x_message.reshape( - n_edge, - self.ebed_dim_full, - self.attn_n_focus, - self.attn_focus_dim, - ).to(dtype=compute_dtype) # (E, D, Fa, Ca) - if self.attn_v_proj is not None: - value_focus = self.attn_v_proj(value_focus) - value_heads = value_focus.reshape( - n_edge, - self.ebed_dim_full, - self.attn_n_focus, - self.n_atten_head, - self.head_dim, - ) # (E, D, Fa, H, Ch) - weighted_value = value_heads * attn_alpha.reshape( - n_edge, 1, self.attn_n_focus, self.n_atten_head, 1 - ) - out_heads = torch.zeros( - n_node, - self.ebed_dim_full, - self.attn_n_focus, - self.n_atten_head, - self.head_dim, - device=x.device, - dtype=compute_dtype, - ) # (N, D, Fa, H, Ch) - out_heads.index_add_(0, dst, weighted_value) - - # === Step 4.4. Output-side head gate === - attn_output_gate = torch.sigmoid( - torch.einsum( - "nfi,ifo->nfo", - self.attn_output_gate_norm(x_l0_node.to(dtype=compute_dtype)), - self.adamw_attn_gate_w, + if run_flash: + # === Step 4.3f. Fused rotate-back + envelope-softmax-weighted + # segment scatter. One destination-segmented Triton kernel + # folds the block-diagonal rotate-back, the inverse-rotation + # rescale, the per-edge ``attn_alpha`` weighting, and the + # destination reduction into a single atomic-free pass, + # returning the ungated aggregate ``(N, D, C_wide)``. The + # transient rotate-back message and weighted value tensors are + # never materialized. + row_ptr = self._build_row_ptr_fn(dst, n_node) + pre_gate = self._flash_atten_fn( + x_local_flash, + edge_cache.Dt_full, + self.rotate_inv_rescale_full, + attn_alpha, + row_ptr, + dst, + self.lmax, + self.n_atten_head, + ) # (N, D, C_wide) + + # === Step 4.4f. Output-side head gate (cheap node-level) === + attn_output_gate = torch.sigmoid( + torch.einsum( + "nfi,ifo->nfo", + self.attn_output_gate_norm( + x_l0_node.to(dtype=compute_dtype) + ), + self.adamw_attn_gate_w, + ) + ) # (N, Fa, H) + # Broadcast the per-(focus, head) gate over the head channels + # to the packed hidden width ``c = f * Cf + h * head_dim + ch``. + gate_full = ( + attn_output_gate.reshape( + n_node, self.attn_n_focus, self.n_atten_head, 1 + ) + .expand( + n_node, + self.attn_n_focus, + self.n_atten_head, + self.head_dim, + ) + .reshape(n_node, self.hidden_channels) + ) # (N, C_wide) + out = (pre_gate * gate_full.unsqueeze(1)).to(dtype=self.dtype) + else: + # === Step 4.3. Value projection and head-wise aggregation === + value_focus = x_message.reshape( + n_edge, + self.ebed_dim_full, + self.attn_n_focus, + self.attn_focus_dim, + ).to(dtype=compute_dtype) # (E, D, Fa, Ca) + if self.attn_v_proj is not None: + value_focus = self.attn_v_proj(value_focus) + value_heads = value_focus.reshape( + n_edge, + self.ebed_dim_full, + self.attn_n_focus, + self.n_atten_head, + self.head_dim, + ) # (E, D, Fa, H, Ch) + weighted_value = value_heads * attn_alpha.reshape( + n_edge, 1, self.attn_n_focus, self.n_atten_head, 1 ) - ) # (N, F, H) - out_heads = out_heads * attn_output_gate.reshape( - n_node, 1, self.attn_n_focus, self.n_atten_head, 1 - ) # (N, D, Fa, H, Ch) - - # === Step 4.5. Output projection and merge heads === - out_focus = out_heads.reshape( - n_node, - self.ebed_dim_full, - self.attn_n_focus, - self.attn_focus_dim, - ) # (N, D, Fa, Ca) - if self.attn_o_proj is not None: - out_focus = self.attn_o_proj(out_focus) - out = out_focus.reshape( - n_node, self.ebed_dim_full, self.hidden_channels - ).to(dtype=self.dtype) # (N, D, C_wide) + out_heads = torch.zeros( + n_node, + self.ebed_dim_full, + self.attn_n_focus, + self.n_atten_head, + self.head_dim, + device=x.device, + dtype=compute_dtype, + ) # (N, D, Fa, H, Ch) + out_heads.index_add_(0, dst, weighted_value) + + # === Step 4.4. Output-side head gate === + attn_output_gate = torch.sigmoid( + torch.einsum( + "nfi,ifo->nfo", + self.attn_output_gate_norm( + x_l0_node.to(dtype=compute_dtype) + ), + self.adamw_attn_gate_w, + ) + ) # (N, F, H) + out_heads = out_heads * attn_output_gate.reshape( + n_node, 1, self.attn_n_focus, self.n_atten_head, 1 + ) # (N, D, Fa, H, Ch) + + # === Step 4.5. Output projection and merge heads === + out_focus = out_heads.reshape( + n_node, + self.ebed_dim_full, + self.attn_n_focus, + self.attn_focus_dim, + ) # (N, D, Fa, Ca) + if self.attn_o_proj is not None: + out_focus = self.attn_o_proj(out_focus) + out = out_focus.reshape( + n_node, self.ebed_dim_full, self.hidden_channels + ).to(dtype=self.dtype) # (N, D, C_wide) # === Step 5. Optional message-node grid product === if self.message_node_grid_product is not None: @@ -1753,6 +1928,7 @@ def so2_message( x: torch.Tensor, edge_cache: EdgeFeatureCache, radial_feat: torch.Tensor, + return_local: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """ Build edge messages by rotate-to-local, SO(2) mixing, and rotate-back. @@ -1765,164 +1941,217 @@ def so2_message( Precomputed edge cache. radial_feat : torch.Tensor Per-edge radial features with shape (E, lmax+1, C). + return_local : bool + If True, return the pre-rotate-back per-focus local features + ``(E, F, D_m, Cf)`` instead of the rotated-back message. Used by the + fused flash-attention aggregation, which folds the rotate-back into + its own kernel. Returns ------- tuple[torch.Tensor, torch.Tensor] ``(x_message, rad_feat)`` with shapes (E, D, C_wide) and - (E, D_m, C_wide). The ``l=0`` slice of ``rad_feat`` is consumed by - the attention aggregation. + (E, D_m, C_wide) by default, or ``(x_local, rad_feat)`` with + ``x_local`` of shape (E, F, D_m, Cf) when ``return_local`` is True. + The ``l=0`` slice of ``rad_feat`` is consumed by the attention + aggregation. """ src, dst = edge_cache.src, edge_cache.dst n_edge = src.numel() - # === Step 1. Rotate to edge-aligned local frame === - with nvtx_range("SO2Conv/rotate_to_local"): - D_full = edge_cache.D_full - x_dst_local: torch.Tensor | None = None - if self.use_triton_infer and not self.training: - # ``self._rotate_to_local_fn`` was bound in ``__init__`` (the - # block kernel for the m-major ``mmax == 1`` layout, dense - # otherwise). - x_local = self._rotate_to_local_fn(x, src, D_full) # (E, D_m, C_wide) - if self.node_wise_grid_product is not None: - x_dst_local = self._rotate_to_local_fn( - x, dst, D_full + if self._triton_value_path is not None and not self.training: + # === Steps 1-5 (fused Triton operators). ``so2_rotate_mix`` folds + # the rotation and the radial degree mixing into one edge-parallel + # kernel writing the focus-major layout; ``so2_mixing_stack`` runs + # the whole gated stack with the competition weight fused into its + # final store, keeping the inter-layer activations off the traced + # graph. === + x_local, rad_feat = self._triton_value_path(x, edge_cache, radial_feat) + elif self._cute_value_path is not None and not self.training: + # === Steps 1-5 (fused CuTe operator). The operator folds + # rotate_to_local, radial degree mixing, the multi-layer gated SO(2) + # stack, and the focus competition into the bucketed kernels; the + # per-edge focus-major intermediates stay resident on chip. === + x_local, rad_feat = self._cute_value_path(x, edge_cache, radial_feat) + else: + # === Step 1. Rotate to edge-aligned local frame === + with nvtx_range("SO2Conv/rotate_to_local"): + D_full = edge_cache.D_full + x_dst_local: torch.Tensor | None = None + if self.use_triton_infer and not self.training: + # ``self._rotate_to_local_fn`` was bound in ``__init__`` (the + # block kernel for the m-major ``mmax == 1`` layout, dense + # otherwise). + x_local = self._rotate_to_local_fn( + x, src, D_full ) # (E, D_m, C_wide) - else: - D_m_prime = project_D_to_m( - D_full=D_full, - coeff_index_m=self.coeff_index_m, - ebed_dim_full=self.ebed_dim_full, - cache=edge_cache.D_to_m_cache, - key_lmax=self.lmax, - key_mmax=self.mmax, - ) - x_src = x.index_select(0, src) # (E, D, C_wide) - x_local = torch.bmm(D_m_prime, x_src) # (E, D_m, C_wide) - if self.node_wise_grid_product is not None: - x_dst = x.index_select(0, dst) # (E, D, C_wide) - x_dst_local = torch.bmm(D_m_prime, x_dst) # (E, D_m, C_wide) - - # === Step 2. Select radial/type features for reduced layout === - with nvtx_range("SO2Conv/radial_fuse"): - rad_feat = radial_feat[:, self.degree_index_m, :] # (E, D_m, C) - if self.radial_hidden_proj is not None: - rad_feat = self.radial_hidden_proj(rad_feat) - if self.radial_degree_mixer is None: - x_local.mul_(rad_feat) - else: - x_local = self.radial_degree_mixer(x_local, rad_feat) - if self.node_wise_grid_product is not None: - x_local = x_local + self.node_wise_grid_product( - x_local, - x_dst_local, - ) - rad_feat_l0_focus = rad_feat[:, 0, :].reshape( - n_edge, self.n_focus, self.so2_focus_dim - ) # (E, F, Cf) - - # === Step 3. Convert to SO(2) internal focus layout === - focus_gate_src: torch.Tensor | None = None - with nvtx_range("SO2Conv/reshape_for_so2"): - x_local = x_local.reshape( - n_edge, self.reduced_dim, self.n_focus, self.so2_focus_dim - ).transpose(1, 2) # (E, F, D_m, Cf), strided - if self.focus_compete and self.n_focus > 1: - focus_gate_src = x_local[:, :, 0, :] - - # === Step 4. Multi-layer SO(2) mixing (pre-norm + residual) === - with nvtx_range("SO2Conv/so2_layers"): - - def so2_l0_extractor(v: torch.Tensor) -> torch.Tensor: - """Extract scalar features from SO(2) reduced layout.""" - return v[:, :, 0, :].reshape(v.shape[0], self.hidden_channels) - - def apply_bias_correction( - x_local: torch.Tensor, - so2_linear: SO2Linear, - layer_idx: int, - ) -> None: - if layer_idx != 0 or so2_linear.bias0 is None: - return - bias0 = so2_linear.bias0.view( - self.n_focus, so2_linear.out_channels - ).unsqueeze(0) - if so2_linear.out_channels == self.so2_focus_dim: - radial_factor = rad_feat_l0_focus - elif so2_linear.out_channels == 2 * self.so2_focus_dim: - radial_factor = torch.cat( - [rad_feat_l0_focus, rad_feat_l0_focus], dim=-1 - ) + if self.node_wise_grid_product is not None: + x_dst_local = self._rotate_to_local_fn( + x, dst, D_full + ) # (E, D_m, C_wide) else: - raise RuntimeError( - "Unexpected SO2Linear output width in bias correction" - ) - bias_correction = bias0 * ( - radial_factor * edge_cache.edge_env.reshape(-1, 1, 1) - 1.0 - ) - x_local[:, :, 0, :].add_(bias_correction) - - if self.use_so2_attn_res: - so2_depth_sources = [x_local] - for layer_idx, (so2_linear, inter_norm, non_linear) in enumerate( - zip( - self.so2_linears, - self.so2_inter_norms, - self.non_linearities, - strict=True, + D_m_prime = project_D_to_m( + D_full=D_full, + coeff_index_m=self.coeff_index_m, + ebed_dim_full=self.ebed_dim_full, + cache=edge_cache.D_to_m_cache, + key_lmax=self.lmax, + key_mmax=self.mmax, ) - ): - x_local: torch.Tensor = self.so2_layer_attn_res[layer_idx]( - sources=so2_depth_sources, - scalar_extractor=so2_l0_extractor, - current_x=x_local, + x_src = x.index_select(0, src) # (E, D, C_wide) + x_local = torch.bmm(D_m_prime, x_src) # (E, D_m, C_wide) + if self.node_wise_grid_product is not None: + x_dst = x.index_select(0, dst) # (E, D, C_wide) + x_dst_local = torch.bmm(D_m_prime, x_dst) # (E, D_m, C_wide) + + # === Step 2. Select radial/type features for reduced layout === + with nvtx_range("SO2Conv/radial_fuse"): + rad_feat = radial_feat[:, self.degree_index_m, :] # (E, D_m, C) + if self.radial_hidden_proj is not None: + rad_feat = self.radial_hidden_proj(rad_feat) + if self.radial_degree_mixer is None: + x_local.mul_(rad_feat) + else: + x_local = self.radial_degree_mixer(x_local, rad_feat) + if self.node_wise_grid_product is not None: + x_local = x_local + self.node_wise_grid_product( + x_local, + x_dst_local, ) - residual = x_local - x_local = inter_norm(x_local) - x_local = so2_linear(x_local) - apply_bias_correction(x_local, so2_linear, layer_idx) - - x_local = non_linear(x_local) - - if self.layer_scale: - scale: torch.Tensor = self.adam_so2_layer_scales[ - layer_idx - ].reshape(1, self.n_focus, 1, self.so2_focus_dim) - x_local = residual + scale * x_local + rad_feat_l0_focus = rad_feat[:, 0, :].reshape( + n_edge, self.n_focus, self.so2_focus_dim + ) # (E, F, Cf) + + # === Step 3. Cast to the focus-major SO(2) mixing layout (F, E, D_m, Cf) === + # The mixing stack runs with the focus stream on the batch axis, the native + # layout of the block-diagonal batched matmul: the per-focus linear consumes + # it with no edge-axis transpose and writes each ``|m|`` block with no + # reassembly cost. This is a strided view of the reduced global buffer, + # materialized by the first linear's reshape exactly as any reduced-layout + # view would be. + focus_gate_src: torch.Tensor | None = None + with nvtx_range("SO2Conv/reshape_for_so2"): + x_local = x_local.reshape( + n_edge, self.reduced_dim, self.n_focus, self.so2_focus_dim + ).permute(2, 0, 1, 3) # (F, E, D_m, Cf), strided view + if self.focus_compete and self.n_focus > 1: + focus_gate_src = x_local[:, :, 0, :] # (F, E, Cf) + + # === Step 4. Multi-layer SO(2) mixing (pre-norm + residual) === + with nvtx_range("SO2Conv/so2_layers"): + + def so2_l0_extractor(v: torch.Tensor) -> torch.Tensor: + """Extract scalar features from the edge-major layout (E, F, D_m, Cf).""" + return v[:, :, 0, :].reshape(v.shape[0], self.hidden_channels) + + def apply_bias_correction( + x_local: torch.Tensor, + so2_linear: SO2Linear, + layer_idx: int, + ) -> None: + if layer_idx != 0 or so2_linear.bias0 is None: + return + if so2_linear.out_channels == self.so2_focus_dim: + radial_factor = rad_feat_l0_focus + elif so2_linear.out_channels == 2 * self.so2_focus_dim: + radial_factor = torch.cat( + [rad_feat_l0_focus, rad_feat_l0_focus], dim=-1 + ) else: - x_local = residual + x_local - so2_depth_sources.append(x_local - residual) - else: - for layer_idx, (so2_linear, inter_norm, non_linear) in enumerate( - zip( - self.so2_linears, - self.so2_inter_norms, - self.non_linearities, - strict=True, + raise RuntimeError( + "Unexpected SO2Linear output width in bias correction" + ) + # Focus-major broadcast: bias0 (F, Cout), the radial l=0 factor + # (E, F, .) transposed to (F, E, .), the per-edge envelope over the + # edge axis, applied to the l=0 scalar slice (F, E, Cout). + bias0 = so2_linear.bias0.view(self.n_focus, so2_linear.out_channels) + radial_factor = radial_factor.transpose(0, 1) # (F, E, .) + bias_correction = bias0.unsqueeze(1) * ( + radial_factor * edge_cache.edge_env.reshape(1, -1, 1) - 1.0 ) - ): - residual = x_local - x_local = inter_norm(x_local) - x_local = so2_linear(x_local) - apply_bias_correction(x_local, so2_linear, layer_idx) - - x_local = non_linear(x_local) - - if self.layer_scale: - scale = self.adam_so2_layer_scales[layer_idx].reshape( - 1, self.n_focus, 1, self.so2_focus_dim + x_local[:, :, 0, :].add_(bias_correction) + + if self.use_so2_attn_res: + # The depth-attention residual is a per-edge reduction over the + # layer history (``DepthAttnRes`` batches on axis 0), so the history + # is kept in the edge-major orientation and each mixing step + # transposes into the focus-major layout for the linear. + so2_depth_sources = [x_local.transpose(0, 1)] # (E, F, D_m, Cf) + for layer_idx, (so2_linear, inter_norm, non_linear) in enumerate( + zip( + self.so2_linears, + self.so2_inter_norms, + self.non_linearities, + strict=True, ) - x_local = residual + scale * x_local - else: - x_local = residual + x_local - - # === Step 5. Cross-focus softmax competition === - if self.focus_compete and self.n_focus > 1: - alpha = self._focus_alpha(focus_gate_src) - x_local = x_local * alpha.to(dtype=x_local.dtype).unsqueeze(-1).unsqueeze( - -1 - ) + ): + x_edge: torch.Tensor = self.so2_layer_attn_res[layer_idx]( + sources=so2_depth_sources, + scalar_extractor=so2_l0_extractor, + current_x=x_local.transpose(0, 1), + ) + x_local = x_edge.transpose(0, 1) # (F, E, D_m, Cf) + residual = x_local + x_local = inter_norm(x_local) + x_local = so2_linear(x_local) + apply_bias_correction(x_local, so2_linear, layer_idx) + + x_local = non_linear(x_local) + + if self.layer_scale: + scale: torch.Tensor = self.adam_so2_layer_scales[ + layer_idx + ].reshape(self.n_focus, 1, 1, self.so2_focus_dim) + x_local = residual + scale * x_local + else: + x_local = residual + x_local + so2_depth_sources.append((x_local - residual).transpose(0, 1)) + else: + for layer_idx, (so2_linear, inter_norm, non_linear) in enumerate( + zip( + self.so2_linears, + self.so2_inter_norms, + self.non_linearities, + strict=True, + ) + ): + residual = x_local + x_local = inter_norm(x_local) + x_local = so2_linear(x_local) + apply_bias_correction(x_local, so2_linear, layer_idx) + + x_local = non_linear(x_local) + + if self.layer_scale: + scale = self.adam_so2_layer_scales[layer_idx].reshape( + self.n_focus, 1, 1, self.so2_focus_dim + ) + x_local = residual + scale * x_local + else: + x_local = residual + x_local + + # === Step 5. Cross-focus softmax competition === + if self.focus_compete and self.n_focus > 1: + # ``_focus_alpha`` is shared with the rotation-free radial and Cartesian + # messages in the edge-major (E, F) orientation; feed it the transposed + # view of the focus-major scalar and broadcast the weights back over the + # focus-major activation. + alpha = self._focus_alpha(focus_gate_src.transpose(0, 1)) # (E, F) + x_local = x_local * alpha.transpose(0, 1).to( + dtype=x_local.dtype + ).unsqueeze(-1).unsqueeze(-1) + + # === Exit. Restore the (E, F, D_m, Cf) orientation === + # Both the fused flash-attention aggregation kernel and the rotate-back + # consume this orientation through explicit strides, so the focus-major + # buffer is handed back as a view with no copy. + x_local = x_local.permute(1, 0, 2, 3) # (E, F, D_m, Cf), strided view + + # The fused flash-attention aggregation consumes the per-focus + # ``(E, F, D_m, Cf)`` local layout directly and performs the rotate-back + # inside its kernel, so return before the standalone rotate-back. + if return_local: + return x_local, rad_feat # === Step 6. Rotate back to global frame === with nvtx_range("SO2Conv/rotate_back"): @@ -2094,7 +2323,7 @@ def _build_so2_mixing( self._rotate_to_local_fn = None self._rotate_back_fn = None if self.use_triton_infer: - from .triton.so2_rotation import ( + from deepmd.kernels.triton.sezm.so2_rotation import ( rotate_back_block_so2, rotate_back_dense, rotate_to_local_block, @@ -2161,6 +2390,11 @@ def _build_so2_mixing( self.so2_inter_norms = nn.ModuleList(inter_norms) # === Step 5. Intermediate non-linearity (the last layer stays linear) === + # Both branches run inside the focus-major SO(2) mixing layout, so they are + # built with ``layout="fndc"``: the ``S2GridNet`` activation (an S2 or + # SO(3) grid GLU per its grid configuration) folds the focus-major + # re-orientation into its coefficient/grid transpose, and the coefficient + # ``GatedActivation`` projects its per-focus gate in the same layout. non_linearities: list[nn.Module] = [] for i in range(self.mixing_layers): if i >= self.mixing_layers - 1: @@ -2175,7 +2409,7 @@ def _build_so2_mixing( mode="self", op_type="glu", dtype=self.compute_dtype, - layout="nfdc", + layout="fndc", grid_resolution_list=self.s2_grid_resolution, coefficient_layout="m_major", grid_method=self.s2_grid_method, @@ -2194,7 +2428,7 @@ def _build_so2_mixing( dtype=self.compute_dtype, activation_function=self.activation_function, mlp_bias=self.mlp_bias, - layout="nfdc", + layout="fndc", trainable=trainable, seed=child_seed(seed_non_linearities, i), ) diff --git a/deepmd/pt/model/descriptor/sezm_nn/triton/__init__.py b/deepmd/pt/model/descriptor/sezm_nn/triton/__init__.py deleted file mode 100644 index 3cc27f40d4..0000000000 --- a/deepmd/pt/model/descriptor/sezm_nn/triton/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -"""Hardware-accelerated SeZM/DPA4 operators. - -This package hosts ``make_fx``-composable Triton implementations of SeZM hot -paths. Kernel entry points are internal implementation details of the SeZM -descriptor; the package-level API only exposes availability. -""" - -from .radial_mix import ( - RADIAL_MIX_TRITON_AVAILABLE, -) -from .so2_rotation import ( - TRITON_ROTATION_AVAILABLE, -) - -# Both kernel modules guard their ``@triton.jit`` definitions behind a ``triton`` -# import, so the two module-level checks are equivalent. Expose a single -# package-level availability flag. -TRITON_AVAILABLE = TRITON_ROTATION_AVAILABLE and RADIAL_MIX_TRITON_AVAILABLE - -__all__ = [ - "TRITON_AVAILABLE", -] diff --git a/deepmd/pt/model/descriptor/sezm_nn/utils.py b/deepmd/pt/model/descriptor/sezm_nn/utils.py index 6bb1933b01..2cfc98d55a 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/utils.py +++ b/deepmd/pt/model/descriptor/sezm_nn/utils.py @@ -11,7 +11,6 @@ ) import math -import os from contextlib import ( contextmanager, ) @@ -35,24 +34,6 @@ ATTN_RES_MODES = ("none", "independent", "dependent") -_TRITON_INFER_TRUE = ("1", "true", "yes", "on") - - -def use_triton_infer() -> bool: - """Return whether the opt-in Triton inference kernels are enabled. - - The flag is controlled by the ``DP_TRITON_INFER`` environment variable and - is read at module construction time so that it becomes a compile-time - constant in the traced (``make_fx``) graph. It only takes effect during - inference; training always uses the dense reference path. - - Returns - ------- - bool - ``True`` when ``DP_TRITON_INFER`` is set to a truthy value. - """ - return os.environ.get("DP_TRITON_INFER", "0").strip().lower() in _TRITON_INFER_TRUE - def init_trunc_normal_fan_in_out( weight: torch.Tensor, diff --git a/deepmd/pt/model/descriptor/sezm_nn/wignerd.py b/deepmd/pt/model/descriptor/sezm_nn/wignerd.py index f668e6b657..064595090f 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/wignerd.py +++ b/deepmd/pt/model/descriptor/sezm_nn/wignerd.py @@ -13,6 +13,7 @@ import math from itertools import ( permutations, + product, ) from typing import ( Any, @@ -22,6 +23,9 @@ import torch import torch.nn as nn +from deepmd.kernels.utils import ( + triton_infer_level, +) from deepmd.pt.utils import ( env, ) @@ -434,6 +438,41 @@ def __init__( dtype=self.dtype, device=self.device, ) + # Flatten the monomial exponent tables to Python constants in + # eager context: the fused monomial operator bakes them into the + # kernel at compile time, and a trace-time ``.tolist()`` would + # create unbacked symbols under ``make_fx`` and abort export. + self._monomial_exponents_flat: dict[str, list[int]] = {} + for exp_name in ("exp_l3", "exp_l4", "exp_l5", "exp_l6"): + exps = getattr(self.small_order_kernels, exp_name, None) + if exps is not None: + self._monomial_exponents_flat[exp_name] = [ + int(v) for v in exps.reshape(-1).tolist() + ] + self._use_triton_monomials = triton_infer_level() >= 1 + # The l = 2 contraction tensor collapsed onto the 35 unique + # degree-4 monomials: column m of the coefficient matrix sums + # C_l2[:, :, p] over the 4^4 index tuples p whose component + # multiplicities equal the monomial exponents. + exp_l2: list[int] = [] + columns: list[torch.Tensor] = [] + index_of: dict[tuple[int, int, int, int], int] = {} + c_l2 = self.small_order_kernels.C_l2 + for p in product(range(4), repeat=4): + counts = (p.count(0), p.count(1), p.count(2), p.count(3)) + if counts not in index_of: + index_of[counts] = len(index_of) + exp_l2.extend(counts) + columns.append(torch.zeros_like(c_l2[:, :, 0, 0, 0, 0])) + columns[index_of[counts]] = ( + columns[index_of[counts]] + c_l2[:, :, p[0], p[1], p[2], p[3]] + ) + self._monomial_exponents_flat["exp_l2"] = exp_l2 + self.register_buffer( + "_l2_monomial_coeff", + torch.stack([c.reshape(-1) for c in columns], dim=0), + persistent=False, + ) if self.lmax >= self.poly_lmin: coeffs = self._precompute_wigner_coefficients( @@ -1030,26 +1069,34 @@ def _precompute_powers( q: torch.Tensor, max_power: int, ) -> torch.Tensor: - """Precompute powers ``q_i^k`` as a dense table with shape ``(4, max_power+1, E)``.""" + """Precompute powers ``q_i^k`` as a dense table with shape ``(4, max_power+1, E)``. + + The table is built by an explicit multiply chain: a ``cumprod`` over + the short power axis lowers to a scan whose forward and leave-one-out + backward cost several milliseconds per model call at typical edge + counts, whereas the unrolled chain stays a fusable pointwise sequence. + """ components = q.transpose(0, 1) + ones = torch.ones_like(components) if max_power == 0: - return torch.ones(4, 1, q.shape[0], dtype=q.dtype, device=q.device) - repeated = components.unsqueeze(1).expand(4, max_power, q.shape[0]) - positive_powers = torch.cumprod(repeated, dim=1) - return torch.cat( - [ - torch.ones(4, 1, q.shape[0], dtype=q.dtype, device=q.device), - positive_powers, - ], - dim=1, - ) + return ones.unsqueeze(1) + powers = [ones, components] + for _ in range(max_power - 1): + powers.append(powers[-1] * components) + return torch.stack(powers, dim=1) @staticmethod def _build_monomial_matrix( powers: torch.Tensor, monomial_exponents: torch.Tensor, ) -> torch.Tensor: - """Assemble the monomial design matrix for one fixed degree by gather/prod.""" + """Assemble the monomial design matrix for one fixed degree. + + The four gathered factor rows are combined by explicit multiplies: + ``prod(dim=0)`` lowers to a ``cumprod`` scan pair (forward plus + leave-one-out backward) on the large ``(4, M, E)`` intermediate, + while two multiply levels keep the chain pointwise and fusable. + """ gather_idx = ( monomial_exponents.transpose(0, 1) .unsqueeze(-1) @@ -1060,7 +1107,38 @@ def _build_monomial_matrix( ) ) selected = torch.gather(powers, 1, gather_idx) - return selected.prod(dim=0).transpose(0, 1).contiguous() + product = (selected[0] * selected[1]) * (selected[2] * selected[3]) + return product.transpose(0, 1).contiguous() + + def _monomial_matrix( + self, + edge_quaternion: torch.Tensor, + exp_name: str, + max_power: int, + ) -> torch.Tensor: + """Evaluate one degree kernel's monomial basis, with the fused fast path. + + On the CUDA inference path the fused operator evaluates the monomials + in registers with the exponent table baked in at compile time (see + :mod:`.triton.wigner_monomials`); construction-time solves and CPU + targets keep the dense power-table chain. + """ + exponents = self._monomial_exponents_flat.get(exp_name) + if ( + self._use_triton_monomials + and exponents is not None + and edge_quaternion.is_cuda + and not self.training + ): + from deepmd.kernels.triton.sezm.wigner_monomials import ( + wigner_monomials, + ) + + return wigner_monomials(edge_quaternion, exponents, max_power) + powers = self._precompute_powers(edge_quaternion, max_power) + return self._build_monomial_matrix( + powers, getattr(self.small_order_kernels, exp_name) + ) def _compute_l1_block(self, edge_quaternion: torch.Tensor) -> torch.Tensor: """Compute the vector block directly from the Cartesian rotation matrix.""" @@ -1069,7 +1147,27 @@ def _compute_l1_block(self, edge_quaternion: torch.Tensor) -> torch.Tensor: return rot_perm * self.l1_sign_outer def _compute_l2_block(self, edge_quaternion: torch.Tensor) -> torch.Tensor: - """Compute the ``l=2`` block from the degree-4 quaternion contraction.""" + """Compute the ``l=2`` block from the degree-4 quaternion contraction. + + The fused inference path collapses the 256 rank-4 index tuples onto + the 35 unique degree-4 monomials, replacing the ``(E, 4, 4, 4, 4)`` + outer product with a monomial evaluation and one ``(E, 35) x (35, 25)`` + product with no large intermediate. + """ + exponents = self._monomial_exponents_flat.get("exp_l2") + if ( + self._use_triton_monomials + and exponents is not None + and edge_quaternion.is_cuda + and not self.training + ): + from deepmd.kernels.triton.sezm.wigner_monomials import ( + wigner_monomials, + ) + + monomials = wigner_monomials(edge_quaternion, exponents, 4) + D_flat = torch.matmul(monomials, self._l2_monomial_coeff) + return D_flat.view(edge_quaternion.shape[0], 5, 5) q2 = edge_quaternion.unsqueeze(-1) * edge_quaternion.unsqueeze(-2) q4 = q2.unsqueeze(-1).unsqueeze(-1) * q2.unsqueeze(-3).unsqueeze(-3) return torch.einsum( @@ -1080,11 +1178,7 @@ def _compute_l2_block(self, edge_quaternion: torch.Tensor) -> torch.Tensor: def _compute_l3_block(self, edge_quaternion: torch.Tensor) -> torch.Tensor: """Compute the ``l=3`` block from the dedicated degree-6 monomial kernel.""" - powers = self._precompute_powers(edge_quaternion, 6) - monomials = self._build_monomial_matrix( - powers, - self.small_order_kernels.exp_l3, - ) + monomials = self._monomial_matrix(edge_quaternion, "exp_l3", 6) D_flat = torch.matmul( monomials, self.small_order_kernels.C_l3.transpose(0, 1), @@ -1096,11 +1190,7 @@ def _compute_l3l4_blocks( edge_quaternion: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """Compute the ``l=3`` and ``l=4`` blocks from one shared degree-8 kernel.""" - powers = self._precompute_powers(edge_quaternion, 8) - monomials = self._build_monomial_matrix( - powers, - self.small_order_kernels.exp_l4, - ) + monomials = self._monomial_matrix(edge_quaternion, "exp_l4", 8) D_flat = torch.matmul( monomials, self.small_order_kernels.C_combined_l3l4.transpose(0, 1), @@ -1111,11 +1201,7 @@ def _compute_l3l4_blocks( def _compute_l5_block(self, edge_quaternion: torch.Tensor) -> torch.Tensor: """Compute the ``l=5`` block from the dedicated degree-10 monomial kernel.""" - powers = self._precompute_powers(edge_quaternion, 10) - monomials = self._build_monomial_matrix( - powers, - self.small_order_kernels.exp_l5, - ) + monomials = self._monomial_matrix(edge_quaternion, "exp_l5", 10) D_flat = torch.matmul( monomials, self.small_order_kernels.C_l5.transpose(0, 1), @@ -1127,11 +1213,7 @@ def _compute_l5l6_blocks( edge_quaternion: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """Compute the ``l=5`` and ``l=6`` blocks from one shared degree-12 kernel.""" - powers = self._precompute_powers(edge_quaternion, 12) - monomials = self._build_monomial_matrix( - powers, - self.small_order_kernels.exp_l6, - ) + monomials = self._monomial_matrix(edge_quaternion, "exp_l6", 12) D_flat = torch.matmul( monomials, self.small_order_kernels.C_combined_l5l6.transpose(0, 1), diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 8f9b63dc80..8671a1e94e 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -84,6 +84,9 @@ from .sezm_model import ( SeZMModel, ) +from .sezm_native_spin_model import ( + SeZMNativeSpinModel, +) from .sezm_property_model import ( SeZMPropertyModel, ) @@ -124,17 +127,43 @@ def _get_standard_model_components(model_params: dict, ntypes: int) -> tuple: def _normalize_spin_use_spin(model_params: dict) -> None: - """Normalize spin.use_spin from type indices to per-type booleans.""" - if not model_params["spin"]["use_spin"] or isinstance( - model_params["spin"]["use_spin"][0], int - ): - use_spin = np.full(len(model_params["type_map"]), False, dtype=bool) - use_spin[model_params["spin"]["use_spin"]] = True - model_params["spin"]["use_spin"] = use_spin.tolist() + """Normalize ``spin.use_spin`` to a per-type boolean list. + + Three equivalent forms are accepted: a per-type boolean list, a list of + magnetic type indices, or a list of magnetic element symbols. The index and + symbol forms are expanded against ``type_map``, so a large type map only needs + its magnetic species named. + + Raises + ------ + ValueError + If a symbol in ``use_spin`` is absent from ``type_map``. + """ + use_spin = model_params["spin"]["use_spin"] + if use_spin and isinstance(use_spin[0], str): + type_index = {name: idx for idx, name in enumerate(model_params["type_map"])} + unknown = [name for name in use_spin if name not in type_index] + if unknown: + raise ValueError( + f"spin.use_spin references element(s) {unknown} absent from type_map." + ) + use_spin = [type_index[name] for name in use_spin] + # ``bool`` is a subclass of ``int``; an already-boolean list is left untouched + # while an index list is scattered into a per-type mask. + if not use_spin or not isinstance(use_spin[0], bool): + mask = np.full(len(model_params["type_map"]), False, dtype=bool) + mask[use_spin] = True + model_params["spin"]["use_spin"] = mask.tolist() def get_spin_model(model_params: dict) -> SpinModel: model_params = copy.deepcopy(model_params) + if model_params["spin"].get("allow_missing_label", False): + raise ValueError( + "spin.allow_missing_label is supported only by the SeZM/DPA4 spin model " + "(model.type='dpa4'/'sezm'), where a zero spin reduces the descriptor to " + "its spin-free form; the virtual-atom expansion here has no such limit." + ) _normalize_spin_use_spin(model_params) # include virtual spin and placeholder types model_params["type_map"] += [item + "_spin" for item in model_params["type_map"]] @@ -444,6 +473,107 @@ def get_sezm_model(model_params: dict) -> BaseModel: def get_sezm_spin_model(model_params: dict) -> BaseModel: + """Dispatch a DPA4/SeZM spin model by ``model.spin.scheme``. + + ``deepspin`` selects the classical virtual-atom representation; ``native`` + injects the spin into the descriptor as an equivariant feature. A missing + ``scheme`` falls back to ``deepspin``. + """ + scheme = str(model_params.get("spin", {}).get("scheme", "deepspin")).lower() + if scheme == "native": + return _get_sezm_native_spin_model(model_params) + if scheme == "deepspin": + return _get_sezm_virtual_spin_model(model_params) + raise ValueError( + f"Unknown spin scheme '{scheme}' for DPA4/SeZM; use 'native' or 'deepspin'." + ) + + +def _get_sezm_native_spin_model(model_params: dict) -> BaseModel: + """Build the native (virtual-atom-free) spin DPA4/SeZM model. + + The spin vector enters the descriptor as an equivariant feature, so the + type map, neighbor selection and type count stay at the real-system sizes; + only the descriptor gains the spin-embedding branch. + """ + model_params_old = model_params + model_params = copy.deepcopy(model_params) + model_params.setdefault("descriptor", {}) + model_params.setdefault("fitting_net", {}) + model_params["descriptor"].setdefault("type", "dpa4") + _normalize_spin_use_spin(model_params) + + use_spin = [bool(flag) for flag in model_params["spin"]["use_spin"]] + # ``virtual_scale`` is a virtual-atom geometric device; the native scheme + # only needs ``use_spin`` for masking, so default it when absent. + spin = Spin( + use_spin=use_spin, + virtual_scale=model_params["spin"].get("virtual_scale", 1.0), + allow_missing_label=model_params["spin"].get("allow_missing_label", False), + ) + + ntypes = len(model_params["type_map"]) + model_params["descriptor"]["ntypes"] = ntypes + model_params["descriptor"]["type_map"] = copy.deepcopy(model_params["type_map"]) + model_params["descriptor"]["use_spin"] = use_spin + + pair_exclude_types = model_params.get("pair_exclude_types", []) + model_params["pair_exclude_types"] = pair_exclude_types + if pair_exclude_types: + model_params["descriptor"]["exclude_types"] = copy.deepcopy(pair_exclude_types) + + # === Bridging parameters (no virtual atoms, so ZBL needs no masking) === + bridging_method = str(model_params.get("bridging_method", "none")).upper() + bridging_r_inner = float(model_params.get("bridging_r_inner", 0.5)) + bridging_r_outer = float(model_params.get("bridging_r_outer", 0.8)) + if bridging_method != "NONE": + model_params["descriptor"]["inner_clamp_r_inner"] = bridging_r_inner + model_params["descriptor"]["inner_clamp_r_outer"] = bridging_r_outer + + descriptor = BaseDescriptor(**model_params["descriptor"]) + + fitting_net = copy.deepcopy(model_params["fitting_net"]) + fitting_net_type = fitting_net.get("type", "dpa4_ener") + if fitting_net_type not in ("dpa4_ener", "sezm_ener"): + raise ValueError( + "Native spin DPA4/SeZM currently supports only `dpa4_ener` or " + f"`sezm_ener` fitting, but got `{fitting_net_type}`." + ) + fitting_net.pop("type", None) + fitting_net["ntypes"] = descriptor.get_ntypes() + fitting_net["type_map"] = copy.deepcopy(model_params["type_map"]) + fitting_net["mixed_types"] = descriptor.mixed_types() + fitting_net["dim_descrpt"] = descriptor.get_dim_out() + fitting = SeZMEnergyFittingNet(**fitting_net) + + preset_out_bias = model_params.get("preset_out_bias") + preset_out_bias = _convert_preset_out_bias_to_array( + preset_out_bias, model_params["type_map"] + ) + data_stat_protect = model_params.get("data_stat_protect", 1e-2) + use_compile = bool(model_params.get("use_compile", False)) + enable_tf32 = bool(model_params.get("enable_tf32", True)) + + model = SeZMNativeSpinModel( + descriptor=descriptor, + fitting=fitting, + type_map=model_params["type_map"], + atom_exclude_types=model_params.get("atom_exclude_types", []), + pair_exclude_types=pair_exclude_types, + preset_out_bias=preset_out_bias, + data_stat_protect=data_stat_protect, + use_compile=use_compile, + enable_tf32=enable_tf32, + bridging_method=bridging_method, + bridging_r_inner=bridging_r_inner, + bridging_r_outer=bridging_r_outer, + spin=spin, + ) + model.model_def_script = json.dumps(model_params_old) + return model + + +def _get_sezm_virtual_spin_model(model_params: dict) -> BaseModel: model_params_old = model_params model_params = copy.deepcopy(model_params) model_params.setdefault("descriptor", {}) @@ -458,6 +588,7 @@ def get_sezm_spin_model(model_params: dict) -> BaseModel: spin = Spin( use_spin=model_params["spin"]["use_spin"], virtual_scale=model_params["spin"]["virtual_scale"], + allow_missing_label=model_params["spin"].get("allow_missing_label", False), ) model_params["type_map"] += [item + "_spin" for item in model_params["type_map"]] pair_exclude_types = spin.get_pair_exclude_types( @@ -556,6 +687,7 @@ def get_model(model_params: dict) -> Any: "PolarModel", "PopulationModel", "SeZMModel", + "SeZMNativeSpinModel", "SeZMPropertyModel", "SeZMSpinModel", "SpinEnergyModel", diff --git a/deepmd/pt/model/model/sezm_model.py b/deepmd/pt/model/model/sezm_model.py index a9ade1c79d..b623222a41 100644 --- a/deepmd/pt/model/model/sezm_model.py +++ b/deepmd/pt/model/model/sezm_model.py @@ -502,6 +502,7 @@ get_task_buffer_values, next_safe_prime, rebuild_graph_module, + relax_views_to_reshapes, strip_saved_tensor_detach, trace_pad_dim, ) @@ -645,6 +646,7 @@ def _sezm_structure_key(model: SeZMModel) -> tuple[Any, ...]: bool(descriptor.use_gie), bool(descriptor.random_gamma), descriptor.charge_spin_embedding is not None, + descriptor.spin_embedding is not None, descriptor.inner_clamp is not None, descriptor.bridging_switch is not None, descriptor.inner_clamp_r_inner, @@ -951,6 +953,7 @@ def forward_common( force_input: Float[Tensor, "nf nloc 3"] | None = None, noise_mask: torch.Tensor | None = None, charge_spin: torch.Tensor | None = None, + spin: torch.Tensor | None = None, embedding_only: bool = False, ) -> dict[str, torch.Tensor]: """ @@ -997,6 +1000,10 @@ def forward_common( nf, nloc = atype.shape[:2] if cc.ndim == 2: cc = cc.view(nf, nloc, 3) + if spin is not None: + spin = spin.to(device=cc.device, dtype=cc.dtype).reshape( + nf, nloc, 3 + ) # === Step 2. Build geometry schema === with nvtx_range("SeZM/build_neighbor_list"): @@ -1034,6 +1041,7 @@ def forward_common( fparam=fp, aparam=ap, charge_spin=charge_spin, + spin=spin, input_prec=input_prec, embedding_only=embedding_only, ) @@ -1052,6 +1060,7 @@ def forward_common_lower( extended_atype: torch.Tensor | None = None, extended_coord_corr: torch.Tensor | None = None, charge_spin: torch.Tensor | None = None, + spin: torch.Tensor | None = None, input_prec: torch.dtype | None = None, use_compile: bool | None = None, embedding_only: bool = False, @@ -1096,6 +1105,8 @@ def forward_common_lower( ) if extended_coord_corr.ndim == 2: extended_coord_corr = extended_coord_corr.reshape(atype.shape[0], -1, 3) + if spin is not None: + spin = spin.to(device=coord.device, dtype=coord.dtype) nf = atype.shape[0] should_compile = ( self.should_use_compile() if use_compile is None else use_compile @@ -1157,6 +1168,7 @@ def forward_common_lower( ap, charge_spin, extended_coord_corr=extended_coord_corr, + spin=spin, ) compiled_core_compute = self.compiled_core_compute_cache[cache_key] task_buf_order = self._task_buf_order_cache[cache_key] @@ -1164,7 +1176,21 @@ def forward_common_lower( task_buf_vals = get_task_buffer_values(self, task_buf_order) grad_ctx: Any = nullcontext() if self.training else torch.no_grad() with nvtx_range("SeZM/core_compute"), grad_ctx: - if extended_coord_corr is None: + if spin is not None: + model_predict = compiled_core_compute( + coord, + atype, + edge_index, + edge_vec, + edge_scatter_index, + edge_mask, + fp, + ap, + charge_spin, + spin, + *task_buf_vals, + ) + elif extended_coord_corr is None: model_predict = compiled_core_compute( coord, atype, @@ -1220,6 +1246,7 @@ def forward_common_lower( comm_dict=comm_dict, extended_atype=extended_atype, extended_coord_corr=extended_coord_corr, + spin=spin, embedding_only=embedding_only, ) return self._output_type_cast(model_predict, input_prec) @@ -1363,6 +1390,7 @@ def core_compute( comm_dict: dict[str, torch.Tensor] | None = None, extended_atype: torch.Tensor | None = None, extended_coord_corr: torch.Tensor | None = None, + spin: torch.Tensor | None = None, embedding_only: bool = False, conservative: bool = True, ) -> dict[str, torch.Tensor]: @@ -1405,6 +1433,11 @@ def core_compute( extended_coord_corr Coordinates correction for virial with shape ``(nf, nscatter, 3)`` or ``None``. + spin + Optional per-atom spin vectors with shape ``(nf, nloc, 3)`` for the + native spin scheme. When provided on the conservative path, the + spin tensor becomes a second autograd leaf so the magnetic force + ``-dE/dspin`` is produced by the same backward as the force/virial. embedding_only When ``True``, return only the embedding outputs and skip the force/virial autograd entirely. @@ -1440,6 +1473,12 @@ def core_compute( if conservative and not embedding_only: edge_vec = edge_vec.detach().requires_grad_(True) + # Native spin: the per-atom spin is a second autograd leaf, so the + # magnetic force -dE/dspin is produced by the same backward that + # scatters the edge gradient into force/virial. + if spin is not None and conservative and not embedding_only: + spin = spin.detach().requires_grad_(True) + # === Step 2. Descriptor forward === # ``extended_atype`` spans the extended region on the parallel path and # reduces to ``atype`` (owned atoms) on the single-domain path; the @@ -1455,6 +1494,7 @@ def core_compute( edge_vec=edge_vec, edge_mask=edge_mask, charge_spin=charge_spin, + spin=spin, comm_dict=comm_dict, nloc=nloc, ) @@ -1545,18 +1585,21 @@ def core_compute( energy_redu = torch.sum( energy_atom.to(env.GLOBAL_PT_ENER_FLOAT_PRECISION), dim=1 ) - energy_derv_r, energy_derv_c, energy_derv_c_redu = edge_energy_deriv( - energy_redu, - edge_vec, - edge_scatter_index[0], - edge_scatter_index[1], - edge_mask, - nf, - nscatter, - create_graph=self.training, - extended_coord_corr=extended_coord_corr, + energy_derv_r, energy_derv_c, energy_derv_c_redu, energy_derv_r_mag = ( + edge_energy_deriv( + energy_redu, + edge_vec, + edge_scatter_index[0], + edge_scatter_index[1], + edge_mask, + nf, + nscatter, + create_graph=self.training, + extended_coord_corr=extended_coord_corr, + spin_leaf=spin, + ) ) - return { + model_ret = { "energy": energy_atom, "energy_redu": energy_redu, "energy_derv_r": energy_derv_r, @@ -1564,6 +1607,9 @@ def core_compute( "energy_derv_c_redu": energy_derv_c_redu, "mask": fit_ret["mask"], } + if energy_derv_r_mag is not None: + model_ret["energy_derv_r_mag"] = energy_derv_r_mag + return model_ret def core_compute_dens( self, @@ -1669,7 +1715,6 @@ def core_compute_dens( dim=-1, ) - @torch.jit.export def forward_lower( self, coord: Float[Tensor, "nf nscatter_x3"] | Float[Tensor, "nf nscatter 3"], @@ -1800,6 +1845,7 @@ def trace_and_compile( ap: torch.Tensor, charge_spin: torch.Tensor, extended_coord_corr: torch.Tensor | None = None, + spin: torch.Tensor | None = None, embedding_only: bool = False, ) -> None: """Trace ``core_compute()`` with ``make_fx`` and cache the compiled callable. @@ -1923,7 +1969,45 @@ def _prepare_coord_for_trace(coord: torch.Tensor) -> torch.Tensor: # into _buffers so downstream code (apply_out_stat, fitting_net.forward) # reads the proxies and the ops are recorded in the FX graph. The # finally block restores original state unconditionally. - if extended_coord_corr is None: + if spin is not None: + + def compute_fn( # type: ignore[misc] + coord: torch.Tensor, + atype: torch.Tensor, + edge_index: torch.Tensor, + edge_vec: torch.Tensor, + edge_scatter_index: torch.Tensor, + edge_mask: torch.Tensor, + fp: torch.Tensor, + ap: torch.Tensor, + charge_spin: torch.Tensor, + spin: torch.Tensor, + *task_buf_vals: torch.Tensor, + ) -> dict[str, torch.Tensor]: + # NOTE: Native spin adds the per-atom spin as a second autograd + # endpoint inside ``core_compute``; ``make_fx`` unfolds the same + # single ``autograd.grad(energy, [edge_vec, spin])`` it already + # captures for the edge-vector force, so the magnetic force is + # produced by the compiled graph with no extra backward. + _saved = _patch_task_bufs(task_buf_vals) + try: + return self.core_compute( + _prepare_coord_for_trace(coord), + atype, + edge_index, + edge_vec, + edge_scatter_index, + edge_mask, + fparam=fp, + aparam=ap, + charge_spin=charge_spin, + spin=spin, + embedding_only=embedding_only, + ) + finally: + _restore_task_bufs(_saved) + + elif extended_coord_corr is None: def compute_fn( coord: torch.Tensor, @@ -2056,7 +2140,11 @@ def compute_fn( # type: ignore[misc] ap_for_trace, charge_spin_for_trace, ] - if extended_coord_corr is not None: + if spin is not None: + spin_for_trace = trace_pad_dim(spin[:1], 0, trace_nf) + spin_for_trace = trace_pad_dim(spin_for_trace, 1, trace_nloc) + trace_args.append(spin_for_trace) + elif extended_coord_corr is not None: corr_for_trace = trace_pad_dim(extended_coord_corr[:1], 0, trace_nf) corr_for_trace = trace_pad_dim(corr_for_trace, 1, trace_nscatter) trace_args.append(corr_for_trace) @@ -2126,8 +2214,11 @@ def compute_fn( # type: ignore[misc] # The conservative Inductor option set that keeps the dynamic edge # graph lowerable is centralised in ``deepmd.pt.utils.compile_compat``; - # subclasses may augment it via ``_inductor_compile_options``. - compile_options = self._inductor_compile_options() + # subclasses may augment it via ``_inductor_compile_options``. The + # inference lowering additionally disables the peak-memory reordering + # pass, whose cost model is blind to the hint-less edge symbols of the + # ``make_fx`` graph (see ``build_inductor_compile_options``). + compile_options = self._inductor_compile_options(inference=not self.training) # NOTE: Store the compiled callable inside the plain-``dict`` # cache ``compiled_core_compute_cache``. The dict itself was installed @@ -2310,19 +2401,19 @@ def should_use_compile(self) -> bool: return self.use_compile return bool(self._env_use_compile_infer) - def _inductor_compile_options(self) -> dict[str, Any]: + def _inductor_compile_options(self, *, inference: bool = False) -> dict[str, Any]: """Return the Inductor lowering options for this model's compiled core. Subclasses may override this to augment the shared option set from :func:`build_inductor_compile_options` with model-specific entries. """ - return build_inductor_compile_options() + return build_inductor_compile_options(inference=inference) # ========================================================================= # Export Utilities # ========================================================================= - def _trace_lower_exportable( + def trace_lower_exportable( self, fn: Any, *sample_inputs: torch.Tensor | None, @@ -2333,7 +2424,7 @@ def _trace_lower_exportable( get_decompositions, ) - return make_fx( + traced = make_fx( fn, tracing_mode="symbolic", _allow_non_fake_inputs=True, @@ -2341,6 +2432,10 @@ def _trace_lower_exportable( [torch.ops.aten.silu_backward.default] ), )(*sample_inputs) + # make_fx can lower a reshape to an unsound aten.view when the fake + # stride diverges from the eager stride; relax views back to reshapes. + relax_views_to_reshapes(traced) + return traced def forward_common_lower_exportable( self, @@ -2452,7 +2547,7 @@ def fn( charge_spin, ) - return self._trace_lower_exportable( + return self.trace_lower_exportable( fn, *trace_inputs, ) @@ -2579,7 +2674,7 @@ def fn( nlocal, nghost, ) - return self._trace_lower_exportable(fn, *trace_inputs) + return self.trace_lower_exportable(fn, *trace_inputs) # ========================================================================= # Neighbor List Construction @@ -2938,14 +3033,30 @@ def supports_edge_parallel(self) -> bool: Cross-rank ghost-feature exchange is well-defined only for the conservative non-bridging path: analytical ZBL bridging and its Source Freeze Propagation gate fold each node's full outgoing-edge set, which a - single rank cannot observe for ghost owners. Spin models use the nlist - lower interface and are gated separately by the freeze entry point. + single rank cannot observe for ghost owners. The native spin scheme + reuses the edge_vec interface and therefore participates; only the + deepspin (virtual-atom) scheme uses the nlist interface and is excluded + by the freeze entry point's edge_vec gate. """ if self.inter_potential is not None: return False descriptor = self.atomic_model.descriptor return bool(descriptor.has_message_passing_across_ranks()) + def export_lower_input_kind(self) -> str: + """Return the ABI consumed by the exported ``.pt2`` lower graph. + + ``"edge_vec"`` means the graph receives the compact edge schema + (``coord``, ``atype``, ``edge_index``, ``edge_vec``, + ``edge_scatter_index``, ``edge_mask``) and the neighbor topology is + built on the C++ side, matching :class:`DeepPotPTExpt`. The native + spin scheme reuses this ABI with one extra per-local-atom spin input; + only the deepspin (virtual-atom) scheme overrides it to ``"nlist"`` + because it expands virtual atoms inside the graph from the extended + neighbor list. + """ + return "edge_vec" + # ========================================================================= # Mode Management # ========================================================================= diff --git a/deepmd/pt/model/model/sezm_native_spin_model.py b/deepmd/pt/model/model/sezm_native_spin_model.py new file mode 100644 index 0000000000..aeec0ae74e --- /dev/null +++ b/deepmd/pt/model/model/sezm_native_spin_model.py @@ -0,0 +1,508 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Native-spin SeZM energy model. + +Unlike the virtual-atom :class:`SeZMSpinModel`, the native scheme injects the +per-atom spin vector directly into the descriptor as an equivariant feature +(``l = 0`` magnitude and ``l = 1`` direction) and obtains the magnetic force as +the negative spin gradient of the energy. No virtual atoms are created, so the +neighbor list, type map and selection stay at their real-system sizes, and the +analytical bridging potential needs no real/virtual masking. +""" + +from copy import ( + deepcopy, +) +from typing import ( + Any, +) + +import torch +from einops import ( + rearrange, +) + +from deepmd.dpmodel import ( + ModelOutputDef, +) +from deepmd.pt.model.atomic_model.sezm_atomic_model import ( + SeZMAtomicModel, +) +from deepmd.pt.model.model.model import ( + BaseModel, +) +from deepmd.pt.model.model.sezm_model import ( + SeZMModel, +) +from deepmd.pt.utils.utils import ( + to_torch_tensor, +) +from deepmd.utils.spin import ( + Spin, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + + +@BaseModel.register("sezm_native_spin") +class SeZMNativeSpinModel(SeZMModel): + """SeZM energy model with native (virtual-atom-free) spin. + + The per-atom spin enters the descriptor through :class:`SpinEmbedding`, and + the magnetic force is the negative spin gradient of the energy, produced by + the same backward as the conservative force and virial. + + Parameters + ---------- + spin + Spin metadata describing which real atom types carry spin. + *args + Positional arguments forwarded to :class:`SeZMModel`. + **kwargs + Keyword arguments forwarded to :class:`SeZMModel`. + """ + + model_type = "sezm_native_spin" + + def __init__( + self, + *args: Any, + spin: Spin, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + self.spin = spin + self.ntypes_real = self.spin.ntypes_real + # Per-type 0/1 spin gate. + self.register_buffer( + "spin_mask", + to_torch_tensor(self.spin.get_spin_mask()), + persistent=False, + ) + + # ========================================================================= + # Forward Methods + # ========================================================================= + + def forward( + self, + coord: torch.Tensor, + atype: torch.Tensor, + spin: torch.Tensor, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + charge_spin: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """Return native-spin SeZM predictions with public output keys. + + ``mask_mag`` is built from the per-type spin gate on the local + ``atype``; non-magnetic atoms already carry a zero magnetic force (the + descriptor gates the spin embedding by type), so the force itself needs + no re-masking. This is the runtime counterpart of the static schema in + :meth:`translated_output_def`. + """ + model_ret = self.forward_common( + coord, + atype, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + charge_spin=charge_spin, + spin=spin, + ) + nf, nloc = atype.shape[:2] + model_predict: dict[str, torch.Tensor] = { + "atom_energy": model_ret["energy"], + "energy": model_ret["energy_redu"], + "mask_mag": self.spin_mask.index_select(0, atype.reshape(-1)).reshape( + nf, nloc, 1 + ) + > 0.0, + } + if self.do_grad_r("energy"): + model_predict["force"] = rearrange( + model_ret["energy_derv_r"], "nf n 1 three -> nf n three", three=3 + ) + model_predict["force_mag"] = rearrange( + model_ret["energy_derv_r_mag"], "nf n 1 three -> nf n three", three=3 + ) + if self.do_grad_c("energy"): + model_predict["virial"] = rearrange( + model_ret["energy_derv_c_redu"], "nf 1 nine -> nf nine", nine=9 + ) + if do_atomic_virial: + model_predict["atom_virial"] = rearrange( + model_ret["energy_derv_c"], "nf n 1 nine -> nf n nine", nine=9 + ) + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + return model_predict + + # ========================================================================= + # Export + # ========================================================================= + + def forward_common_lower_exportable( + self, + coord: torch.Tensor, + atype: torch.Tensor, + edge_index: torch.Tensor, + edge_vec: torch.Tensor, + edge_scatter_index: torch.Tensor, + edge_mask: torch.Tensor, + spin: torch.Tensor, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + charge_spin: torch.Tensor | None = None, + ) -> torch.nn.Module: + """Trace the native-spin lower interface into an exportable FX graph. + + The native scheme reuses the energy model's edge ABI (``coord``, + ``atype``, ``edge_index``, ``edge_vec``, ``edge_scatter_index``, + ``edge_mask``); the only addition is the per-local-atom ``spin`` leaf, + so ``make_fx`` unfolds the single + ``autograd.grad(energy, [edge_vec, spin])`` into the conservative force + and the magnetic force. The C++ backend builds the edge schema (exactly + like :class:`DeepPotPTExpt`) and feeds the owned-atom spins, so a spin + and a non-spin ``.pt2`` archive share one inference path. + + Parameters + ---------- + coord + Extended coordinates with shape ``(nf, nall, 3)``. + atype + Local atom types with shape ``(nf, nloc)``. + edge_index + Local-folded source/destination indices with shape ``(2, nedge)``. + edge_vec + Per-edge displacement with shape ``(nedge, 3)``. + edge_scatter_index + Extended source/destination indices for the force scatter with + shape ``(2, nedge)``. + edge_mask + Boolean per-edge validity mask with shape ``(nedge,)``. + spin + Per-local-atom spin vectors with shape ``(nf, nloc, 3)``. + fparam, aparam, charge_spin + Optional frame / atomic / charge-spin conditioning inputs. + + Returns + ------- + torch.nn.Module + The traced exportable lower graph. + """ + if self.get_active_mode() == "dens": + raise NotImplementedError( + "SeZM export supports only the conservative `ener` path." + ) + model = self + + def lower_fn( + coord_: torch.Tensor, + atype_: torch.Tensor, + edge_index_: torch.Tensor, + edge_vec_: torch.Tensor, + edge_scatter_index_: torch.Tensor, + edge_mask_: torch.Tensor, + spin_: torch.Tensor, + fparam_: torch.Tensor | None, + aparam_: torch.Tensor | None, + charge_spin_: torch.Tensor | None, + ) -> dict[str, torch.Tensor]: + # Detach the leaves inside the traced closure so the exported graph + # owns its own force/magnetic-force autograd endpoints (edge_vec and + # spin) rather than capturing the upstream LAMMPS tensors. + coord_ = coord_.detach() + edge_vec_ = edge_vec_.detach() + model_ret = model.forward_common_lower( + coord_, + atype_, + edge_index_, + edge_vec_, + edge_scatter_index_, + edge_mask_, + fparam=fparam_, + aparam=aparam_, + charge_spin=charge_spin_, + spin=spin_, + use_compile=False, + ) + return model._attach_spin_masks( + model_ret, atype=atype_, nall=coord_.shape[1] + ) + + if self.get_dim_chg_spin() > 0: + charge_spin = self.convert_charge_spin( + charge_spin, + nf=atype.shape[0], + dtype=coord.dtype, + device=coord.device, + ) + return self.trace_lower_exportable( + lower_fn, + coord, + atype, + edge_index, + edge_vec, + edge_scatter_index, + edge_mask, + spin, + fparam, + aparam, + charge_spin, + ) + + def forward_common_lower_exportable_with_comm( + self, + coord: torch.Tensor, + atype: torch.Tensor, + extended_atype: torch.Tensor, + edge_index: torch.Tensor, + edge_vec: torch.Tensor, + edge_scatter_index: torch.Tensor, + edge_mask: torch.Tensor, + spin: torch.Tensor, + fparam: torch.Tensor | None, + aparam: torch.Tensor | None, + charge_spin: torch.Tensor | None, + send_list: torch.Tensor, + send_proc: torch.Tensor, + recv_proc: torch.Tensor, + send_num: torch.Tensor, + recv_num: torch.Tensor, + communicator: torch.Tensor, + nlocal: torch.Tensor, + nghost: torch.Tensor, + ) -> torch.nn.Module: + """Trace the native-spin parallel (with-comm) lower interface. + + Mirrors :meth:`SeZMModel.forward_common_lower_exportable_with_comm` and + adds the extended (nall) per-atom spin leaf. On the parallel path the + spin is a per-extended-node feature -- the LAMMPS spin reverse/forward + comm supplies ghost spins -- so the magnetic force ``-dE/dspin`` is + itself extended (nall): its ghost rows are the cross-rank neighbour + contributions that ``border_op``'s exact-VJP backward routes correctly + and the LAMMPS spin reverse-comm folds onto owners. ``_attach_spin_masks`` + is a no-op pad here (the force is already extended) and attaches the + extended ``mask_mag``. + """ + if self.get_active_mode() == "dens": + raise NotImplementedError( + "SeZM export supports only the conservative `ener` path." + ) + from deepmd.pt_expt.utils.comm import ( + ensure_comm_registered, + ) + + ensure_comm_registered() + model = self + + def fn( + coord_: torch.Tensor, + atype_: torch.Tensor, + extended_atype_: torch.Tensor, + edge_index_: torch.Tensor, + edge_vec_: torch.Tensor, + edge_scatter_index_: torch.Tensor, + edge_mask_: torch.Tensor, + spin_: torch.Tensor, + fparam_: torch.Tensor | None, + aparam_: torch.Tensor | None, + charge_spin_: torch.Tensor | None, + send_list_: torch.Tensor, + send_proc_: torch.Tensor, + recv_proc_: torch.Tensor, + send_num_: torch.Tensor, + recv_num_: torch.Tensor, + communicator_: torch.Tensor, + nlocal_: torch.Tensor, + nghost_: torch.Tensor, + ) -> dict[str, torch.Tensor]: + coord_ = coord_.detach() + edge_vec_ = edge_vec_.detach() + comm_dict = { + "send_list": send_list_, + "send_proc": send_proc_, + "recv_proc": recv_proc_, + "send_num": send_num_, + "recv_num": recv_num_, + "communicator": communicator_, + "nlocal": nlocal_, + "nghost": nghost_, + } + model_ret = model.forward_common_lower( + coord_, + atype_, + edge_index_, + edge_vec_, + edge_scatter_index_, + edge_mask_, + fparam=fparam_, + aparam=aparam_, + comm_dict=comm_dict, + extended_atype=extended_atype_, + charge_spin=charge_spin_, + spin=spin_, + use_compile=False, + ) + return model._attach_spin_masks( + model_ret, atype=extended_atype_, nall=coord_.shape[1] + ) + + if self.get_dim_chg_spin() > 0: + charge_spin = self.convert_charge_spin( + charge_spin, + nf=atype.shape[0], + dtype=coord.dtype, + device=coord.device, + ) + return self.trace_lower_exportable( + fn, + coord, + atype, + extended_atype, + edge_index, + edge_vec, + edge_scatter_index, + edge_mask, + spin, + fparam, + aparam, + charge_spin, + send_list, + send_proc, + recv_proc, + send_num, + recv_num, + communicator, + nlocal, + nghost, + ) + + def _attach_spin_masks( + self, + model_ret: dict[str, torch.Tensor], + *, + atype: torch.Tensor, + nall: int, + ) -> dict[str, torch.Tensor]: + """Express the magnetic force in the extended layout and attach ``mask_mag``. + + Parameters + ---------- + model_ret + Internal SeZM lower outputs; ``energy_derv_r_mag`` has the + per-local-atom shape ``(nf, nloc, 1, 3)``. + atype + Local atom types with shape ``(nf, nloc)``, used to build + ``mask_mag`` with shape ``(nf, nloc, 1)``. + nall + Extended atom count the magnetic force is padded to. + + Returns + ------- + dict[str, torch.Tensor] + ``model_ret`` with ``energy_derv_r_mag`` padded to ``nall`` and a + ``mask_mag`` entry. + + Notes + ----- + The magnetic force is intrinsically per-local-atom (only owned spins + enter the descriptor); padding the ghost slots with zero lets it share + the extended reduce / fold-back contract that + ``communicate_extended_output`` and the LAMMPS C++ backend apply to the + conservative force, so the native and deepspin schemes share one + downstream path. The padding is unconditional (``nall - nloc`` is zero + for an isolated cluster) so the closure stays free of shape-dependent + branches under ``make_fx`` symbolic tracing. + """ + derv_r_mag = model_ret["energy_derv_r_mag"] # (nf, nloc, 1, 3) + nf, nloc = derv_r_mag.shape[:2] + ghost_pad = derv_r_mag.new_zeros(nf, nall - nloc, *derv_r_mag.shape[2:]) + model_ret["energy_derv_r_mag"] = torch.cat([derv_r_mag, ghost_pad], dim=1) + model_ret["mask_mag"] = ( + self.spin_mask.index_select(0, atype.reshape(-1)).reshape( + atype.shape[0], atype.shape[1], 1 + ) + > 0.0 + ) + return model_ret + + # ========================================================================= + # Mode Selection + # ========================================================================= + + def set_active_mode(self, mode: str) -> None: + """Switch mode, allowing only the conservative energy path.""" + normalized = str(mode).lower() + if normalized != "ener": + raise NotImplementedError("SeZM native spin supports only the `ener` path.") + super().set_active_mode(normalized) + + def set_active_mode_from_loss(self, loss_type: str) -> None: + """Select execution mode from loss type.""" + normalized = str(loss_type).lower() + if normalized == "dens": + raise NotImplementedError("SeZM native spin supports only the `ener` path.") + if normalized in {"ener", "ener_spin"}: + self.set_active_mode("ener") + + # ========================================================================= + # Output Definitions and Metadata + # ========================================================================= + + def has_spin(self) -> bool: + """Return whether this model consumes spin input.""" + return True + + def model_output_def(self) -> ModelOutputDef: + """Return the spin-aware model output definition.""" + atomic_output_def = self.atomic_output_def() + atomic_output_def["energy"].magnetic = True + return ModelOutputDef(atomic_output_def) + + def translated_output_def(self) -> dict[str, Any]: + """Translate internal output definitions to public spin keys.""" + out_def_data = self.model_output_def().get_data() + output_def = { + "atom_energy": out_def_data["energy"], + "energy": out_def_data["energy_redu"], + "mask_mag": out_def_data["mask_mag"], + } + if self.do_grad_r("energy"): + output_def["force"] = deepcopy(out_def_data["energy_derv_r"]) + output_def["force"].squeeze(-2) + output_def["force_mag"] = deepcopy(out_def_data["energy_derv_r_mag"]) + output_def["force_mag"].squeeze(-2) + if self.do_grad_c("energy"): + output_def["virial"] = deepcopy(out_def_data["energy_derv_c_redu"]) + output_def["virial"].squeeze(-2) + output_def["atom_virial"] = deepcopy(out_def_data["energy_derv_c"]) + output_def["atom_virial"].squeeze(-2) + return output_def + + # ========================================================================= + # Serialization + # ========================================================================= + + def serialize(self) -> dict[str, Any]: + """Serialize the native-spin SeZM model.""" + data = super().serialize() + data["type"] = self.model_type + data["spin"] = self.spin.serialize() + return data + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> "SeZMNativeSpinModel": + """Deserialize a native-spin SeZM model.""" + data = data.copy() + version = int(data.pop("@version", 1)) + check_version_compatibility(version, 1, 1) + data.pop("@class", None) + data.pop("type", None) + spin = Spin.deserialize(data.pop("spin")) + atomic_model = SeZMAtomicModel.deserialize(data.pop("atomic_model")) + return cls(atomic_model_=atomic_model, spin=spin, **data) diff --git a/deepmd/pt/model/model/sezm_property_model.py b/deepmd/pt/model/model/sezm_property_model.py index e1fcb7b793..43890d4549 100644 --- a/deepmd/pt/model/model/sezm_property_model.py +++ b/deepmd/pt/model/model/sezm_property_model.py @@ -132,6 +132,7 @@ def core_compute( comm_dict: dict[str, torch.Tensor] | None = None, extended_atype: torch.Tensor | None = None, extended_coord_corr: torch.Tensor | None = None, + spin: torch.Tensor | None = None, embedding_only: bool = False, ) -> dict[str, torch.Tensor]: """Compute property outputs through the SeZM forward-only graph.""" @@ -148,11 +149,12 @@ def core_compute( comm_dict=comm_dict, extended_atype=extended_atype, extended_coord_corr=extended_coord_corr, + spin=spin, embedding_only=embedding_only, conservative=False, ) - def _inductor_compile_options(self) -> dict[str, Any]: + def _inductor_compile_options(self, *, inference: bool = False) -> dict[str, Any]: """Augment the shared Inductor options for the property compile path. The non-conservative property backward graph triggers a TorchInductor @@ -163,7 +165,7 @@ def _inductor_compile_options(self) -> dict[str, Any]: backend, so CUDA/Triton lowering -- the actual ``use_compile`` deployment target -- is unchanged. """ - options = super()._inductor_compile_options() + options = super()._inductor_compile_options(inference=inference) options["cpp.simdlen"] = 0 return options diff --git a/deepmd/pt/model/model/sezm_spin_model.py b/deepmd/pt/model/model/sezm_spin_model.py index af973ea529..0be5b39c10 100644 --- a/deepmd/pt/model/model/sezm_spin_model.py +++ b/deepmd/pt/model/model/sezm_spin_model.py @@ -33,6 +33,7 @@ from deepmd.pt.model.model.spin_model import ( SpinModel, _lookup_type_values, + _pack_spin_stat_sample, ) from deepmd.pt.utils.utils import ( to_torch_tensor, @@ -209,7 +210,13 @@ def forward_common( edge_schema.edge_mask, fparam=fp, aparam=ap, - extended_coord_corr=extended_coord_corr[:, : nloc * 2, :], + # Slicing the doubled-extended correction down to the local + # region yields a stride-(2*nall*3, ...) view; the compiled + # virial matmul reshapes it to (N, 1, 3), which ``torch.compile`` + # lowers to ``aten.view`` and rejects on the non-contiguous + # layout. Materialize a contiguous copy, mirroring the local + # coord/atype slices in ``edge_schema_from_extended``. + extended_coord_corr=extended_coord_corr[:, : nloc * 2, :].contiguous(), charge_spin=charge_spin, input_prec=input_prec, ) @@ -395,7 +402,7 @@ def fn( ) trace_inputs = (*trace_inputs, charge_spin) - return self._trace_lower_exportable( + return self.trace_lower_exportable( fn, *trace_inputs, ) @@ -462,6 +469,16 @@ def has_spin(self) -> bool: """Return whether this model consumes spin input.""" return True + def export_lower_input_kind(self) -> str: + """Return the ``.pt2`` lower ABI: the deepspin scheme needs the nlist. + + Virtual atoms are placed and the neighbor list is expanded inside the + traced graph from the extended inputs, so the export contract feeds the + real extended coordinates, spin and neighbor list rather than a + pre-built edge schema. + """ + return "nlist" + def get_type_map(self) -> list[str]: """Return the real atom type map.""" return super().get_type_map()[: self.ntypes_real] @@ -502,9 +519,8 @@ def get_observed_type_list(self) -> list[str]: def model_output_def(self) -> ModelOutputDef: """Return the spin-aware model output definition.""" - var_name = self._get_output_var_name() atomic_output_def = self.atomic_output_def() - atomic_output_def[var_name].magnetic = True + atomic_output_def["energy"].magnetic = True return ModelOutputDef(atomic_output_def) def translated_output_def(self) -> dict[str, Any]: @@ -585,10 +601,6 @@ def _get_inter_potential_real_type_count(self) -> int: """Return the number of real types for real-only ZBL masking.""" return self.ntypes_real - def _get_output_var_name(self) -> str: - """Return the primary atomic output variable name.""" - return "energy" - def _get_spin_sampled_func( self, sampled_func: Callable[[], list[dict[str, Any]]] ) -> Callable[[], list[dict[str, Any]]]: @@ -596,36 +608,7 @@ def _get_spin_sampled_func( @functools.lru_cache def spin_sampled_func() -> list[dict[str, Any]]: - sampled = sampled_func() - spin_sampled = [] - for sys in sampled: - coord_updated, atype_updated, _ = self.process_spin_input( - sys["coord"], sys["atype"], sys["spin"] - ) - tmp_dict = { - "coord": coord_updated, - "atype": atype_updated, - } - if "aparam" in sys: - tmp_dict["aparam"] = self.expand_aparam( - sys["aparam"], atype_updated.shape[1] - ) - if "natoms" in sys: - natoms = sys["natoms"] - tmp_dict["natoms"] = torch.cat( - [2 * natoms[:, :2], natoms[:, 2:], natoms[:, 2:]], dim=-1 - ) - for item_key in sys.keys(): - if item_key not in [ - "coord", - "atype", - "spin", - "natoms", - "aparam", - ]: - tmp_dict[item_key] = sys[item_key] - spin_sampled.append(tmp_dict) - return spin_sampled + return [_pack_spin_stat_sample(self, sys) for sys in sampled_func()] return self.atomic_model._make_wrapped_sampler(spin_sampled_func) @@ -650,22 +633,21 @@ def _split_spin_common_output( nloc: int, ) -> dict[str, torch.Tensor]: """Split full-interface SeZM outputs into real and magnetic parts.""" - var_name = self._get_output_var_name() - model_ret[var_name] = torch.split(model_ret[var_name], [nloc, nloc], dim=1)[0] - if self.do_grad_r(var_name) and model_ret.get(f"{var_name}_derv_r") is not None: + model_ret["energy"] = torch.split(model_ret["energy"], [nloc, nloc], dim=1)[0] + if self.do_grad_r("energy") and model_ret.get("energy_derv_r") is not None: ( - model_ret[f"{var_name}_derv_r"], - model_ret[f"{var_name}_derv_r_mag"], + model_ret["energy_derv_r"], + model_ret["energy_derv_r_mag"], model_ret["mask_mag"], - ) = self.process_spin_output(atype, model_ret[f"{var_name}_derv_r"]) - if self.do_grad_c(var_name) and model_ret.get(f"{var_name}_derv_c") is not None: + ) = self.process_spin_output(atype, model_ret["energy_derv_r"]) + if self.do_grad_c("energy") and model_ret.get("energy_derv_c") is not None: ( - model_ret[f"{var_name}_derv_c"], - model_ret[f"{var_name}_derv_c_mag"], + model_ret["energy_derv_c"], + model_ret["energy_derv_c_mag"], model_ret["mask_mag"], ) = self.process_spin_output( atype, - model_ret[f"{var_name}_derv_c"], + model_ret["energy_derv_c"], add_mag=True, virtual_scale=False, ) @@ -679,24 +661,23 @@ def _split_spin_lower_output( nloc: int, ) -> dict[str, torch.Tensor]: """Split lower-interface SeZM outputs into real and magnetic parts.""" - var_name = self._get_output_var_name() - model_ret[var_name] = torch.split(model_ret[var_name], [nloc, nloc], dim=1)[0] - if self.do_grad_r(var_name) and model_ret.get(f"{var_name}_derv_r") is not None: + model_ret["energy"] = torch.split(model_ret["energy"], [nloc, nloc], dim=1)[0] + if self.do_grad_r("energy") and model_ret.get("energy_derv_r") is not None: ( - model_ret[f"{var_name}_derv_r"], - model_ret[f"{var_name}_derv_r_mag"], + model_ret["energy_derv_r"], + model_ret["energy_derv_r_mag"], model_ret["mask_mag"], ) = self.process_spin_output_lower( - extended_atype, model_ret[f"{var_name}_derv_r"], nloc + extended_atype, model_ret["energy_derv_r"], nloc ) - if self.do_grad_c(var_name) and model_ret.get(f"{var_name}_derv_c") is not None: + if self.do_grad_c("energy") and model_ret.get("energy_derv_c") is not None: ( - model_ret[f"{var_name}_derv_c"], - model_ret[f"{var_name}_derv_c_mag"], + model_ret["energy_derv_c"], + model_ret["energy_derv_c_mag"], model_ret["mask_mag"], ) = self.process_spin_output_lower( extended_atype, - model_ret[f"{var_name}_derv_c"], + model_ret["energy_derv_c"], nloc, add_mag=True, virtual_scale=False, diff --git a/deepmd/pt/model/model/spin_model.py b/deepmd/pt/model/model/spin_model.py index 1909fde36b..fed382f37e 100644 --- a/deepmd/pt/model/model/spin_model.py +++ b/deepmd/pt/model/model/spin_model.py @@ -35,6 +35,59 @@ make_model, ) +_NATOMS_VEC_KEYS = ("natoms", "real_natoms_vec") +_SPIN_STAT_RESERVED_KEYS = frozenset( + {"coord", "atype", "spin", "natoms", "real_natoms_vec", "aparam"} +) + + +def _expand_natoms_vec_for_virtual_spin(natoms: torch.Tensor) -> torch.Tensor: + """Expand a DeePMD natoms vector for the virtual-atom spin layout. + + The leading two entries count local and extended atoms; they are doubled + to reflect real/virtual atom pairs. Per-type counts are duplicated so that + virtual spin slots inherit the population of their real counterparts. + + Parameters + ---------- + natoms + Natoms vector with shape ``(nframes, ntypes_real + 2)``. + + Returns + ------- + torch.Tensor + Expanded vector with shape ``(nframes, 2 * ntypes_real + 2)``. + """ + return torch.cat( + [2 * natoms[:, :2], natoms[:, 2:], natoms[:, 2:]], + dim=-1, + ) + + +def _pack_spin_stat_sample( + spin_model: "SpinModel", + sample: dict[str, Any], +) -> dict[str, Any]: + """Transform one statistics sample into the virtual-atom spin layout.""" + coord_updated, atype_updated, _ = spin_model.process_spin_input( + sample["coord"], sample["atype"], sample["spin"] + ) + packed: dict[str, Any] = { + "coord": coord_updated, + "atype": atype_updated, + } + if "aparam" in sample: + packed["aparam"] = spin_model.expand_aparam( + sample["aparam"], atype_updated.shape[1] + ) + for key in _NATOMS_VEC_KEYS: + if key in sample: + packed[key] = _expand_natoms_vec_for_virtual_spin(sample[key]) + for item_key in sample: + if item_key not in _SPIN_STAT_RESERVED_KEYS: + packed[item_key] = sample[item_key] + return packed + def _lookup_type_values(values: torch.Tensor, atype: torch.Tensor) -> torch.Tensor: """ @@ -44,8 +97,14 @@ def _lookup_type_values(values: torch.Tensor, atype: torch.Tensor) -> torch.Tens that advanced-indexing form to a CUDA ``index.Tensor`` shim even for a CPU ``.pt2`` package. ``index_select`` keeps the exported spin graph device stable while preserving the same lookup semantics. + + Padding ghost slots carry ``atype == -1`` (batched extended regions are + padded to a uniform ``nall``). Unlike advanced indexing, ``index_select`` + rejects negative indices, so the padding entries are clamped to row 0; their + looked-up value is irrelevant because padding atoms carry zero spin and are + dropped from the per-local output downstream. """ - flat_atype = atype.reshape(-1).to(dtype=torch.long) + flat_atype = torch.clamp_min(atype.reshape(-1).to(dtype=torch.long), 0) return torch.index_select(values.to(atype.device), 0, flat_atype).view(atype.shape) @@ -413,26 +472,7 @@ def _get_spin_sampled_func( ) -> Callable[[], list[dict]]: @functools.lru_cache def spin_sampled_func() -> list[dict]: - sampled = sampled_func() - spin_sampled = [] - for sys in sampled: - coord_updated, atype_updated, _ = self.process_spin_input( - sys["coord"], sys["atype"], sys["spin"] - ) - tmp_dict = { - "coord": coord_updated, - "atype": atype_updated, - } - if "natoms" in sys: - natoms = sys["natoms"] - tmp_dict["natoms"] = torch.cat( - [2 * natoms[:, :2], natoms[:, 2:], natoms[:, 2:]], dim=-1 - ) - for item_key in sys.keys(): - if item_key not in ["coord", "atype", "spin", "natoms"]: - tmp_dict[item_key] = sys[item_key] - spin_sampled.append(tmp_dict) - return spin_sampled + return [_pack_spin_stat_sample(self, sys) for sys in sampled_func()] return self.backbone_model.atomic_model._make_wrapped_sampler(spin_sampled_func) @@ -501,36 +541,7 @@ def compute_or_load_stat( @functools.lru_cache def spin_sampled_func() -> list[dict[str, Any]]: - sampled = sampled_func() - spin_sampled = [] - for sys in sampled: - coord_updated, atype_updated, _ = self.process_spin_input( - sys["coord"], sys["atype"], sys["spin"] - ) - tmp_dict = { - "coord": coord_updated, - "atype": atype_updated, - } - if "aparam" in sys: - tmp_dict["aparam"] = self.expand_aparam( - sys["aparam"], atype_updated.shape[1] - ) - if "natoms" in sys: - natoms = sys["natoms"] - tmp_dict["natoms"] = torch.cat( - [2 * natoms[:, :2], natoms[:, 2:], natoms[:, 2:]], dim=-1 - ) - for item_key in sys.keys(): - if item_key not in [ - "coord", - "atype", - "spin", - "natoms", - "aparam", - ]: - tmp_dict[item_key] = sys[item_key] - spin_sampled.append(tmp_dict) - return spin_sampled + return [_pack_spin_stat_sample(self, sys) for sys in sampled_func()] self.backbone_model.compute_or_load_stat( spin_sampled_func, diff --git a/deepmd/pt/model/model/transform_output.py b/deepmd/pt/model/model/transform_output.py index ca3515f650..fb17d762b7 100644 --- a/deepmd/pt/model/model/transform_output.py +++ b/deepmd/pt/model/model/transform_output.py @@ -9,6 +9,9 @@ get_deriv_name, get_reduce_name, ) +from deepmd.kernels.utils import ( + triton_infer_level, +) from deepmd.pt.utils import ( env, ) @@ -216,7 +219,8 @@ def edge_energy_deriv( nall: int, create_graph: bool, extended_coord_corr: torch.Tensor | None = None, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + spin_leaf: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]: """Assemble extended force, virial and atomic virial from edge gradients. The energy depends on coordinates only through the per-edge displacement @@ -261,6 +265,11 @@ def edge_energy_deriv( extended_coord_corr Optional spin virtual-displacement correction with shape ``(nf, nall, 3)``; adds ``force (x) coord_corr`` per extended atom. + spin_leaf + Optional per-atom spin leaf with shape ``(nf, nloc, 3)`` for the native + spin scheme. When provided, the energy is also differentiated with + respect to it in the same backward, so the magnetic force shares the + first-derivative graph used by the force-loss second backward. Returns ------- @@ -271,34 +280,68 @@ def edge_energy_deriv( symmetrically between the two endpoints of each edge. energy_derv_c_redu Reduced global virial with shape ``(nf, 1, 9)``. + energy_derv_r_mag + Magnetic force ``-dE/dspin`` with shape ``(nf, nloc, 1, 3)`` when + ``spin_leaf`` is provided, otherwise ``None``. """ - (g,) = torch.autograd.grad( + grad_inputs = [edge_vec] if spin_leaf is None else [edge_vec, spin_leaf] + grads = torch.autograd.grad( [energy_redu], - [edge_vec], + grad_inputs, grad_outputs=[torch.ones_like(energy_redu)], create_graph=create_graph, retain_graph=True, ) + g = grads[0] # Padded edges carry no energy contribution, so their gradient is zero; # mask defensively before the scatter. g = torch.where(edge_mask.unsqueeze(-1), g, torch.zeros_like(g)) n_ext = nf * nall - # Force: F_k = sum_{dst=k} g_e - sum_{src=k} g_e. - force_flat = torch.zeros(n_ext, 3, dtype=g.dtype, device=g.device) - force_flat = force_flat.index_add(0, dst_ext, g) - force_flat = force_flat.index_add(0, src_ext, -g) - extended_force = force_flat.view(nf, nall, 3) + if triton_infer_level() >= 1 and not create_graph and g.is_cuda: + # Inference: assemble force and per-atom virial with two CSR segment + # reductions instead of four ``index_add`` scatters (which serialize + # on the colliding edges of each atom) and a materialized ``(E, 9)`` + # outer product. The extended indices carry no ordering guarantee, so + # the topology is sorted here; these integer ops trace as ordinary + # aten nodes under ``make_fx``. + from deepmd.kernels.triton.sezm.force_assembly import ( + edge_force_assembly, + ) + + dst_order = torch.argsort(dst_ext) + src_order = torch.argsort(src_ext) + boundaries = torch.arange(n_ext + 1, device=g.device, dtype=dst_ext.dtype) + dst_row_ptr = torch.searchsorted(dst_ext.index_select(0, dst_order), boundaries) + src_row_ptr = torch.searchsorted(src_ext.index_select(0, src_order), boundaries) + force_flat, av_flat = edge_force_assembly( + g.contiguous(), + edge_vec.detach().contiguous(), + dst_order, + dst_row_ptr, + src_order, + src_row_ptr, + ) + extended_force = force_flat.view(nf, nall, 3) + extended_virial = av_flat.view(nf, nall, 9) + else: + # Force: F_k = sum_{dst=k} g_e - sum_{src=k} g_e. + force_flat = torch.zeros(n_ext, 3, dtype=g.dtype, device=g.device) + force_flat = force_flat.index_add(0, dst_ext, g) + force_flat = force_flat.index_add(0, src_ext, -g) + extended_force = force_flat.view(nf, nall, 3) - # Per-edge virial outer product w_e[k, j] = -g_e^k * edge_vec_e^j, flattened - # to 9 with (force component k, coordinate component j) ordering. - w_edge = -torch.einsum("ek,ej->ekj", g, edge_vec).reshape(-1, 9) - # Atomic virial: split each per-edge tensor symmetrically between endpoints. - half_w = 0.5 * w_edge - av_flat = torch.zeros(n_ext, 9, dtype=g.dtype, device=g.device) - av_flat = av_flat.index_add(0, dst_ext, half_w) - av_flat = av_flat.index_add(0, src_ext, half_w) - extended_virial = av_flat.view(nf, nall, 9) + # Per-edge virial outer product w_e[k, j] = -g_e^k * edge_vec_e^j, + # flattened to 9 with (force component k, coordinate component j) + # ordering. + w_edge = -torch.einsum("ek,ej->ekj", g, edge_vec).reshape(-1, 9) + # Atomic virial: split each per-edge tensor symmetrically between + # endpoints. + half_w = 0.5 * w_edge + av_flat = torch.zeros(n_ext, 9, dtype=g.dtype, device=g.device) + av_flat = av_flat.index_add(0, dst_ext, half_w) + av_flat = av_flat.index_add(0, src_ext, half_w) + extended_virial = av_flat.view(nf, nall, 9) if extended_coord_corr is not None: # Spin: the virtual-atom displacement adds force (x) coord_corr per atom. @@ -311,7 +354,14 @@ def edge_energy_deriv( energy_derv_r = extended_force.unsqueeze(-2) energy_derv_c = extended_virial.unsqueeze(-2) energy_derv_c_redu = energy_derv_c.to(env.GLOBAL_PT_ENER_FLOAT_PRECISION).sum(dim=1) - return energy_derv_r, energy_derv_c, energy_derv_c_redu + + # Magnetic force is the negative spin gradient, matching the dataset + # ``force_mag = -dE/dspin`` convention (the virtual-atom scheme reaches the + # same quantity through ``F_virtual * virtual_scale``). + energy_derv_r_mag: torch.Tensor | None = None + if spin_leaf is not None: + energy_derv_r_mag = (-grads[1]).unsqueeze(-2) + return energy_derv_r, energy_derv_c, energy_derv_c_redu, energy_derv_r_mag def communicate_extended_output( diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index df883c8ee6..4b41019201 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -195,6 +195,8 @@ def __init__( infer_env_defaults["DP_COMPILE_INFER"] = "1" if bool(validating_params.get("tf32_infer", False)): infer_env_defaults["DP_TF32_INFER"] = "1" + if bool(validating_params.get("amp_infer", False)): + infer_env_defaults["DP_AMP_INFER"] = "1" self.multi_task = "model_dict" in model_params self.finetune_links = finetune_links self.finetune_update_stat = False @@ -1231,10 +1233,12 @@ def _create_ema_full_validator( validating_params: dict[str, Any], validation_data: DpLoaderSet | None, ) -> FullValidator | None: - """Create the runtime EMA full validator when it is active.""" - if not self._is_validation_requested( - validating_params, "full_validation" - ) or not validating_params.get("ema_full_validation", False): + """Create the runtime EMA full validator when it is active. + + EMA full validation is independent from regular full validation: it + can be enabled on its own to validate only the EMA-smoothed model. + """ + if not self._is_validation_requested(validating_params, "ema_full_validation"): return None self._raise_if_full_validation_unsupported(validation_data) if self.model_ema is None: @@ -1292,18 +1296,10 @@ def _raise_if_full_validation_unsupported( "training; multi-task training is not supported." ) - has_spin = getattr(self.model, "has_spin", False) - if callable(has_spin): - has_spin = has_spin() - if has_spin or isinstance(self.loss, EnergySpinLoss): + if not isinstance(self.loss, (EnergyStdLoss, EnergySpinLoss)): raise ValueError( "validating.full_validation only supports single-task energy " - "training; spin-energy training is not supported." - ) - - if not isinstance(self.loss, EnergyStdLoss): - raise ValueError( - "validating.full_validation only supports single-task energy training." + "or spin-energy training." ) if validation_data is None: @@ -2449,8 +2445,21 @@ def get_additional_data_requirement(_model: Any) -> list[DataRequirementItem]: if callable(has_spin): has_spin = has_spin() if has_spin: + # ``model.spin.allow_missing_label`` relaxes the spin label from mandatory to + # optional with a zero default, so a system without a ``spin`` file is filled + # with zeros rather than rejected. The flag is read from the model's spin + # configuration. + allow_missing_spin = getattr( + getattr(_model, "spin", None), "allow_missing_label", False + ) spin_requirement_items = [ - DataRequirementItem("spin", ndof=3, atomic=True, must=True) + DataRequirementItem( + "spin", + ndof=3, + atomic=True, + must=not allow_missing_spin, + default=0.0, + ) ] additional_data_requirement += spin_requirement_items if _model.has_chg_spin_ebd(): diff --git a/deepmd/pt/train/validation.py b/deepmd/pt/train/validation.py index f206df03c3..da3a452cea 100644 --- a/deepmd/pt/train/validation.py +++ b/deepmd/pt/train/validation.py @@ -54,10 +54,9 @@ resolve_full_validation_start_step, ) from deepmd.utils.eval_metrics import ( - FULL_VALIDATION_METRIC_FAMILY_BY_KEY, - FULL_VALIDATION_METRIC_KEY_MAP, - FULL_VALIDATION_WEIGHTED_METRIC_KEYS, - compute_energy_type_metrics, + ENERGY_FULL_VALIDATION_PROFILE, + SPIN_FULL_VALIDATION_PROFILE, + FullValidationMetricProfile, ) from deepmd.utils.weight_avg import ( weighted_average, @@ -75,15 +74,6 @@ DeepmdData, ) -LOG_COLUMN_ORDER = [ - ("E_MAE", "mae_e_per_atom"), - ("E_RMSE", "rmse_e_per_atom"), - ("F_MAE", "mae_f"), - ("F_RMSE", "rmse_f"), - ("V_MAE", "mae_v_per_atom"), - ("V_RMSE", "rmse_v_per_atom"), -] - TOPK_RECORDS_INFO_KEY = "full_validation_topk_records" BEST_METRIC_NAME_INFO_KEY = "full_validation_metric" STALE_FULL_VALIDATION_INFO_KEYS = ( @@ -97,11 +87,6 @@ VAL_LOG_COLUMN_GAP = " " VAL_LOG_HEADER_PREFIX = "# " VAL_LOG_DATA_PREFIX = " " -METRIC_LOG_UNIT_MAP = { - "e": ("meV/atom", 1000.0), - "f": ("meV/Å", 1000.0), - "v": ("meV/atom", 1000.0), -} @dataclass(frozen=True) @@ -140,38 +125,46 @@ def build_best_checkpoint_pattern( ) -def parse_validation_metric(metric: str) -> tuple[str, str]: - """Parse the configured full validation metric.""" +def select_metric_profile(model: torch.nn.Module) -> FullValidationMetricProfile: + """Select the metric profile for a model based on its spin capability.""" + has_spin = getattr(model, "has_spin", False) + if callable(has_spin): + has_spin = has_spin() + return SPIN_FULL_VALIDATION_PROFILE if has_spin else ENERGY_FULL_VALIDATION_PROFILE + + +def parse_validation_metric( + metric: str, profile: FullValidationMetricProfile +) -> tuple[str, str]: + """Parse the configured full validation metric against a profile.""" normalized_metric = normalize_full_validation_metric(metric) - if normalized_metric not in FULL_VALIDATION_METRIC_KEY_MAP: - supported_metrics = ", ".join( - item.upper() for item in FULL_VALIDATION_METRIC_KEY_MAP - ) + if normalized_metric not in profile.metric_key_map: + supported_metrics = ", ".join(item.upper() for item in profile.metric_key_map) raise ValueError( "validating.validation_metric must be one of " f"{supported_metrics}, got {metric!r}." ) - return normalized_metric, FULL_VALIDATION_METRIC_KEY_MAP[normalized_metric] + return normalized_metric, profile.metric_key_map[normalized_metric] def format_metric_for_log( - metric_name: str, metric_value: float + metric_name: str, metric_value: float, profile: FullValidationMetricProfile ) -> tuple[str, float, str]: """Format a full validation metric for user-facing logging.""" metric_family, metric_kind = metric_name.split(":") - metric_unit, metric_scale = METRIC_LOG_UNIT_MAP[metric_family] + metric_unit, metric_scale = profile.unit_by_family[metric_family] metric_label = f"{metric_family.upper()}:{metric_kind.upper()}" return metric_label, metric_value * metric_scale, metric_unit def format_metric_value_for_table( - metric_key: str, metric_value: float + metric_key: str, metric_value: float, profile: FullValidationMetricProfile ) -> tuple[float, str]: """Format one table metric value and its unit for `val.log`.""" - metric_family = FULL_VALIDATION_METRIC_FAMILY_BY_KEY.get(metric_key) + metric_family = profile.metric_family_by_key.get(metric_key) if metric_family is None: raise ValueError(f"Unknown full validation metric key: {metric_key}") - metric_unit, metric_scale = METRIC_LOG_UNIT_MAP[metric_family] + metric_unit, metric_scale = profile.unit_by_family[metric_family] return metric_value * metric_scale, metric_unit @@ -225,6 +218,7 @@ def __init__( ) -> None: self.validation_data = validation_data self.model = model + self.profile = select_metric_profile(model) self.state_store = state_store self.rank = rank self.zero_stage = zero_stage @@ -251,7 +245,7 @@ def __init__( self.save_best = bool(validating_params.get("save_best", True)) self.max_best_ckpt = int(validating_params.get("max_best_ckpt", 1)) self.metric_name, self.metric_key = parse_validation_metric( - str(validating_params.get("validation_metric", "E:MAE")) + str(validating_params.get("validation_metric", "E:MAE")), self.profile ) resolved_log_file = ( full_val_file @@ -275,8 +269,10 @@ def __init__( ) self.auto_batch_size = AutoBatchSize(silent=True) self.table_column_specs = [] - for column_name, metric_key in LOG_COLUMN_ORDER: - _, metric_unit = format_metric_value_for_table(metric_key, 1.0) + for column_name, metric_key in self.profile.column_order: + _, metric_unit = format_metric_value_for_table( + metric_key, 1.0, self.profile + ) header_label = f"{column_name}({metric_unit})" self.table_column_specs.append( (metric_key, header_label, max(len(header_label), 18)) @@ -422,7 +418,7 @@ def evaluate_all_systems(self) -> dict[str, float]: aggregated = weighted_average([metric for metric in system_metrics if metric]) return { metric_key: float(aggregated[metric_key]) - for _, metric_key in LOG_COLUMN_ORDER + for _, metric_key in self.profile.column_order if metric_key in aggregated } @@ -510,7 +506,14 @@ def _evaluate_system( test_data = data_system.get_test() natoms = int(test_data["type"].shape[1]) nframes = int(test_data["coord"].shape[0]) - include_virial = data_system.pbc and bool(test_data.get("find_virial", 0.0)) + include_virial = ( + not self.profile.needs_spin + and data_system.pbc + and bool(test_data.get("find_virial", 0.0)) + ) + spin = ( + test_data["spin"].reshape(nframes, -1) if self.profile.needs_spin else None + ) prediction = self._predict_outputs( coord=test_data["coord"].reshape(nframes, -1), atom_types=test_data["type"], @@ -520,18 +523,13 @@ def _evaluate_system( and bool(test_data.get("find_fparam", 0.0)) else None, aparam=test_data["aparam"] if self.model.get_dim_aparam() > 0 else None, + spin=spin, include_virial=include_virial, natoms=natoms, nframes=nframes, ) - shared_metrics = compute_energy_type_metrics( - prediction=prediction, - test_data=test_data, - natoms=natoms, - has_pbc=data_system.pbc, - ) - return shared_metrics.as_weighted_average_errors( - FULL_VALIDATION_WEIGHTED_METRIC_KEYS + return self.profile.compute_system_metrics( + prediction, test_data, natoms, data_system.pbc ) def _predict_outputs( @@ -542,11 +540,17 @@ def _predict_outputs( box: np.ndarray | None, fparam: np.ndarray | None, aparam: np.ndarray | None, + spin: np.ndarray | None, include_virial: bool, natoms: int, nframes: int, ) -> dict[str, np.ndarray]: - """Predict energy, force, and virial for the full validation batch.""" + """Predict energy and forces for the full validation batch. + + Energy and real-atom force are always produced. The virial is added + for periodic energy-type systems, while magnetic force and its atom + mask are added for spin systems. + """ def predict_batch( coord_batch: np.ndarray, @@ -554,6 +558,7 @@ def predict_batch( box_batch: np.ndarray | None, fparam_batch: np.ndarray | None, aparam_batch: np.ndarray | None, + spin_batch: np.ndarray | None, ) -> dict[str, np.ndarray]: coord_input = torch.tensor( coord_batch.reshape(-1, natoms, 3).astype( @@ -593,6 +598,20 @@ def predict_batch( ) else: aparam_input = None + if spin_batch is not None: + spin_kwargs = { + "spin": torch.tensor( + spin_batch.reshape(-1, natoms, 3).astype( + NP_PRECISION_DICT[ + RESERVED_PRECISION_DICT[GLOBAL_PT_FLOAT_PRECISION] + ] + ), + dtype=GLOBAL_PT_FLOAT_PRECISION, + device=DEVICE, + ) + } + else: + spin_kwargs = {} # Do not use `torch.no_grad()` here: force/virial predictions rely on # autograd inside the model even during evaluation. @@ -602,6 +621,7 @@ def predict_batch( box=box_input, fparam=fparam_input, aparam=aparam_input, + **spin_kwargs, ) if isinstance(batch_output, tuple): batch_output = batch_output[0] @@ -614,6 +634,17 @@ def predict_batch( .numpy() .reshape(-1, natoms * 3), } + if spin_batch is not None: + prediction["force_mag"] = ( + batch_output["force_mag"] + .detach() + .cpu() + .numpy() + .reshape(-1, natoms * 3) + ) + prediction["mask_mag"] = ( + batch_output["mask_mag"].detach().cpu().numpy().reshape(-1, natoms) + ) if include_virial: if "virial" not in batch_output: raise KeyError( @@ -634,11 +665,15 @@ def predict_batch( box, fparam, aparam, + spin, ) prediction = { "energy": batch_prediction["energy"], "force": batch_prediction["force"], } + if spin is not None: + prediction["force_mag"] = batch_prediction["force_mag"] + prediction["mask_mag"] = batch_prediction["mask_mag"] if include_virial: prediction["virial"] = batch_prediction["virial"] return prediction @@ -812,7 +847,7 @@ def _log_result(self, result: FullValidationResult | None) -> None: self._write_log_file(result) if self.emit_best_save_log and result.saved_best_path is not None: metric_label, metric_value, metric_unit = format_metric_for_log( - self.metric_name, result.selected_metric_value + self.metric_name, result.selected_metric_value, self.profile ) log.info( f"Saved best model to {result.saved_best_path} " @@ -828,10 +863,7 @@ def _write_log_file(self, result: FullValidationResult) -> None: for _, header_label, column_width in self.table_column_specs: header += VAL_LOG_COLUMN_GAP + f"{header_label:^{column_width}s}" header += "\n" - header += ( - "# E uses per-atom energy, F uses component-wise force errors, " - "and V uses virial normalized by natoms.\n" - ) + header += self.profile.log_header_note fout.write(header) self._should_write_header = False self._write_mode = "a" @@ -844,7 +876,7 @@ def _write_log_file(self, result: FullValidationResult) -> None: metric_value = result.metrics.get(metric_key, float("nan")) if not np.isnan(metric_value): metric_value, _ = format_metric_value_for_table( - metric_key, metric_value + metric_key, metric_value, self.profile ) metric_text = format_metric_number_for_log(metric_value) line += VAL_LOG_COLUMN_GAP + f"{metric_text:^{column_width}s}" @@ -852,7 +884,7 @@ def _write_log_file(self, result: FullValidationResult) -> None: fout.write(line) if result.saved_best_path is not None: metric_label, metric_value, metric_unit = format_metric_for_log( - self.metric_name, result.selected_metric_value + self.metric_name, result.selected_metric_value, self.profile ) fout.write( "# saved best checkpoint: " diff --git a/deepmd/pt/utils/compile_compat.py b/deepmd/pt/utils/compile_compat.py index a5dea86e26..0e64f6b040 100644 --- a/deepmd/pt/utils/compile_compat.py +++ b/deepmd/pt/utils/compile_compat.py @@ -45,6 +45,7 @@ "patch_inductor_force_int64_indexing", "patch_inductor_symbolic_divisibility", "rebuild_graph_module", + "relax_views_to_reshapes", "strip_saved_tensor_detach", "trace_pad_dim", ] @@ -300,7 +301,37 @@ def rebuild_graph_module(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: return new_gm -def build_inductor_compile_options() -> dict[str, Any]: +def relax_views_to_reshapes(gm: torch.fx.GraphModule) -> None: + """Rewrite every ``aten.view`` in a ``make_fx`` graph to ``aten.reshape``. + + ``make_fx`` lowers ``Tensor.reshape`` to ``aten.view`` whenever the traced + ``FakeTensor`` is view-compatible. The lowering is unsound when the fake + stride differs from the eager stride -- a permuted tensor that ``FakeTensor`` + keeps strided while eager materializes contiguous -- since the baked + ``aten.view`` is accepted during tracing yet rejected at runtime for + incompatible size and stride. ``aten.reshape`` coincides with ``aten.view`` + on view-compatible strides (and is elided by Inductor in that case) and + copies only when a view is impossible; the rewrite is therefore + semantics-preserving and free on the fast path. + + Parameters + ---------- + gm : torch.fx.GraphModule + The ``make_fx`` graph to rewrite in place. + """ + view = torch.ops.aten.view.default + reshape = torch.ops.aten.reshape.default + relaxed = False + for node in gm.graph.nodes: + if node.op == "call_function" and node.target is view: + node.target = reshape + relaxed = True + if relaxed: + gm.graph.lint() + gm.recompile() + + +def build_inductor_compile_options(*, inference: bool = False) -> dict[str, Any]: """Return the conservative Inductor options used to lower the dynamic graph. The option set disables every Inductor and Triton feature that has @@ -310,6 +341,22 @@ def build_inductor_compile_options() -> dict[str, Any]: some GPU/Triton combinations. Options absent from the running PyTorch's configuration registry are dropped so the returned dictionary stays valid across releases. + + Parameters + ---------- + inference : bool + Whether the options lower an inference graph (the ``make_fx`` + + ``aot_module_simplified`` path and the AOTInductor freeze) rather + than the ``torch.compile`` training graph. Inference graphs enter + Inductor with hint-less data-dependent symbols, which breaks the + peak-memory reordering pass (see below); training graphs carry real + size hints from the first traced call and benefit from the pass. + + Returns + ------- + dict[str, Any] + Keyword options accepted by ``torch.compile(options=...)`` and by + ``torch._inductor.config.patch``. """ compile_options: dict[str, Any] = { "max_autotune": False, @@ -330,6 +377,19 @@ def build_inductor_compile_options() -> dict[str, Any]: # The option is shared by the training and evaluation graphs. "triton.max_tiles": 1, } + if inference: + # The peak-memory reordering pass sizes buffers through + # ``sizevars.size_hint(numel, fallback=0)``. The inference graph is + # lowered from ``make_fx`` fake placeholders whose edge-count symbols + # carry no hint, so every dynamically shaped buffer is costed as zero + # bytes, the candidate orders become indistinguishable to the cost + # model, and the pass rewrites the schedule into an order that hoists + # the dynamic allocations to the head of the generated ``call()`` -- + # all forward/backward intermediates then coexist, more than doubling + # peak memory on the SeZM inference graph. Training compiles through + # Dynamo with real hints from the first call and measurably benefits + # from the pass, so it keeps the upstream default. + compile_options["reorder_for_peak_memory"] = False try: from torch._inductor import config as inductor_config diff --git a/deepmd/pt/utils/serialization.py b/deepmd/pt/utils/serialization.py index 82274796e8..db23eef4dc 100644 --- a/deepmd/pt/utils/serialization.py +++ b/deepmd/pt/utils/serialization.py @@ -85,6 +85,12 @@ def deserialize_to_file(model_file: str, data: dict) -> None: ) model = SeZMSpinModel.deserialize(model_data) + elif model_data.get("type") == "sezm_native_spin": + from deepmd.pt.model.model.sezm_native_spin_model import ( + SeZMNativeSpinModel, + ) + + model = SeZMNativeSpinModel.deserialize(model_data) else: model = BaseModel.deserialize(model_data) # JIT will happy in this way... diff --git a/deepmd/pt_expt/descriptor/dpa4.py b/deepmd/pt_expt/descriptor/dpa4.py index baca905d64..65a4b94efd 100644 --- a/deepmd/pt_expt/descriptor/dpa4.py +++ b/deepmd/pt_expt/descriptor/dpa4.py @@ -11,8 +11,8 @@ from deepmd.dpmodel.descriptor.dpa4_nn.radial import ( C3CutoffEnvelope as C3CutoffEnvelopeDP, ) -from deepmd.dpmodel.descriptor.dpa4_nn.wignerd import ( - WignerDCalculator as WignerDCalculatorDP, +from deepmd.kernels.utils import ( + use_amp_infer, ) from deepmd.pt_expt.common import ( register_dpmodel_mapping, @@ -26,20 +26,6 @@ ) -@torch_module -class WignerDCalculator(WignerDCalculatorDP): - def forward(self, *args: Any, **kwargs: Any) -> Any: - return self.call(*args, **kwargs) - - -# WignerDCalculator.deserialize raises NotImplementedError by design (its -# tables are derived constants); rebuild from the stored constructor args. -register_dpmodel_mapping( - WignerDCalculatorDP, - lambda v: WignerDCalculator(v.lmax, eps=v.eps, precision=v.precision), -) - - @torch_module class SwiGLU(SwiGLUDP): def forward(self, *args: Any, **kwargs: Any) -> Any: @@ -117,6 +103,12 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: ), # dpa4_nn.embedding "SeZMTypeEmbedding": ("adam_type_embedding",), + # dpa4_nn.embedding (native spin): these are nn.Parameter in pt but land as + # numpy->buffer in dpmodel; mag_layer1/2 are NativeLayer and auto-promote, + # and _promote_trainable skips a missing buffer, so no-spin configs (where + # spin_scale is absent) stay safe. + "SpinEmbedding": ("adam_spin_vec_weight", "adam_spin_nbr_weight"), + "EnvironmentInitialEmbedding": ("spin_scale",), # dpa4_nn.attn_res "DepthAttnRes": ("adamw_pseudo_query",), # dpa4_nn.grid_net (residual_scale is None when disabled; _promote_trainable @@ -172,6 +164,7 @@ class DescrptDPA4(DescrptDPA4DP): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) + self.use_amp_infer = use_amp_infer() _promote_trainable_tree(self) @classmethod @@ -195,10 +188,16 @@ def _forward_blocks(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> Any: config flag and never autocasts (array-API has no autocast), so the real automatic mixed precision lives here. ``x`` is the node-feature tensor entering the blocks; its device equals the working device, so - autocast engages only when ``self.use_amp`` is set, the module is in - training mode, and the inputs live on a CUDA device. + autocast engages when ``self.use_amp`` is set, the inputs live on a + CUDA device, and either the module is training or eval-time AMP was + opted in through ``DP_AMP_INFER`` (captured once at construction as + ``self.use_amp_infer``). """ - if self.use_amp and self.training and x.device.type == "cuda": + if ( + self.use_amp + and x.device.type == "cuda" + and (self.training or self.use_amp_infer) + ): with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): return super()._forward_blocks(x, *args, **kwargs) return super()._forward_blocks(x, *args, **kwargs) diff --git a/deepmd/pt_expt/descriptor/dpa4_nn/__init__.py b/deepmd/pt_expt/descriptor/dpa4_nn/__init__.py index 4b649efaae..4e3585b520 100644 --- a/deepmd/pt_expt/descriptor/dpa4_nn/__init__.py +++ b/deepmd/pt_expt/descriptor/dpa4_nn/__init__.py @@ -11,6 +11,8 @@ weights are trainable parameters (the dpmodel list mixes modules with a bare activation function, which the generic conversion cannot turn into a ``ModuleList``). +- :mod:`wignerd` -- opt-in fused Triton monomial fast path for the Wigner-D + ``l = 2`` contraction and the shared ``l >= 3`` monomial kernels. Importing this package registers the dpmodel -> pt_expt converters (via ``torch_module``), so the auto-wrapped descriptor tree picks up these subclasses @@ -21,4 +23,5 @@ block, radial, so2, + wignerd, ) diff --git a/deepmd/pt_expt/descriptor/dpa4_nn/so2.py b/deepmd/pt_expt/descriptor/dpa4_nn/so2.py index ae9ffc43aa..19e6b6aec9 100644 --- a/deepmd/pt_expt/descriptor/dpa4_nn/so2.py +++ b/deepmd/pt_expt/descriptor/dpa4_nn/so2.py @@ -1,57 +1,94 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -"""pt_expt SO(2) convolution and radial mixer with opt-in fused Triton kernels. +"""pt_expt SO(2) linear, convolution, and radial mixer with opt-in fused Triton kernels. The dpmodel SO(2) modules are array-API only. These wrappers inject the -reference pt opt-in Triton inference path (``DP_TRITON_INFER``) around the two -rotation hot paths of the SO(2) convolution and the low-rank branch of the -dynamic radial degree mixer, mirroring -``deepmd.pt.model.descriptor.sezm_nn.so2``. The kernels run only during -inference (``not self.training``); training and CPU / fp64 inference fall back to -the dpmodel dense path. +reference pt opt-in Triton inference path around three hot paths, mirroring +``deepmd.pt.model.descriptor.sezm_nn.so2``: + +- the block-diagonal GEMM of :class:`SO2Linear`, +- the two rotation hot paths of :class:`SO2Convolution`, and +- the low-rank branch of :class:`DynamicRadialDegreeMixer`. + +The kernels are sourced from the central :mod:`deepmd.kernels.triton.sezm` +package and gated by the integer inference level ``DP_TRITON_INFER`` (see +:func:`deepmd.kernels.utils.triton_infer_level`); every kernel path requires +level ``>= 1``. The kernels run only during inference (``not self.training``), +and each kernel self-guards Triton availability and falls back to an eager +reference off CUDA / on fp64, so importing this module is safe on CPU-only +environments; training and CPU / fp64 inference use the dpmodel dense path. """ from __future__ import ( annotations, ) -import os from typing import ( TYPE_CHECKING, Any, ) +import torch + +from deepmd.dpmodel.common import ( + get_xp_precision, +) from deepmd.dpmodel.descriptor.dpa4_nn.so2 import ( DynamicRadialDegreeMixer as DynamicRadialDegreeMixerDP, ) from deepmd.dpmodel.descriptor.dpa4_nn.so2 import SO2Convolution as SO2ConvolutionDP +from deepmd.dpmodel.descriptor.dpa4_nn.so2 import SO2Linear as SO2LinearDP +from deepmd.kernels.utils import ( + triton_infer_level, + use_cute_infer, +) from deepmd.pt_expt.common import ( torch_module, ) if TYPE_CHECKING: - import torch - from deepmd.dpmodel.descriptor.dpa4_nn.edge_cache import ( EdgeFeatureCache, ) -_TRITON_INFER_TRUE = ("1", "true", "yes", "on") +@torch_module +class SO2Linear(SO2LinearDP): + """SO(2)-equivariant linear with an opt-in fused block-diagonal Triton GEMM.""" -def use_triton_infer() -> bool: - """Return whether the opt-in Triton inference kernels are enabled. + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + # Inference fast path (``DP_TRITON_INFER >= 1``): the per-|m|-block + # batched bmm + cat of ``_block_diagonal_matmul`` is replaced by a fused + # Triton BN=64 block-diagonal GEMM that consumes the strided operands + # without a contiguity copy. Bound only when Triton is available and every + # block width aligns to BN=64; otherwise the eager path is kept. The gate + # is read once at construction so it is a compile-time constant in the + # traced (``make_fx``) graph, and it only takes effect during inference. + self._block_diag_gemm = None + if triton_infer_level() >= 1: + from deepmd.kernels.triton.sezm.so2_block_gemm import ( + SO2_BLOCK_GEMM_TRITON_AVAILABLE, + block_diag_gemm, + slices_supported, + ) - The flag is controlled by the ``DP_TRITON_INFER`` environment variable and - is read at module construction time so that it becomes a compile-time - constant in the traced (``make_fx``) graph. It only takes effect during - inference; training always uses the dense reference path. + if SO2_BLOCK_GEMM_TRITON_AVAILABLE and slices_supported( + self._block_diag_slices + ): + self._block_diag_gemm = block_diag_gemm - Returns - ------- - bool - ``True`` when ``DP_TRITON_INFER`` is set to a truthy value. - """ - return os.environ.get("DP_TRITON_INFER", "0").strip().lower() in _TRITON_INFER_TRUE + def _block_diagonal_matmul( + self, x_flat: torch.Tensor, weight: torch.Tensor + ) -> torch.Tensor: + if self._block_diag_gemm is not None and not self.training: + # The fused GEMM consumes the ``(F, D_m*Cin, D_m*Cout)`` presentation + # directly from the strided weight, so the permute is applied here and + # the contiguity copy the dpmodel ``bmm`` cat path would need is + # skipped. The eager fallback permutes ``weight`` internally, so it is + # passed the stored ``(D_m*Cin, F, D_m*Cout)`` layout untouched. + weight = weight.permute(1, 0, 2) # (F, D_m*Cin, D_m*Cout) + return self._block_diag_gemm(x_flat, weight, self._block_diag_slices) + return super()._block_diagonal_matmul(x_flat, weight) @torch_module @@ -60,10 +97,12 @@ class DynamicRadialDegreeMixer(DynamicRadialDegreeMixerDP): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - # Inference fast path (opt-in via ``DP_TRITON_INFER``): a fused Triton - # kernel replaces the dense scatter and the tiny batched matmul of the - # ``degree_channel`` low-rank branch in the ``mmax == 1`` layout. - self.use_triton_infer = use_triton_infer() + # Inference fast path (``DP_TRITON_INFER >= 1``): a fused Triton kernel + # replaces the dense scatter and the tiny batched matmul of the + # ``degree_channel`` low-rank branch in the ``mmax == 1`` layout. The gate + # is read once at construction so it is a compile-time constant in the + # traced (``make_fx``) graph, and it only takes effect during inference. + self.use_triton_infer = triton_infer_level() >= 1 self._radial_mix_block = None if ( self.use_triton_infer @@ -71,7 +110,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: and self.rank > 0 and self.mmax == 1 ): - from .triton.radial_mix import ( + from deepmd.kernels.triton.sezm.radial_mix import ( radial_mix_block, ) @@ -96,13 +135,13 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # ``use_triton_infer`` is read once at construction so it is a # compile-time constant in the traced (``make_fx``) graph, and it only # takes effect during inference. - self.use_triton_infer = use_triton_infer() + self.use_triton_infer = triton_infer_level() >= 1 # === Triton rotation kernels: block for mmax == 1, dense otherwise === self._rotate_to_local_fn = None self._rotate_back_fn = None if self.use_triton_infer: - from .triton.so2_rotation import ( + from deepmd.kernels.triton.sezm.so2_rotation import ( rotate_back_block_so2, rotate_back_dense, rotate_to_local_block, @@ -127,6 +166,73 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: x_local, wigner, self.coeff_index_m, self.ebed_dim_full ) + # === Step 12. Optional fused flash-attention aggregation kernel === + # Folds the entire ``n_atten_head > 0`` value aggregation -- block-diagonal + # rotate-back, inverse-rotation rescale, envelope-gated softmax weighting, + # and the destination scatter -- into a single destination-segmented + # Triton kernel, removing the transient ``x_message`` and weighted-value + # edge tensors and the ``index_add`` round trip. It shares the + # ``DP_TRITON_INFER`` gate with the other SeZM inference kernels and only + # engages for the supported ``mmax == 1`` attention layout without the + # optional focus-mix / value / output projections (the deployed DPA4 + # configuration); the op itself dispatches to an eager reference off the + # CUDA fp32 path. The output-side head gate stays a cheap node-level + # elementwise applied after the kernel. The supported-layout half of the + # predicate is the dpmodel base's ``_flash_atten_layout_ok`` (the base + # leaves ``use_flash_atten=False`` and the hooks ``None``); this re-enables + # flash by ANDing that layout predicate with the Triton-availability gate. + self.use_flash_atten = self.use_triton_infer and self._flash_atten_layout_ok + if self.use_flash_atten: + from deepmd.kernels.triton.sezm.flash_atten import ( + build_row_ptr, + flash_atten_aggregate, + ) + + self._flash_atten_fn = flash_atten_aggregate + self._build_row_ptr_fn = build_row_ptr + + # The rotate/flash gate above exposes only the boolean ``use_triton_infer``; + # the fused value-path operator additionally reads the raw integer level + # (it selects the level-3 fp16x3 mixing stack from ``self.triton_infer_level``), + # so the level is stored here as well. ``DP_TRITON_INFER`` and + # ``DP_CUTE_INFER`` both claim the single ``so2_message`` value path, so + # enabling them together has no coherent meaning and is rejected here. + self.triton_infer_level = triton_infer_level() + if self.triton_infer_level >= 1 and use_cute_infer(): + raise ValueError( + "DP_TRITON_INFER and DP_CUTE_INFER are mutually exclusive: both " + "select the fused SO(2) value-path backend. Enable exactly one " + "of them." + ) + + # === Step 13. Optional fused Triton SO(2) value-path operators === + # Fuses rotate-to-local, the radial degree mixing, the gated mixing + # stack, and the focus competition of ``so2_message`` into the + # ``sezm_triton::so2_rotate_mix`` / ``so2_mixing_stack`` operators. + # The factory validates the block layout (``mmax == 1``, gated stack + # with an identity final layer, supported focus widths) and returns + # ``None`` otherwise, leaving the reference path in charge. The value + # path resolves its launch configurations from the swept tables, so + # it engages at ``DP_TRITON_INFER >= 2``; at level 3 the factory + # additionally routes the mixing stack through the fp16x3 tensor-core + # operator on shapes whose configuration passed the fp64 validation + # sweep. + if self.triton_infer_level >= 2: + from deepmd.kernels.triton.sezm.so2_value_path import ( + make_triton_value_path, + ) + + self._value_path = make_triton_value_path(self) + # === Step 14. Optional fused CuTe SO(2) value-path operator === + # Experimental alternative backend; mutually exclusive with the Triton + # flag (enforced above). + elif use_cute_infer(): + from deepmd.kernels.cute.sezm import ( + make_cute_value_path, + ) + + self._value_path = make_cute_value_path(self) + def _rotate_to_local( self, x: torch.Tensor, edge_cache: EdgeFeatureCache ) -> tuple[torch.Tensor, torch.Tensor | None]: @@ -158,3 +264,61 @@ def _rotate_back( ) return self._rotate_back_fn(x_std, Dt_full) return super()._rotate_back(x_local, edge_cache, n_edge) + + def _flash_aggregate( + self, + x_local_flash: torch.Tensor, + edge_cache: EdgeFeatureCache, + attn_alpha: torch.Tensor, + x_l0_node: torch.Tensor, + n_node: int, + compute_dtype: Any, + ) -> torch.Tensor: + # === Step 4.3f. Fused rotate-back + envelope-softmax-weighted + # segment scatter. One destination-segmented Triton kernel + # folds the block-diagonal rotate-back, the inverse-rotation + # rescale, the per-edge ``attn_alpha`` weighting, and the + # destination reduction into a single atomic-free pass, + # returning the ungated aggregate ``(N, D, C_wide)``. The + # transient rotate-back message and weighted value tensors are + # never materialized. + row_ptr = self._build_row_ptr_fn(edge_cache.dst, n_node) + pre_gate = self._flash_atten_fn( + x_local_flash, + edge_cache.Dt_full, + self.rotate_inv_rescale_full, + attn_alpha, + row_ptr, + edge_cache.dst, + self.lmax, + self.n_atten_head, + ) # (N, D, C_wide) + + # === Step 4.4f. Output-side head gate (cheap node-level) === + attn_output_gate = torch.sigmoid( + torch.einsum( + "nfi,ifo->nfo", + self.attn_output_gate_norm(x_l0_node.to(dtype=compute_dtype)), + self.adamw_attn_gate_w, + ) + ) # (N, Fa, H) + # Broadcast the per-(focus, head) gate over the head channels + # to the packed hidden width ``c = f * Cf + h * head_dim + ch``. + gate_full = ( + attn_output_gate.reshape(n_node, self.attn_n_focus, self.n_atten_head, 1) + .expand( + n_node, + self.attn_n_focus, + self.n_atten_head, + self.head_dim, + ) + .reshape(n_node, self.hidden_channels) + ) # (N, C_wide) + # dpmodel exposes the output precision as the string ``self.precision`` (the + # wrapped conv has no ``self.dtype``); ``get_xp_precision`` resolves it to + # the torch dtype the dpmodel dense branch casts to, so the fused and dense + # aggregates share the same storage precision. + out = (pre_gate * gate_full.unsqueeze(1)).to( + dtype=get_xp_precision(torch, self.precision) + ) + return out # (N, D, C_wide) diff --git a/deepmd/pt_expt/descriptor/dpa4_nn/triton/__init__.py b/deepmd/pt_expt/descriptor/dpa4_nn/triton/__init__.py deleted file mode 100644 index 3cc27f40d4..0000000000 --- a/deepmd/pt_expt/descriptor/dpa4_nn/triton/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -"""Hardware-accelerated SeZM/DPA4 operators. - -This package hosts ``make_fx``-composable Triton implementations of SeZM hot -paths. Kernel entry points are internal implementation details of the SeZM -descriptor; the package-level API only exposes availability. -""" - -from .radial_mix import ( - RADIAL_MIX_TRITON_AVAILABLE, -) -from .so2_rotation import ( - TRITON_ROTATION_AVAILABLE, -) - -# Both kernel modules guard their ``@triton.jit`` definitions behind a ``triton`` -# import, so the two module-level checks are equivalent. Expose a single -# package-level availability flag. -TRITON_AVAILABLE = TRITON_ROTATION_AVAILABLE and RADIAL_MIX_TRITON_AVAILABLE - -__all__ = [ - "TRITON_AVAILABLE", -] diff --git a/deepmd/pt_expt/descriptor/dpa4_nn/triton/radial_mix.py b/deepmd/pt_expt/descriptor/dpa4_nn/triton/radial_mix.py deleted file mode 100644 index c563887834..0000000000 --- a/deepmd/pt_expt/descriptor/dpa4_nn/triton/radial_mix.py +++ /dev/null @@ -1,836 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -# pyright: reportMissingImports=false -# ruff: noqa: ANN001, ANN202 -"""Fused Triton dynamic radial degree mixer for the SeZM/DPA4 descriptor. - -This module provides a clean-room Triton implementation of the -``degree_channel`` branch of :class:`DynamicRadialDegreeMixer` for the -``mmax == 1`` reduced layout. The eager reference applies, per edge ``e`` and -output coefficient ``o``:: - - out[e, o, c] = sum_r channel_basis[r, c] * sum_i K_r[e, o, i] * x[e, i, c] - -where ``K_r`` is the edge-conditioned degree kernel obtained by scattering the -projected radial features ``compact`` into a ``(reduced_dim, reduced_dim)`` -matrix. ``K_r`` is block-diagonal over the ``|m|`` groups, so for -``mmax == 1`` only a ``(lmax+1) x (lmax+1)`` block (orders ``m = 0``) and two -identical ``lmax x lmax`` blocks (orders ``m = -1`` and ``m = +1``) are -non-zero. - -Design goals ------------- -1. **Skip the structural zeros and the dense scratch.** The eager path - materializes the dense kernel ``(E, reduced_dim, reduced_dim, rank)`` via a - scatter and then contracts it with a batched ``einsum``/``bmm`` whose matrices - are tiny (``reduced_dim <= 16``), which is inefficient on cuBLAS and wastes - roughly two thirds of the multiply-adds on off-block zeros. The kernel - instead reads ``compact`` directly and contracts only the structural - non-zeros, with the channel axis vectorized and one program per edge. -2. **Match eager fp32 accuracy.** Accumulation is in fp32, matching the smooth - potential-energy surface contract used throughout the SeZM descriptor. -3. **Compose with the SeZM ``make_fx`` lowering *and* the AOTInductor freeze.** - The forward and backward are functional ``torch.library.triton_op`` instances - (``mutates_args=()``) with registered fake kernels and an autograd formula, so - ``make_fx(tracing_mode="symbolic")`` captures the energy path together with - the force autograd graph used by inference. ``triton_op`` + ``wrap_triton`` - (vs ``custom_op``) lets Inductor see through to the Triton kernel and bake the - cubin into the AOTInductor ``.pt2``, so the frozen package runs the fused - mixer inside the LAMMPS C++ runtime without any Python op registration. - -Inference-only contract ------------------------ -The operator is opt-in through ``DP_TRITON_INFER`` and is only used in -evaluation, where the force is obtained from ``autograd.grad(energy, coord)``. -The backward therefore returns gradients with respect to ``compact`` and -``x_local`` (both of which carry a path to the coordinates) and ``None`` for -``channel_basis``, which is a parameter and never differentiated by the force -computation. -""" - -from __future__ import ( - annotations, -) - -import torch -from torch import ( - Tensor, -) -from torch.library import ( - wrap_triton, -) - -__all__ = [ - "RADIAL_MIX_TRITON_AVAILABLE", - "radial_mix_block", - "radial_mix_reference", -] - -try: - import triton - import triton.language as tl - - RADIAL_MIX_TRITON_AVAILABLE = True -except ImportError: # pragma: no cover - exercised only without triton - RADIAL_MIX_TRITON_AVAILABLE = False - - -# ====================================================================== -# Eager reference / fallback implementation -# ====================================================================== -def _block_layout(lmax: int) -> list[tuple[int, int, int]]: - """Return ``(coeff_start, compact_start, num_l)`` for the ``mmax == 1`` blocks. - - The reduced m-major layout keeps, for each degree ``l``, the orders - ``m = 0`` (the leading ``lmax + 1`` coefficients) followed by ``m = -1`` and - ``m = +1`` (``lmax`` coefficients each). The degree kernel for the two - signed-``m`` blocks is shared, hence the identical ``compact_start``. - """ - num_l0 = lmax + 1 - return [ - (0, 0, num_l0), - (num_l0, num_l0 * num_l0, lmax), - (num_l0 + lmax, num_l0 * num_l0, lmax), - ] - - -def radial_mix_reference( - compact: Tensor, x_local: Tensor, channel_basis: Tensor, lmax: int -) -> Tensor: - """Eager ground truth for :func:`radial_mix_block`. - - Parameters - ---------- - compact : Tensor - Projected radial degree kernel with shape ``(E, degree_kernel_size, R)``. - x_local : Tensor - Edge-local reduced features with shape ``(E, reduced_dim, C)``. - channel_basis : Tensor - Per-rank channel basis with shape ``(R, C)``. - lmax : int - Maximum spherical-harmonic degree. - - Returns - ------- - Tensor - Mixed features with shape ``(E, reduced_dim, C)``. - """ - n_edge, reduced_dim, channels = x_local.shape - out = x_local.new_zeros(n_edge, reduced_dim, channels) - for coeff0, comp0, num_l in _block_layout(int(lmax)): - # K[e, o, i, r] = compact[e, comp0 + i * num_l + o, r] - block = compact[:, comp0 : comp0 + num_l * num_l, :].reshape( - n_edge, num_l, num_l, -1 - ) - block = block.permute(0, 2, 1, 3) # (E, o, i, R) - x_block = x_local[:, coeff0 : coeff0 + num_l, :] # (E, i, C) - inner = torch.einsum("eoir,eic->eocr", block, x_block) # (E, o, C, R) - out[:, coeff0 : coeff0 + num_l, :] = torch.einsum( - "eocr,rc->eoc", inner, channel_basis - ) - return out - - -def _radial_mix_backward_reference( - grad_out: Tensor, compact: Tensor, x_local: Tensor, channel_basis: Tensor, lmax: int -) -> tuple[Tensor, Tensor]: - """Closed-form eager backward of :func:`radial_mix_reference`. - - Gradients are evaluated analytically per diagonal block, mirroring the - contractions of the Triton backward. A closed form is required rather than a - nested ``autograd.grad``: this routine is the CPU backend of the - ``radial_mix_block_bwd`` operator, which carries no autograd formula and is - consequently dispatched under ``_AutoDispatchBelowAutograd`` whenever the - force graph is replayed without grad (the SeZM ``.pt2`` freeze does so under - :func:`torch.no_grad`). That guard excludes the autograd key, so a nested - ``autograd.grad`` would observe an output without a ``grad_fn``. - - Parameters - ---------- - grad_out : Tensor - Upstream gradient with shape ``(E, reduced_dim, C)``. - compact : Tensor - Projected radial degree kernel with shape ``(E, degree_kernel_size, R)``. - x_local : Tensor - Edge-local reduced features with shape ``(E, reduced_dim, C)``. - channel_basis : Tensor - Per-rank channel basis with shape ``(R, C)``. - lmax : int - Maximum spherical-harmonic degree. - - Returns - ------- - tuple[Tensor, Tensor] - Gradients ``(grad_compact, grad_x_local)``, matching ``compact`` and - ``x_local`` in shape respectively. - """ - n_edge, reduced_dim, channels = x_local.shape - grad_x_local = torch.zeros_like(x_local) - grad_compact = torch.zeros_like(compact) - for coeff0, comp0, num_l in _block_layout(int(lmax)): - # Forward of this block (see ``radial_mix_reference``): - # out[e, o, c] = sum_{i, r} K[e, o, i, r] * x[e, i, c] * cb[r, c] - # with K[e, o, i, r] = compact[e, comp0 + i * num_l + o, r]. - k_block = ( - compact[:, comp0 : comp0 + num_l * num_l, :] - .reshape(n_edge, num_l, num_l, -1) - .permute(0, 2, 1, 3) - ) # (E, o, i, R) - x_block = x_local[:, coeff0 : coeff0 + num_l, :] # (E, i, C) - g_block = grad_out[:, coeff0 : coeff0 + num_l, :] # (E, o, C) - - # grad_x[e, i, c] = sum_r cb[r, c] * sum_o K[e, o, i, r] * g[e, o, c]. - gx = torch.einsum("eoir,eoc->eicr", k_block, g_block) # (E, i, C, R) - grad_x_local[:, coeff0 : coeff0 + num_l, :] += torch.einsum( - "eicr,rc->eic", gx, channel_basis - ) - - # grad_K[e, o, i, r] = sum_c cb[r, c] * x[e, i, c] * g[e, o, c], scattered - # back to the compact slot comp0 + i * num_l + o. The shared m = +-1 - # blocks address the same slots, so the in-place add accumulates both. - gk = torch.einsum("eoc,eic,rc->eoir", g_block, x_block, channel_basis) - grad_compact[:, comp0 : comp0 + num_l * num_l, :] += gk.permute( - 0, 2, 1, 3 - ).reshape(n_edge, num_l * num_l, -1) - return grad_compact, grad_x_local - - -# ====================================================================== -# Triton kernels (mmax == 1; LMAX and RANK are constexpr; channels vectorized) -# ====================================================================== -if RADIAL_MIX_TRITON_AVAILABLE: - # The per-edge work is tiny and memory-light, so only the warp count and - # pipeline depth are swept, keyed on the channel width. - _CONFIGS = [ - triton.Config({}, num_warps=1, num_stages=1), - triton.Config({}, num_warps=2, num_stages=1), - triton.Config({}, num_warps=4, num_stages=1), - triton.Config({}, num_warps=2, num_stages=2), - triton.Config({}, num_warps=4, num_stages=2), - ] - _KEY = ["channels"] - - @triton.jit - def _mix_fwd_block( - edge, - chan, - cmask, - x_ptr, - k_ptr, - cb_ptr, - out_ptr, - x_se, - x_sr, - x_sc, - k_se, - k_sk, - k_sr, - cb_sr, - cb_sc, - o_se, - o_sr, - o_sc, - COEFF0: tl.constexpr, - COMPACT0: tl.constexpr, - NUM_L: tl.constexpr, - RANK: tl.constexpr, - ): - """Contract one diagonal block: ``out[o] = sum_r cb[r] sum_i K_r[o,i] x[i]``.""" - for o in tl.static_range(0, NUM_L): - acc = tl.zeros(chan.shape, dtype=tl.float32) - for r in tl.static_range(0, RANK): - partial = tl.zeros(chan.shape, dtype=tl.float32) - for i in tl.static_range(0, NUM_L): - kval = tl.load( - k_ptr - + edge * k_se - + (COMPACT0 + i * NUM_L + o) * k_sk - + r * k_sr - ).to(tl.float32) - x_vec = tl.load( - x_ptr + edge * x_se + (COEFF0 + i) * x_sr + chan * x_sc, - mask=cmask, - other=0.0, - ).to(tl.float32) - partial += kval * x_vec - cb_vec = tl.load( - cb_ptr + r * cb_sr + chan * cb_sc, mask=cmask, other=0.0 - ).to(tl.float32) - acc += partial * cb_vec - tl.store( - out_ptr + edge * o_se + (COEFF0 + o) * o_sr + chan * o_sc, - acc.to(out_ptr.dtype.element_ty), - mask=cmask, - ) - - @triton.autotune(configs=_CONFIGS, key=_KEY) - @triton.jit - def _radial_mix_fwd_kernel( - x_ptr, - k_ptr, - cb_ptr, - out_ptr, - n_edge, - channels, - x_se, - x_sr, - x_sc, - k_se, - k_sk, - k_sr, - cb_sr, - cb_sc, - o_se, - o_sr, - o_sc, - LMAX: tl.constexpr, - RANK: tl.constexpr, - BLOCK_C: tl.constexpr, - ): - edge = tl.program_id(0).to(tl.int64) - chan = tl.arange(0, BLOCK_C) - cmask = chan < channels - num_l0: tl.constexpr = LMAX + 1 - strides = ( - x_se, - x_sr, - x_sc, - k_se, - k_sk, - k_sr, - cb_sr, - cb_sc, - o_se, - o_sr, - o_sc, - ) - # m = 0 block, then the shared m = -1 and m = +1 blocks. - _mix_fwd_block( - edge, - chan, - cmask, - x_ptr, - k_ptr, - cb_ptr, - out_ptr, - *strides, - 0, - 0, - num_l0, - RANK, - ) - _mix_fwd_block( - edge, - chan, - cmask, - x_ptr, - k_ptr, - cb_ptr, - out_ptr, - *strides, - num_l0, - num_l0 * num_l0, - LMAX, - RANK, - ) - _mix_fwd_block( - edge, - chan, - cmask, - x_ptr, - k_ptr, - cb_ptr, - out_ptr, - *strides, - num_l0 + LMAX, - num_l0 * num_l0, - LMAX, - RANK, - ) - - @triton.jit - def _mix_bwd_grad_x_block( - edge, - chan, - cmask, - go_ptr, - k_ptr, - cb_ptr, - gx_ptr, - go_se, - go_sr, - go_sc, - k_se, - k_sk, - k_sr, - cb_sr, - cb_sc, - gx_se, - gx_sr, - gx_sc, - COEFF0: tl.constexpr, - COMPACT0: tl.constexpr, - NUM_L: tl.constexpr, - RANK: tl.constexpr, - ): - """Input gradient of one diagonal block. - - Computes ``grad_x[i] = sum_r cb[r] sum_o K_r[o,i] grad_out[o]``. Each edge - owns its rows and the three blocks address disjoint coefficient rows, so - the result is written once with a plain store rather than an atomic add. - """ - for i in tl.static_range(0, NUM_L): - grad_x = tl.zeros(chan.shape, dtype=tl.float32) - for r in tl.static_range(0, RANK): - cb_vec = tl.load( - cb_ptr + r * cb_sr + chan * cb_sc, mask=cmask, other=0.0 - ).to(tl.float32) - partial = tl.zeros(chan.shape, dtype=tl.float32) - for o in tl.static_range(0, NUM_L): - kval = tl.load( - k_ptr - + edge * k_se - + (COMPACT0 + i * NUM_L + o) * k_sk - + r * k_sr - ).to(tl.float32) - go_vec = tl.load( - go_ptr + edge * go_se + (COEFF0 + o) * go_sr + chan * go_sc, - mask=cmask, - other=0.0, - ).to(tl.float32) - partial += kval * go_vec - grad_x += cb_vec * partial - tl.store( - gx_ptr + edge * gx_se + (COEFF0 + i) * gx_sr + chan * gx_sc, - grad_x.to(gx_ptr.dtype.element_ty), - mask=cmask, - ) - - @triton.jit - def _mix_bwd_grad_k_block( - edge, - chan, - cmask, - go_ptr, - x_ptr, - cb_ptr, - gk_ptr, - go_se, - go_sr, - go_sc, - x_se, - x_sr, - x_sc, - cb_sr, - cb_sc, - gk_se, - gk_sk, - gk_sr, - COEFF0: tl.constexpr, - COEFF1: tl.constexpr, - COMPACT0: tl.constexpr, - NUM_L: tl.constexpr, - RANK: tl.constexpr, - SHARED: tl.constexpr, - ): - """Kernel gradient of one diagonal block. - - Computes ``grad_K_r[o,i] = sum_c cb[r,c] x[i,c] grad_out[o,c]``. The - ``m = -1`` and ``m = +1`` blocks (``SHARED``) write the same ``compact`` - slots; their contributions are summed in registers and stored once, which - removes the atomic add and the zero-initialization the original required. - """ - for o in tl.static_range(0, NUM_L): - go_vec = tl.load( - go_ptr + edge * go_se + (COEFF0 + o) * go_sr + chan * go_sc, - mask=cmask, - other=0.0, - ).to(tl.float32) - if SHARED: - go_vec_sh = tl.load( - go_ptr + edge * go_se + (COEFF1 + o) * go_sr + chan * go_sc, - mask=cmask, - other=0.0, - ).to(tl.float32) - for i in tl.static_range(0, NUM_L): - x_vec = tl.load( - x_ptr + edge * x_se + (COEFF0 + i) * x_sr + chan * x_sc, - mask=cmask, - other=0.0, - ).to(tl.float32) - prod = go_vec * x_vec - if SHARED: - x_vec_sh = tl.load( - x_ptr + edge * x_se + (COEFF1 + i) * x_sr + chan * x_sc, - mask=cmask, - other=0.0, - ).to(tl.float32) - prod += go_vec_sh * x_vec_sh - for r in tl.static_range(0, RANK): - cb_vec = tl.load( - cb_ptr + r * cb_sr + chan * cb_sc, mask=cmask, other=0.0 - ).to(tl.float32) - grad_k = tl.sum(tl.where(cmask, prod * cb_vec, 0.0)) - tl.store( - gk_ptr - + edge * gk_se - + (COMPACT0 + i * NUM_L + o) * gk_sk - + r * gk_sr, - grad_k.to(gk_ptr.dtype.element_ty), - ) - - @triton.autotune(configs=_CONFIGS, key=_KEY) - @triton.jit - def _radial_mix_bwd_kernel( - go_ptr, - x_ptr, - k_ptr, - cb_ptr, - gx_ptr, - gk_ptr, - n_edge, - channels, - go_se, - go_sr, - go_sc, - x_se, - x_sr, - x_sc, - k_se, - k_sk, - k_sr, - cb_sr, - cb_sc, - gx_se, - gx_sr, - gx_sc, - gk_se, - gk_sk, - gk_sr, - LMAX: tl.constexpr, - RANK: tl.constexpr, - BLOCK_C: tl.constexpr, - ): - edge = tl.program_id(0).to(tl.int64) - chan = tl.arange(0, BLOCK_C) - cmask = chan < channels - num_l0: tl.constexpr = LMAX + 1 - - # === Step 1. Input gradient: three disjoint coefficient blocks === - grad_x_strides = ( - go_se, - go_sr, - go_sc, - k_se, - k_sk, - k_sr, - cb_sr, - cb_sc, - gx_se, - gx_sr, - gx_sc, - ) - _mix_bwd_grad_x_block( - edge, - chan, - cmask, - go_ptr, - k_ptr, - cb_ptr, - gx_ptr, - *grad_x_strides, - 0, - 0, - num_l0, - RANK, - ) - _mix_bwd_grad_x_block( - edge, - chan, - cmask, - go_ptr, - k_ptr, - cb_ptr, - gx_ptr, - *grad_x_strides, - num_l0, - num_l0 * num_l0, - LMAX, - RANK, - ) - _mix_bwd_grad_x_block( - edge, - chan, - cmask, - go_ptr, - k_ptr, - cb_ptr, - gx_ptr, - *grad_x_strides, - num_l0 + LMAX, - num_l0 * num_l0, - LMAX, - RANK, - ) - - # === Step 2. Kernel gradient: m=0 block, then summed m=+-1 blocks === - grad_k_strides = ( - go_se, - go_sr, - go_sc, - x_se, - x_sr, - x_sc, - cb_sr, - cb_sc, - gk_se, - gk_sk, - gk_sr, - ) - _mix_bwd_grad_k_block( - edge, - chan, - cmask, - go_ptr, - x_ptr, - cb_ptr, - gk_ptr, - *grad_k_strides, - 0, - 0, - 0, - num_l0, - RANK, - False, - ) - _mix_bwd_grad_k_block( - edge, - chan, - cmask, - go_ptr, - x_ptr, - cb_ptr, - gk_ptr, - *grad_k_strides, - num_l0, - num_l0 + LMAX, - num_l0 * num_l0, - LMAX, - RANK, - True, - ) - - -# ====================================================================== -# Triton launch wrappers -# ====================================================================== -def _tile_channels(channels: int) -> int: - """Smallest power-of-two channel tile of at least 16 covering ``channels``.""" - tile = 16 - while tile < int(channels): - tile *= 2 - return tile - - -def _has_no_edges(n_edge: int) -> bool: - """Return true for eager zero-edge calls without guarding symbolic edges.""" - return type(n_edge) is int and n_edge == 0 - - -def _launch_forward( - x_local: Tensor, compact: Tensor, channel_basis: Tensor, lmax: int -) -> Tensor: - n_edge, reduced_dim, channels = x_local.shape - rank = int(compact.shape[-1]) - out = torch.empty_like(x_local) - if _has_no_edges(n_edge): - return out - wrap_triton(_radial_mix_fwd_kernel)[(n_edge,)]( - x_local, - compact, - channel_basis, - out, - n_edge, - channels, - x_local.stride(0), - x_local.stride(1), - x_local.stride(2), - compact.stride(0), - compact.stride(1), - compact.stride(2), - channel_basis.stride(0), - channel_basis.stride(1), - out.stride(0), - out.stride(1), - out.stride(2), - LMAX=int(lmax), - RANK=rank, - BLOCK_C=_tile_channels(channels), - ) - return out - - -def _launch_backward( - grad_out: Tensor, - x_local: Tensor, - compact: Tensor, - channel_basis: Tensor, - lmax: int, -) -> tuple[Tensor, Tensor]: - n_edge, reduced_dim, channels = x_local.shape - rank = int(compact.shape[-1]) - # Every output element is written exactly once (input rows are disjoint and - # the shared m=+-1 kernel slots are summed in-register), so no zero-init. - grad_x = torch.empty_like(x_local) - grad_compact = torch.empty_like(compact) - if _has_no_edges(n_edge): - return grad_compact, grad_x - wrap_triton(_radial_mix_bwd_kernel)[(n_edge,)]( - grad_out.contiguous(), - x_local, - compact, - channel_basis, - grad_x, - grad_compact, - n_edge, - channels, - grad_out.stride(0), - grad_out.stride(1), - grad_out.stride(2), - x_local.stride(0), - x_local.stride(1), - x_local.stride(2), - compact.stride(0), - compact.stride(1), - compact.stride(2), - channel_basis.stride(0), - channel_basis.stride(1), - grad_x.stride(0), - grad_x.stride(1), - grad_x.stride(2), - grad_compact.stride(0), - grad_compact.stride(1), - grad_compact.stride(2), - LMAX=int(lmax), - RANK=rank, - BLOCK_C=_tile_channels(channels), - ) - return grad_compact, grad_x - - -# ====================================================================== -# Dispatch helpers (triton on CUDA float, eager otherwise) -# ====================================================================== -def _use_triton(tensor: Tensor) -> bool: - return ( - RADIAL_MIX_TRITON_AVAILABLE - and tensor.is_cuda - and tensor.dtype in (torch.float16, torch.bfloat16, torch.float32) - ) - - -def _forward_impl( - compact: Tensor, x_local: Tensor, channel_basis: Tensor, lmax: int -) -> Tensor: - if not _use_triton(x_local): - return radial_mix_reference(compact, x_local, channel_basis, lmax) - return _launch_forward( - x_local.contiguous(), - compact.contiguous(), - channel_basis.contiguous(), - int(lmax), - ) - - -def _backward_impl( - grad_out: Tensor, - compact: Tensor, - x_local: Tensor, - channel_basis: Tensor, - lmax: int, -) -> tuple[Tensor, Tensor]: - if not _use_triton(x_local): - return _radial_mix_backward_reference( - grad_out, compact, x_local, channel_basis, lmax - ) - return _launch_backward( - grad_out, - x_local.contiguous(), - compact.contiguous(), - channel_basis.contiguous(), - int(lmax), - ) - - -# ====================================================================== -# Functional triton_op + fake + autograd registration -# ====================================================================== -# ``triton_op`` (not ``custom_op``) so Inductor bakes the Triton cubin into the -# AOTInductor ``.pt2``; the LAMMPS C++ runtime then needs no Python registration. -_radial_mix_op = torch.library.triton_op( - "dpa4_triton::radial_mix_block", mutates_args=() -)(_forward_impl) - -_radial_mix_bwd_op = torch.library.triton_op( - "dpa4_triton::radial_mix_block_bwd", mutates_args=() -)(_backward_impl) - - -@_radial_mix_op.register_fake -def _(compact, x_local, channel_basis, lmax): - return torch.empty_like(x_local) - - -@_radial_mix_bwd_op.register_fake -def _(grad_out, compact, x_local, channel_basis, lmax): - return torch.empty_like(compact), torch.empty_like(x_local) - - -def _radial_mix_setup_context(ctx, inputs, output): - compact, x_local, channel_basis, lmax = inputs - ctx.save_for_backward(compact, x_local, channel_basis) - ctx.lmax = lmax - - -def _radial_mix_backward(ctx, grad_out): - compact, x_local, channel_basis = ctx.saved_tensors - grad_compact, grad_x = _radial_mix_bwd_op( - grad_out, compact, x_local, channel_basis, ctx.lmax - ) - # ``channel_basis`` is a parameter; the inference force differentiates only - # w.r.t. coordinates, so its gradient is intentionally not produced. - return grad_compact, grad_x, None, None - - -_radial_mix_op.register_autograd( - _radial_mix_backward, setup_context=_radial_mix_setup_context -) - - -# ====================================================================== -# Public API -# ====================================================================== -def radial_mix_block( - compact: Tensor, x_local: Tensor, channel_basis: Tensor, lmax: int -) -> Tensor: - """Apply the block-diagonal dynamic radial degree mixer (``mmax == 1``). - - Computes the same operation as :func:`radial_mix_reference` while avoiding - the dense scattered kernel and the tiny batched matmul on CUDA. - - Parameters - ---------- - compact : Tensor - Projected radial degree kernel with shape ``(E, degree_kernel_size, R)``. - x_local : Tensor - Edge-local reduced features with shape ``(E, reduced_dim, C)``. - channel_basis : Tensor - Per-rank channel basis with shape ``(R, C)``. - lmax : int - Maximum spherical-harmonic degree. - - Returns - ------- - Tensor - Mixed features with shape ``(E, reduced_dim, C)``. - """ - return _radial_mix_op(compact, x_local, channel_basis, int(lmax)) diff --git a/deepmd/pt_expt/descriptor/dpa4_nn/triton/so2_rotation.py b/deepmd/pt_expt/descriptor/dpa4_nn/triton/so2_rotation.py deleted file mode 100644 index e34e300259..0000000000 --- a/deepmd/pt_expt/descriptor/dpa4_nn/triton/so2_rotation.py +++ /dev/null @@ -1,2003 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -# pyright: reportMissingImports=false -# ruff: noqa: ANN001, ANN202 -"""Fused Triton SO(2)/Wigner rotation operators for the SeZM/DPA4 descriptor. - -This module provides a *clean room* Triton implementation of the two rotation -hot paths used by the SeZM SO(2) convolution: - -``rotate_to_local`` (global -> edge-local reduced frame) - For every edge ``e`` with source node ``src[e]``:: - - out[e] = Wrows[e] @ x[src[e]] # (Dm, C) - Wrows[e][m, k] = wigner[e, coeff_index[m], k] # (Dm, D), k < D - - i.e. the eager reference ``bmm(D_to_m, x[src])`` where - ``D_to_m = wigner[:, :D, :D].index_select(1, coeff_index)``. - -``rotate_back`` (edge-local reduced frame -> global) - For every edge ``e``:: - - out[e] = Wcols[e] @ x_local[e] # (D, C) - Wcols[e][d, m] = wigner[e, d, coeff_index[m]] # (D, Dm), d < D - - i.e. the eager reference ``bmm(Dt_from_m, x_local)`` where - ``Dt_from_m = wigner[:, :D, :D].index_select(2, coeff_index)``. - -Design goals ------------- -1. **Fuse the gathers into the GEMM.** The eager / ``torch.compile`` path first - materializes ``D_to_m`` (or ``Dt_from_m``), shape ``(E, Dm, D)``, *and* - ``x[src]``, shape ``(E, D, C)``, before calling ``bmm``. For lmax 10 with - E=100k that is ~9 GB of scratch that is written and immediately re-read. - We instead gather the Wigner rows/columns (by ``coeff_index``) and the node - features (by ``src``) *inside* the kernel, so neither scratch tensor is ever - created. Each edge is one tiny GEMM; this also sidesteps the well-known - inefficiency of cuBLAS strided-batched GEMM on very small matrices. - -2. **Match eager FP32 accuracy.** Every ``tl.dot`` uses - ``input_precision="ieee"`` so the contraction runs in true IEEE FP32 (no - TF32). This keeps the potential-energy surface smooth. - -3. **Compose with SeZM's ``make_fx`` lowering *and* the AOTInductor freeze.** - The operators are functional ``torch.library.triton_op`` instances - (``mutates_args=()``) with registered fake kernels and autograd formulas; the - backward is itself a ``triton_op``, so ``make_fx(tracing_mode="symbolic")`` - can capture the energy path together with the force autograd graph used by - inference. Unlike ``torch.library.custom_op`` (opaque to the compiler, hence - emitted as a *runtime dispatcher* call that the C++ ``.pt2`` runtime cannot - resolve), a ``triton_op`` wraps its kernel launch in ``wrap_triton`` so - Inductor sees through to the Triton kernel and **bakes the cubin into the - AOTInductor package**. That is what lets the frozen ``.pt2`` run the Triton - path inside the LAMMPS C++ runtime (``DeepPotPTExpt`` / - ``AOTIModelPackageLoader``), with no Python op registration available. The - ``_use_triton`` device/dtype branch below stays a plain Python ``if``: the - op is opaque under ``make_fx`` (CPU trace), and Inductor resolves the branch - at compile time on the post-``move_to_device`` CUDA tensors, so CUDA fp32 - targets bake the Triton kernel while CPU / fp64 targets bake the eager - reference. - -Shapes / dtypes ---------------- -``x``/``x_local`` and ``wigner`` are float tensors; fp32 is the supported -precision for the smooth potential-energy surface, while fp16/bf16 inputs -accumulate in fp32. ``src`` and ``coeff_index`` are int64 tensors. ``E`` (edges) -may exceed 2**31 elements once multiplied by the per-edge matrix size, so all -kernels use int64 addressing. -""" - -from __future__ import ( - annotations, -) - -import torch -from torch import ( - Tensor, -) -from torch.library import ( - wrap_triton, -) - -from deepmd.dpmodel.descriptor.dpa4_nn.indexing import ( - build_m_major_index as _build_m_major_index_np, -) - -__all__ = [ - "TRITON_ROTATION_AVAILABLE", -] - - -def build_m_major_index(lmax: int, mmax: int, device: torch.device) -> Tensor: - """Torch m-major reduced coefficient index on ``device``. - - The dpmodel index builder is numpy-only; the Triton eager-fallback paths - need it as an int64 tensor on the working device. - """ - return torch.as_tensor(_build_m_major_index_np(int(lmax), int(mmax)), device=device) - - -try: - import triton - import triton.language as tl - - TRITON_ROTATION_AVAILABLE = True -except ImportError: # pragma: no cover - exercised only without triton - TRITON_ROTATION_AVAILABLE = False - - -# ====================================================================== -# Eager reference / fallback implementations -# ====================================================================== -def rotate_to_local_reference( - x: Tensor, - src: Tensor, - wigner: Tensor, - coeff_index: Tensor, - dim_full: int, -) -> Tensor: - """Eager ground-truth for ``rotate_to_local`` (``bmm(D_to_m, x[src])``).""" - d_to_m = wigner[:, :dim_full, :dim_full].index_select(1, coeff_index) - return torch.bmm(d_to_m, x.index_select(0, src)) - - -def rotate_back_reference( - x_local: Tensor, - wigner: Tensor, - coeff_index: Tensor, - dim_full: int, -) -> Tensor: - """Eager ground-truth for ``rotate_back`` (``bmm(Dt_from_m, x_local)``).""" - dt_from_m = wigner[:, :dim_full, :dim_full].index_select(2, coeff_index) - return torch.bmm(dt_from_m, x_local) - - -def _rotate_to_local_bwd_eager( - grad_out: Tensor, - x: Tensor, - src: Tensor, - wigner: Tensor, - coeff_index: Tensor, - dim_full: int, -) -> tuple[Tensor, Tensor]: - """Eager backward of ``rotate_to_local`` returning ``(grad_x, grad_wigner)``.""" - w_rows = wigner[:, :dim_full, :dim_full].index_select(1, coeff_index) # (E,Dm,D) - x_src = x.index_select(0, src) # (E,D,C) - grad_x_src = torch.bmm(w_rows.transpose(1, 2), grad_out) # (E,D,C) - grad_x = torch.zeros_like(x).index_add_(0, src, grad_x_src) - grad_rows = torch.bmm(grad_out, x_src.transpose(1, 2)) # (E,Dm,D) - grad_block = torch.zeros( - grad_out.shape[0], dim_full, dim_full, dtype=wigner.dtype, device=wigner.device - ) - grad_block.index_copy_(1, coeff_index, grad_rows) - grad_wigner = torch.zeros_like(wigner) - grad_wigner[:, :dim_full, :dim_full] = grad_block - return grad_x, grad_wigner - - -def _rotate_back_bwd_eager( - grad_out: Tensor, - x_local: Tensor, - wigner: Tensor, - coeff_index: Tensor, - dim_full: int, -) -> tuple[Tensor, Tensor]: - """Eager backward of ``rotate_back`` returning ``(grad_x_local, grad_wigner)``.""" - w_cols = wigner[:, :dim_full, :dim_full].index_select(2, coeff_index) # (E,D,Dm) - grad_x_local = torch.bmm(w_cols.transpose(1, 2), grad_out) # (E,Dm,C) - grad_cols = torch.bmm(grad_out, x_local.transpose(1, 2)) # (E,D,Dm) - grad_block = torch.zeros( - grad_out.shape[0], dim_full, dim_full, dtype=wigner.dtype, device=wigner.device - ) - grad_block.index_copy_(2, coeff_index, grad_cols) - grad_wigner = torch.zeros_like(wigner) - grad_wigner[:, :dim_full, :dim_full] = grad_block - return grad_x_local, grad_wigner - - -# ====================================================================== -# Tile-size helpers and autotuning configs -# ====================================================================== -def _tile_dim(value: int) -> int: - """Pick a single-tile edge: the next power of two, at least 16. - - Tiles spanning a whole dimension (the non-tiled ``N`` axis and the static - ``BLOCK_N``) must be a power of two (``tl.arange``) *and* a multiple of 16 - (``tl.dot``); powers of two ``>= 16`` satisfy both. Packed dims map as - ``16 -> 16`` (lmax 3), ``36 -> 64`` (lmax 5), ``64 -> 64`` (lmax 7), - ``121 -> 128`` (lmax 10), ``C=64 -> 64``. - """ - tile = 16 - target = max(int(value), 1) - while tile < target: - tile *= 2 - return tile - - -def _autotune_configs() -> list: - """A small curated set of (BLOCK_M, BLOCK_K, num_warps, num_stages) configs. - - The per-edge GEMMs are tiny (M, K, N <= 128). We tile the output-row axis - ``M`` across the grid and stream the contraction axis ``K`` in a pipelined - loop, so the dominant Wigner load overlaps with the matmul. Autotuning over - a handful of shapes lets one source kernel serve lmax 3..10 well (small - tiles for lmax 3, larger tiles / more warps for lmax 10). - """ - return [ - # Tiny tiles: best for lmax 3 (D=16), where a single 16x16 row tile and a - # one-shot K step behave like a per-edge matvec with minimal overhead. - triton.Config({"BLOCK_M": 16, "BLOCK_K": 16}, num_warps=1, num_stages=2), - triton.Config({"BLOCK_M": 16, "BLOCK_K": 16}, num_warps=2, num_stages=2), - triton.Config({"BLOCK_M": 32, "BLOCK_K": 16}, num_warps=2, num_stages=2), - triton.Config({"BLOCK_M": 16, "BLOCK_K": 64}, num_warps=2, num_stages=2), - triton.Config({"BLOCK_M": 32, "BLOCK_K": 32}, num_warps=4, num_stages=3), - triton.Config({"BLOCK_M": 32, "BLOCK_K": 64}, num_warps=4, num_stages=3), - triton.Config({"BLOCK_M": 64, "BLOCK_K": 32}, num_warps=4, num_stages=3), - triton.Config({"BLOCK_M": 64, "BLOCK_K": 64}, num_warps=4, num_stages=4), - triton.Config({"BLOCK_M": 64, "BLOCK_K": 64}, num_warps=8, num_stages=3), - triton.Config({"BLOCK_M": 128, "BLOCK_K": 32}, num_warps=8, num_stages=3), - ] - - -if TRITON_ROTATION_AVAILABLE: - _CONFIGS = _autotune_configs() - _KEY = ["dim_full", "reduced_dim", "channels"] - - # Block-diagonal kernels are fully unrolled over l (LMAX constexpr) and over - # each l-block, with channels vectorized -- there is no GEMM tile to tune, so - # we only sweep the warp count / pipeline depth, keyed on the channel width. - _BD_CONFIGS = [ - triton.Config({}, num_warps=1, num_stages=1), - triton.Config({}, num_warps=2, num_stages=1), - triton.Config({}, num_warps=4, num_stages=1), - triton.Config({}, num_warps=2, num_stages=2), - triton.Config({}, num_warps=4, num_stages=2), - ] - _BD_KEY = ["channels"] - - # ================================================================== - # Triton kernels - # - # Every kernel is one fused-gather GEMM ``C_out = A @ B`` with: - # * grid = (edge, ceil(M / BLOCK_M)) -- one program per (edge, row-tile), - # * a pipelined K-loop streaming BLOCK_K of the contraction at a time, - # * the Wigner row/column gather (by ``coeff_index``) and the node-feature - # gather (by ``src``) folded into the pointer arithmetic, so neither - # ``D_to_m``/``Dt_from_m`` nor ``x[src]`` is ever materialized. - # All stores overwrite their tile (idempotent), which keeps autotuning safe. - # ================================================================== - @triton.autotune(configs=_CONFIGS, key=_KEY) - @triton.jit - def _to_local_fwd_kernel( - x_ptr, - src_ptr, - w_ptr, - idx_ptr, - out_ptr, - n_edge, - reduced_dim, - dim_full, - channels, - x_sn, - x_sd, - x_sc, - w_se, - w_sr, - w_sk, - o_se, - o_sr, - o_sc, - BLOCK_M: tl.constexpr, - BLOCK_K: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - """``out[e,m,c] = sum_k W[e, coeff[m], k] * x[src[e], k, c]`` (M=Dm,K=D,N=C).""" - edge = tl.program_id(0).to(tl.int64) - row = tl.program_id(1) * BLOCK_M + tl.arange(0, BLOCK_M) # over Dm - chan = tl.arange(0, BLOCK_N) # over C - row_mask = row < reduced_dim - chan_mask = chan < channels - - src_idx = tl.load(src_ptr + edge).to(tl.int64) - coeff_rows = tl.load(idx_ptr + row, mask=row_mask, other=0).to(tl.int64) - - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - for k0 in range(0, tl.cdiv(dim_full, BLOCK_K)): - kk = k0 * BLOCK_K + tl.arange(0, BLOCK_K) # over D - k_mask = kk < dim_full - w_tile = tl.load( - w_ptr + edge * w_se + coeff_rows[:, None] * w_sr + kk[None, :] * w_sk, - mask=row_mask[:, None] & k_mask[None, :], - other=0.0, - ) # (BLOCK_M, BLOCK_K) = W[coeff[m], k] - x_tile = tl.load( - x_ptr + src_idx * x_sn + kk[:, None] * x_sd + chan[None, :] * x_sc, - mask=k_mask[:, None] & chan_mask[None, :], - other=0.0, - ) # (BLOCK_K, BLOCK_N) = x[src, k, c] - acc = tl.dot(w_tile.to(x_tile.dtype), x_tile, acc, input_precision="ieee") - - tl.store( - out_ptr + edge * o_se + row[:, None] * o_sr + chan[None, :] * o_sc, - acc.to(out_ptr.dtype.element_ty), - mask=row_mask[:, None] & chan_mask[None, :], - ) - - @triton.autotune(configs=_CONFIGS, key=_KEY, reset_to_zero=["gx_ptr"]) - @triton.jit - def _to_local_bwd_dx_kernel( - go_ptr, - src_ptr, - w_ptr, - idx_ptr, - gx_ptr, - n_edge, - reduced_dim, - dim_full, - channels, - go_se, - go_sr, - go_sc, - w_se, - w_sr, - w_sk, - gx_sn, - gx_sd, - gx_sc, - BLOCK_M: tl.constexpr, - BLOCK_K: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - """``grad_x[src[e],d,c] += sum_m W[e, coeff[m], d] * grad_out[e,m,c]``. - - (M=D, K=Dm, N=C). The per-edge source gradient is atomically scattered - straight into the zero-initialized ``grad_x`` (no ``x[src]``-sized - scratch). ``reset_to_zero`` keeps the autotuner's trial runs from - polluting the accumulator. - """ - edge = tl.program_id(0).to(tl.int64) - drow = tl.program_id(1) * BLOCK_M + tl.arange(0, BLOCK_M) # over D - chan = tl.arange(0, BLOCK_N) # over C - d_mask = drow < dim_full - chan_mask = chan < channels - - src_idx = tl.load(src_ptr + edge).to(tl.int64) - - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - for k0 in range(0, tl.cdiv(reduced_dim, BLOCK_K)): - mm = k0 * BLOCK_K + tl.arange(0, BLOCK_K) # over Dm - m_mask = mm < reduced_dim - coeff = tl.load(idx_ptr + mm, mask=m_mask, other=0).to(tl.int64) - w_tile = tl.load( - w_ptr + edge * w_se + coeff[None, :] * w_sr + drow[:, None] * w_sk, - mask=d_mask[:, None] & m_mask[None, :], - other=0.0, - ) # (BLOCK_M(d), BLOCK_K(m)) = W[coeff[m], d] - go_tile = tl.load( - go_ptr + edge * go_se + mm[:, None] * go_sr + chan[None, :] * go_sc, - mask=m_mask[:, None] & chan_mask[None, :], - other=0.0, - ) # (BLOCK_K(m), BLOCK_N(c)) - acc = tl.dot(w_tile.to(go_tile.dtype), go_tile, acc, input_precision="ieee") - - tl.atomic_add( - gx_ptr + src_idx * gx_sn + drow[:, None] * gx_sd + chan[None, :] * gx_sc, - acc, - mask=d_mask[:, None] & chan_mask[None, :], - ) - - @triton.autotune(configs=_CONFIGS, key=_KEY) - @triton.jit - def _to_local_bwd_dw_kernel( - go_ptr, - x_ptr, - src_ptr, - idx_ptr, - gw_ptr, - n_edge, - reduced_dim, - dim_full, - channels, - go_se, - go_sr, - go_sc, - x_sn, - x_sd, - x_sc, - gw_se, - gw_sr, - gw_sk, - BLOCK_M: tl.constexpr, - BLOCK_K: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - """``grad_W[e, coeff[m], d] = sum_c grad_out[e,m,c] * x[src[e], d, c]``. - - (M=Dm, K=C, N=D). Writes directly into rows ``coeff_index`` of the - zero-initialized ``grad_wigner``. - """ - edge = tl.program_id(0).to(tl.int64) - mrow = tl.program_id(1) * BLOCK_M + tl.arange(0, BLOCK_M) # over Dm - dcol = tl.arange(0, BLOCK_N) # over D - m_mask = mrow < reduced_dim - d_mask = dcol < dim_full - - coeff = tl.load(idx_ptr + mrow, mask=m_mask, other=0).to(tl.int64) - src_idx = tl.load(src_ptr + edge).to(tl.int64) - - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - for k0 in range(0, tl.cdiv(channels, BLOCK_K)): - cc = k0 * BLOCK_K + tl.arange(0, BLOCK_K) # over C - c_mask = cc < channels - go_tile = tl.load( - go_ptr + edge * go_se + mrow[:, None] * go_sr + cc[None, :] * go_sc, - mask=m_mask[:, None] & c_mask[None, :], - other=0.0, - ) # (BLOCK_M(m), BLOCK_K(c)) - x_tile = tl.load( - x_ptr + src_idx * x_sn + dcol[None, :] * x_sd + cc[:, None] * x_sc, - mask=c_mask[:, None] & d_mask[None, :], - other=0.0, - ) # (BLOCK_K(c), BLOCK_N(d)) = x[src, d, c] - acc = tl.dot(go_tile.to(x_tile.dtype), x_tile, acc, input_precision="ieee") - - tl.store( - gw_ptr + edge * gw_se + coeff[:, None] * gw_sr + dcol[None, :] * gw_sk, - acc.to(gw_ptr.dtype.element_ty), - mask=m_mask[:, None] & d_mask[None, :], - ) - - # ``rotate_back`` reads the Wigner *columns* selected by ``coeff_index``. - # Gathering columns of a row-major ``(E, D, D)`` tensor is uncoalesced, so - # instead we read *dense* Wigner rows (coalesced last axis) and gather / - # scatter the small ``x_local`` through the inverse permutation - # ``inv[k] = m if coeff[m]==k else -1``. For ``mmax==lmax`` (a full - # permutation) this is the same flop count with far better memory behaviour. - @triton.autotune(configs=_CONFIGS, key=_KEY) - @triton.jit - def _back_fwd_kernel( - xl_ptr, - w_ptr, - inv_ptr, - out_ptr, - n_edge, - reduced_dim, - dim_full, - channels, - xl_se, - xl_sr, - xl_sc, - w_se, - w_sr, - w_sk, - o_se, - o_sd, - o_sc, - BLOCK_M: tl.constexpr, - BLOCK_K: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - """``out[e,d,c] = sum_k W[e,d,k] * x_local[e, inv[k], c]`` (M=D, K=D, N=C).""" - edge = tl.program_id(0).to(tl.int64) - drow = tl.program_id(1) * BLOCK_M + tl.arange(0, BLOCK_M) # over D - chan = tl.arange(0, BLOCK_N) # over C - d_mask = drow < dim_full - chan_mask = chan < channels - - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - for k0 in range(0, tl.cdiv(dim_full, BLOCK_K)): - kk = k0 * BLOCK_K + tl.arange(0, BLOCK_K) # over D (contraction) - k_mask = kk < dim_full - inv_k = tl.load(inv_ptr + kk, mask=k_mask, other=-1).to(tl.int64) - keep = inv_k >= 0 - w_tile = tl.load( - w_ptr + edge * w_se + drow[:, None] * w_sr + kk[None, :] * w_sk, - mask=d_mask[:, None] & k_mask[None, :], - other=0.0, - ) # (BLOCK_M(d), BLOCK_K(k)) = W[d, k] (k contiguous -> coalesced) - xl_tile = tl.load( - xl_ptr + edge * xl_se + inv_k[:, None] * xl_sr + chan[None, :] * xl_sc, - mask=keep[:, None] & chan_mask[None, :], - other=0.0, - ) # (BLOCK_K(k), BLOCK_N(c)) = x_local[inv[k], c] - acc = tl.dot(w_tile.to(xl_tile.dtype), xl_tile, acc, input_precision="ieee") - - tl.store( - out_ptr + edge * o_se + drow[:, None] * o_sd + chan[None, :] * o_sc, - acc.to(out_ptr.dtype.element_ty), - mask=d_mask[:, None] & chan_mask[None, :], - ) - - @triton.autotune(configs=_CONFIGS, key=_KEY) - @triton.jit - def _back_bwd_dx_kernel( - go_ptr, - w_ptr, - inv_ptr, - gxl_ptr, - n_edge, - reduced_dim, - dim_full, - channels, - go_se, - go_sd, - go_sc, - w_se, - w_sr, - w_sk, - gxl_se, - gxl_sr, - gxl_sc, - BLOCK_M: tl.constexpr, - BLOCK_K: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - """``grad_x_local[e, inv[k], c] = sum_d W[e,d,k] * grad_out[e,d,c]``. - - (M=D, K=D, N=C). Computes the dense ``k``-indexed gradient with coalesced - Wigner reads, then scatters each full row ``k`` into reduced row - ``inv[k]`` of ``grad_x_local``. - """ - edge = tl.program_id(0).to(tl.int64) - krow = tl.program_id(1) * BLOCK_M + tl.arange(0, BLOCK_M) # over D - chan = tl.arange(0, BLOCK_N) # over C - k_mask = krow < dim_full - chan_mask = chan < channels - - inv_k = tl.load(inv_ptr + krow, mask=k_mask, other=-1).to(tl.int64) - keep = inv_k >= 0 - - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - for k0 in range(0, tl.cdiv(dim_full, BLOCK_K)): - dd = k0 * BLOCK_K + tl.arange(0, BLOCK_K) # over D (contraction) - d_mask = dd < dim_full - w_tile = tl.load( - w_ptr + edge * w_se + dd[None, :] * w_sr + krow[:, None] * w_sk, - mask=k_mask[:, None] & d_mask[None, :], - other=0.0, - ) # (BLOCK_M(k), BLOCK_K(d)) = W[d, k] (k contiguous -> coalesced) - go_tile = tl.load( - go_ptr + edge * go_se + dd[:, None] * go_sd + chan[None, :] * go_sc, - mask=d_mask[:, None] & chan_mask[None, :], - other=0.0, - ) # (BLOCK_K(d), BLOCK_N(c)) - acc = tl.dot(w_tile.to(go_tile.dtype), go_tile, acc, input_precision="ieee") - - tl.store( - gxl_ptr + edge * gxl_se + inv_k[:, None] * gxl_sr + chan[None, :] * gxl_sc, - acc.to(gxl_ptr.dtype.element_ty), - mask=keep[:, None] & chan_mask[None, :], - ) - - @triton.autotune(configs=_CONFIGS, key=_KEY) - @triton.jit - def _back_bwd_dw_kernel( - go_ptr, - xl_ptr, - inv_ptr, - gw_ptr, - n_edge, - reduced_dim, - dim_full, - channels, - go_se, - go_sd, - go_sc, - xl_se, - xl_sr, - xl_sc, - gw_se, - gw_sr, - gw_sk, - BLOCK_M: tl.constexpr, - BLOCK_K: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - """``grad_W[e,d,k] = sum_c grad_out[e,d,c] * x_local[e, inv[k], c]``. - - (M=D, K=C, N=D). Writes the dense ``(D, D)`` block of ``grad_wigner`` - with a coalesced last axis; columns ``k`` not selected by ``coeff_index`` - receive zero (``inv[k] < 0``), matching the eager column gather. - """ - edge = tl.program_id(0).to(tl.int64) - drow = tl.program_id(1) * BLOCK_M + tl.arange(0, BLOCK_M) # over D - kcol = tl.arange(0, BLOCK_N) # over D - d_mask = drow < dim_full - k_mask = kcol < dim_full - - inv_k = tl.load(inv_ptr + kcol, mask=k_mask, other=-1).to(tl.int64) - keep = inv_k >= 0 - - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - for k0 in range(0, tl.cdiv(channels, BLOCK_K)): - cc = k0 * BLOCK_K + tl.arange(0, BLOCK_K) # over C (contraction) - c_mask = cc < channels - go_tile = tl.load( - go_ptr + edge * go_se + drow[:, None] * go_sd + cc[None, :] * go_sc, - mask=d_mask[:, None] & c_mask[None, :], - other=0.0, - ) # (BLOCK_M(d), BLOCK_K(c)) - xl_tile = tl.load( - xl_ptr + edge * xl_se + inv_k[None, :] * xl_sr + cc[:, None] * xl_sc, - mask=c_mask[:, None] & keep[None, :], - other=0.0, - ) # (BLOCK_K(c), BLOCK_N(k)) = x_local[inv[k], c] - acc = tl.dot( - go_tile.to(xl_tile.dtype), xl_tile, acc, input_precision="ieee" - ) - - tl.store( - gw_ptr + edge * gw_se + drow[:, None] * gw_sr + kcol[None, :] * gw_sk, - acc.to(gw_ptr.dtype.element_ty), - mask=d_mask[:, None] & k_mask[None, :], - ) - - # ================================================================== - # Block-diagonal kernels (mmax == 1, block-diagonal Wigner-D) - # - # The Wigner-D matrix is block-diagonal by degree ``l``: block ``l`` is the - # ``(2l+1) x (2l+1)`` sub-matrix on rows/cols ``[l^2 : (l+1)^2]`` and every - # off-(l-block) entry is exactly 0. With ``mmax == 1`` the reduced layout - # keeps, per degree ``l``, the orders ``m in {0}`` (l == 0) or - # ``{0, -1, +1}`` (l >= 1). Output coefficient ``(l, m)`` therefore contracts - # ONLY over the ``2l+1`` inputs of block ``l`` -- never the full ``D``. - # - # The m-major reduced index and the packed Wigner row/col are pure functions - # of ``(l, m, LMAX)``:: - # - # reduced index: m=0 -> l, m=-1 -> LMAX+l, m=+1 -> 2*LMAX+l - # packed (l, m): l^2 + l + m (so m=0 -> l^2+l, m=-1 -> -1, m=+1 -> +1) - # - # so the kernels need no ``coeff_index`` tensor: with ``LMAX`` a constexpr we - # fully unroll over ``l`` and over each block, contracting exactly the - # structural non-zeros (no padding, no wasted FLOPs). Channels are the - # vectorized axis (``BLOCK_C`` spans the full width ``C``), so the backward - # Wigner gradient is a single in-program ``tl.sum`` over channels. - @triton.autotune(configs=_BD_CONFIGS, key=_BD_KEY) - @triton.jit - def _bd_to_local_fwd_kernel( - x_ptr, - src_ptr, - w_ptr, - out_ptr, - n_edge, - channels, - x_sn, - x_sd, - x_sc, - w_se, - w_sr, - w_sk, - o_se, - o_sr, - o_sc, - LMAX: tl.constexpr, - BLOCK_C: tl.constexpr, - ): - """``out[e,(l,m),c] = sum_{j} W[e, l^2+l+m, l^2+j] * x[src[e], l^2+j, c]``.""" - edge = tl.program_id(0).to(tl.int64) - chan = tl.arange(0, BLOCK_C) - cmask = chan < channels - src_idx = tl.load(src_ptr + edge).to(tl.int64) - - for l in tl.static_range(0, LMAX + 1): - base = l * l - r0 = base + l # packed row of order m=0 - acc0 = tl.zeros((BLOCK_C,), dtype=tl.float32) - acc_m = tl.zeros((BLOCK_C,), dtype=tl.float32) - acc_p = tl.zeros((BLOCK_C,), dtype=tl.float32) - for j in tl.static_range(0, 2 * l + 1): - col = base + j - x_vec = tl.load( - x_ptr + src_idx * x_sn + col * x_sd + chan * x_sc, - mask=cmask, - other=0.0, - ).to(tl.float32) - acc0 += tl.load(w_ptr + edge * w_se + r0 * w_sr + col * w_sk) * x_vec - if l >= 1: - acc_m += ( - tl.load(w_ptr + edge * w_se + (r0 - 1) * w_sr + col * w_sk) - * x_vec - ) - acc_p += ( - tl.load(w_ptr + edge * w_se + (r0 + 1) * w_sr + col * w_sk) - * x_vec - ) - tl.store( - out_ptr + edge * o_se + l * o_sr + chan * o_sc, - acc0.to(out_ptr.dtype.element_ty), - mask=cmask, - ) - if l >= 1: - tl.store( - out_ptr + edge * o_se + (LMAX + l) * o_sr + chan * o_sc, - acc_m.to(out_ptr.dtype.element_ty), - mask=cmask, - ) - tl.store( - out_ptr + edge * o_se + (2 * LMAX + l) * o_sr + chan * o_sc, - acc_p.to(out_ptr.dtype.element_ty), - mask=cmask, - ) - - @triton.autotune(configs=_BD_CONFIGS, key=_BD_KEY, reset_to_zero=["gx_ptr"]) - @triton.jit - def _bd_to_local_bwd_kernel( - go_ptr, - x_ptr, - src_ptr, - w_ptr, - gx_ptr, - gw_ptr, - n_edge, - channels, - go_se, - go_sr, - go_sc, - x_sn, - x_sd, - x_sc, - w_se, - w_sr, - w_sk, - gx_sn, - gx_sd, - gx_sc, - gw_se, - gw_sr, - gw_sk, - LMAX: tl.constexpr, - BLOCK_C: tl.constexpr, - ): - """Fused block-diagonal backward of ``rotate_to_local``. - - Per edge (full channel width in one program): scatters - ``grad_x[src, l^2+j, :] += sum_m W[l^2+l+m, l^2+j] * grad_out[(l,m), :]`` - and writes ``grad_W[l^2+l+m, l^2+j] = sum_c grad_out[(l,m),c] * x[l^2+j,c]`` - for the structural non-zeros only. - """ - edge = tl.program_id(0).to(tl.int64) - chan = tl.arange(0, BLOCK_C) - cmask = chan < channels - src_idx = tl.load(src_ptr + edge).to(tl.int64) - - for l in tl.static_range(0, LMAX + 1): - base = l * l - r0 = base + l - go0 = tl.load( - go_ptr + edge * go_se + l * go_sr + chan * go_sc, - mask=cmask, - other=0.0, - ).to(tl.float32) - if l >= 1: - go_m = tl.load( - go_ptr + edge * go_se + (LMAX + l) * go_sr + chan * go_sc, - mask=cmask, - other=0.0, - ).to(tl.float32) - go_p = tl.load( - go_ptr + edge * go_se + (2 * LMAX + l) * go_sr + chan * go_sc, - mask=cmask, - other=0.0, - ).to(tl.float32) - for j in tl.static_range(0, 2 * l + 1): - col = base + j - x_vec = tl.load( - x_ptr + src_idx * x_sn + col * x_sd + chan * x_sc, - mask=cmask, - other=0.0, - ).to(tl.float32) - w0 = tl.load(w_ptr + edge * w_se + r0 * w_sr + col * w_sk) - gx_row = w0 * go0 - tl.store( - gw_ptr + edge * gw_se + r0 * gw_sr + col * gw_sk, - tl.sum(go0 * x_vec).to(gw_ptr.dtype.element_ty), - ) - if l >= 1: - wm = tl.load(w_ptr + edge * w_se + (r0 - 1) * w_sr + col * w_sk) - wp = tl.load(w_ptr + edge * w_se + (r0 + 1) * w_sr + col * w_sk) - gx_row += wm * go_m + wp * go_p - tl.store( - gw_ptr + edge * gw_se + (r0 - 1) * gw_sr + col * gw_sk, - tl.sum(go_m * x_vec).to(gw_ptr.dtype.element_ty), - ) - tl.store( - gw_ptr + edge * gw_se + (r0 + 1) * gw_sr + col * gw_sk, - tl.sum(go_p * x_vec).to(gw_ptr.dtype.element_ty), - ) - tl.atomic_add( - gx_ptr + src_idx * gx_sn + col * gx_sd + chan * gx_sc, - gx_row, - mask=cmask, - ) - - @triton.autotune(configs=_BD_CONFIGS, key=_BD_KEY) - @triton.jit - def _bd_back_fwd_kernel( - xl_ptr, - w_ptr, - out_ptr, - n_edge, - channels, - xl_se, - xl_sr, - xl_sc, - w_se, - w_sr, - w_sk, - o_se, - o_sd, - o_sc, - LMAX: tl.constexpr, - BLOCK_C: tl.constexpr, - ): - """``out[e, l^2+j, c] = sum_m W[e, l^2+j, l^2+l+m] * x_local[(l,m), c]``.""" - edge = tl.program_id(0).to(tl.int64) - chan = tl.arange(0, BLOCK_C) - cmask = chan < channels - - for l in tl.static_range(0, LMAX + 1): - base = l * l - r0 = base + l # packed col of order m=0 - xl0 = tl.load( - xl_ptr + edge * xl_se + l * xl_sr + chan * xl_sc, - mask=cmask, - other=0.0, - ).to(tl.float32) - if l >= 1: - xl_m = tl.load( - xl_ptr + edge * xl_se + (LMAX + l) * xl_sr + chan * xl_sc, - mask=cmask, - other=0.0, - ).to(tl.float32) - xl_p = tl.load( - xl_ptr + edge * xl_se + (2 * LMAX + l) * xl_sr + chan * xl_sc, - mask=cmask, - other=0.0, - ).to(tl.float32) - for j in tl.static_range(0, 2 * l + 1): - d = base + j # full packed output row - acc = tl.load(w_ptr + edge * w_se + d * w_sr + r0 * w_sk) * xl0 - if l >= 1: - acc += ( - tl.load(w_ptr + edge * w_se + d * w_sr + (r0 - 1) * w_sk) * xl_m - ) - acc += ( - tl.load(w_ptr + edge * w_se + d * w_sr + (r0 + 1) * w_sk) * xl_p - ) - tl.store( - out_ptr + edge * o_se + d * o_sd + chan * o_sc, - acc.to(out_ptr.dtype.element_ty), - mask=cmask, - ) - - @triton.autotune(configs=_BD_CONFIGS, key=_BD_KEY) - @triton.jit - def _bd_back_bwd_kernel( - go_ptr, - xl_ptr, - w_ptr, - gxl_ptr, - gw_ptr, - n_edge, - channels, - go_se, - go_sd, - go_sc, - xl_se, - xl_sr, - xl_sc, - w_se, - w_sr, - w_sk, - gxl_se, - gxl_sr, - gxl_sc, - gw_se, - gw_sr, - gw_sk, - LMAX: tl.constexpr, - BLOCK_C: tl.constexpr, - ): - """Fused block-diagonal backward of ``rotate_back``. - - Per edge (full channel width in one program): writes - ``grad_x_local[(l,m), :] = sum_j W[l^2+j, l^2+l+m] * grad_out[l^2+j, :]`` - (no scatter -- ``x_local`` is per-edge) and - ``grad_W[l^2+j, l^2+l+m] = sum_c grad_out[l^2+j, c] * x_local[(l,m), c]``. - """ - edge = tl.program_id(0).to(tl.int64) - chan = tl.arange(0, BLOCK_C) - cmask = chan < channels - - for l in tl.static_range(0, LMAX + 1): - base = l * l - r0 = base + l # packed col of order m=0 - xl0 = tl.load( - xl_ptr + edge * xl_se + l * xl_sr + chan * xl_sc, - mask=cmask, - other=0.0, - ).to(tl.float32) - gxl0 = tl.zeros((BLOCK_C,), dtype=tl.float32) - if l >= 1: - xl_m = tl.load( - xl_ptr + edge * xl_se + (LMAX + l) * xl_sr + chan * xl_sc, - mask=cmask, - other=0.0, - ).to(tl.float32) - xl_p = tl.load( - xl_ptr + edge * xl_se + (2 * LMAX + l) * xl_sr + chan * xl_sc, - mask=cmask, - other=0.0, - ).to(tl.float32) - gxl_m = tl.zeros((BLOCK_C,), dtype=tl.float32) - gxl_p = tl.zeros((BLOCK_C,), dtype=tl.float32) - for j in tl.static_range(0, 2 * l + 1): - d = base + j # full packed row (output of forward / grad_out row) - go_d = tl.load( - go_ptr + edge * go_se + d * go_sd + chan * go_sc, - mask=cmask, - other=0.0, - ).to(tl.float32) - w0 = tl.load(w_ptr + edge * w_se + d * w_sr + r0 * w_sk) - gxl0 += w0 * go_d - tl.store( - gw_ptr + edge * gw_se + d * gw_sr + r0 * gw_sk, - tl.sum(go_d * xl0).to(gw_ptr.dtype.element_ty), - ) - if l >= 1: - wm = tl.load(w_ptr + edge * w_se + d * w_sr + (r0 - 1) * w_sk) - wp = tl.load(w_ptr + edge * w_se + d * w_sr + (r0 + 1) * w_sk) - gxl_m += wm * go_d - gxl_p += wp * go_d - tl.store( - gw_ptr + edge * gw_se + d * gw_sr + (r0 - 1) * gw_sk, - tl.sum(go_d * xl_m).to(gw_ptr.dtype.element_ty), - ) - tl.store( - gw_ptr + edge * gw_se + d * gw_sr + (r0 + 1) * gw_sk, - tl.sum(go_d * xl_p).to(gw_ptr.dtype.element_ty), - ) - tl.store( - gxl_ptr + edge * gxl_se + l * gxl_sr + chan * gxl_sc, - gxl0.to(gxl_ptr.dtype.element_ty), - mask=cmask, - ) - if l >= 1: - tl.store( - gxl_ptr + edge * gxl_se + (LMAX + l) * gxl_sr + chan * gxl_sc, - gxl_m.to(gxl_ptr.dtype.element_ty), - mask=cmask, - ) - tl.store( - gxl_ptr + edge * gxl_se + (2 * LMAX + l) * gxl_sr + chan * gxl_sc, - gxl_p.to(gxl_ptr.dtype.element_ty), - mask=cmask, - ) - - @triton.autotune(configs=_BD_CONFIGS, key=["channels"]) - @triton.jit - def _bd_back_so2_fwd_kernel( - xl_ptr, - w_ptr, - out_ptr, - n_edge, - channels, - xl_se, - xl_sf, - xl_sr, - xl_sc, - w_se, - w_sr, - w_sk, - o_se, - o_sd, - o_sc, - LMAX: tl.constexpr, - FOCUS_DIM: tl.constexpr, - BLOCK_C: tl.constexpr, - ): - """Block-diagonal rotate_back reading the per-focus layout in place. - - ``out[e, l^2+j, c] = sum_m W[e, l^2+j, l^2+l+m] * x_local[e, f, (l,m), cf]`` - with ``c = f * FOCUS_DIM + cf``. Decoding the channel as ``(f, cf)`` folds - the ``(F, D_m, Cf) -> (D_m, C_wide)`` transpose into the addressing, so the - caller passes the SO(2) focus tensor without an explicit copy. - """ - edge = tl.program_id(0).to(tl.int64) - chan = tl.arange(0, BLOCK_C) - cmask = chan < channels - xl_co = (chan // FOCUS_DIM) * xl_sf + (chan % FOCUS_DIM) * xl_sc - for l in tl.static_range(0, LMAX + 1): - base = l * l - r0 = base + l - xl0 = tl.load( - xl_ptr + edge * xl_se + l * xl_sr + xl_co, mask=cmask, other=0.0 - ).to(tl.float32) - if l >= 1: - xl_m = tl.load( - xl_ptr + edge * xl_se + (LMAX + l) * xl_sr + xl_co, - mask=cmask, - other=0.0, - ).to(tl.float32) - xl_p = tl.load( - xl_ptr + edge * xl_se + (2 * LMAX + l) * xl_sr + xl_co, - mask=cmask, - other=0.0, - ).to(tl.float32) - for j in tl.static_range(0, 2 * l + 1): - d = base + j - acc = tl.load(w_ptr + edge * w_se + d * w_sr + r0 * w_sk) * xl0 - if l >= 1: - acc += ( - tl.load(w_ptr + edge * w_se + d * w_sr + (r0 - 1) * w_sk) * xl_m - ) - acc += ( - tl.load(w_ptr + edge * w_se + d * w_sr + (r0 + 1) * w_sk) * xl_p - ) - tl.store( - out_ptr + edge * o_se + d * o_sd + chan * o_sc, - acc.to(out_ptr.dtype.element_ty), - mask=cmask, - ) - - @triton.autotune(configs=_BD_CONFIGS, key=["channels"]) - @triton.jit - def _bd_back_so2_bwd_kernel( - go_ptr, - xl_ptr, - w_ptr, - gxl_ptr, - gw_ptr, - n_edge, - channels, - go_se, - go_sd, - go_sc, - xl_se, - xl_sf, - xl_sr, - xl_sc, - w_se, - w_sr, - w_sk, - gxl_se, - gxl_sf, - gxl_sr, - gxl_sc, - gw_se, - gw_sr, - gw_sk, - LMAX: tl.constexpr, - FOCUS_DIM: tl.constexpr, - BLOCK_C: tl.constexpr, - ): - """Backward of :func:`_bd_back_so2_fwd_kernel`. - - Writes ``grad_x_local`` in the per-focus layout (decoding the channel as - ``(f, cf)`` exactly as the forward) and accumulates ``grad_W`` over the - full channel width, i.e. summed across focus streams. - """ - edge = tl.program_id(0).to(tl.int64) - chan = tl.arange(0, BLOCK_C) - cmask = chan < channels - xl_co = (chan // FOCUS_DIM) * xl_sf + (chan % FOCUS_DIM) * xl_sc - gxl_co = (chan // FOCUS_DIM) * gxl_sf + (chan % FOCUS_DIM) * gxl_sc - for l in tl.static_range(0, LMAX + 1): - base = l * l - r0 = base + l - xl0 = tl.load( - xl_ptr + edge * xl_se + l * xl_sr + xl_co, mask=cmask, other=0.0 - ).to(tl.float32) - gxl0 = tl.zeros((BLOCK_C,), dtype=tl.float32) - if l >= 1: - xl_m = tl.load( - xl_ptr + edge * xl_se + (LMAX + l) * xl_sr + xl_co, - mask=cmask, - other=0.0, - ).to(tl.float32) - xl_p = tl.load( - xl_ptr + edge * xl_se + (2 * LMAX + l) * xl_sr + xl_co, - mask=cmask, - other=0.0, - ).to(tl.float32) - gxl_m = tl.zeros((BLOCK_C,), dtype=tl.float32) - gxl_p = tl.zeros((BLOCK_C,), dtype=tl.float32) - for j in tl.static_range(0, 2 * l + 1): - d = base + j - go_d = tl.load( - go_ptr + edge * go_se + d * go_sd + chan * go_sc, - mask=cmask, - other=0.0, - ).to(tl.float32) - gxl0 += tl.load(w_ptr + edge * w_se + d * w_sr + r0 * w_sk) * go_d - tl.store( - gw_ptr + edge * gw_se + d * gw_sr + r0 * gw_sk, - tl.sum(go_d * xl0).to(gw_ptr.dtype.element_ty), - ) - if l >= 1: - gxl_m += ( - tl.load(w_ptr + edge * w_se + d * w_sr + (r0 - 1) * w_sk) * go_d - ) - gxl_p += ( - tl.load(w_ptr + edge * w_se + d * w_sr + (r0 + 1) * w_sk) * go_d - ) - tl.store( - gw_ptr + edge * gw_se + d * gw_sr + (r0 - 1) * gw_sk, - tl.sum(go_d * xl_m).to(gw_ptr.dtype.element_ty), - ) - tl.store( - gw_ptr + edge * gw_se + d * gw_sr + (r0 + 1) * gw_sk, - tl.sum(go_d * xl_p).to(gw_ptr.dtype.element_ty), - ) - tl.store( - gxl_ptr + edge * gxl_se + l * gxl_sr + gxl_co, - gxl0.to(gxl_ptr.dtype.element_ty), - mask=cmask, - ) - if l >= 1: - tl.store( - gxl_ptr + edge * gxl_se + (LMAX + l) * gxl_sr + gxl_co, - gxl_m.to(gxl_ptr.dtype.element_ty), - mask=cmask, - ) - tl.store( - gxl_ptr + edge * gxl_se + (2 * LMAX + l) * gxl_sr + gxl_co, - gxl_p.to(gxl_ptr.dtype.element_ty), - mask=cmask, - ) - - -# ====================================================================== -# Triton launch wrappers -# ====================================================================== -def _grid_over_rows(n_edge: int, rows: int): - """Grid callable: one program per (edge, BLOCK_M-sized row tile).""" - return lambda meta: (n_edge, triton.cdiv(rows, meta["BLOCK_M"])) - - -def _has_no_edges(n_edge: int) -> bool: - """Return true for eager zero-edge calls without guarding symbolic edges. - - Under ``torch.library.triton_op`` decomposition (AOTInductor freeze), the - edge dimension can be a data-dependent SymInt produced from the neighbour - list. Converting it to a Python ``int`` would force a guard such as - ``u0 + 2`` and abort export. We only need the zero-edge early return in - eager Python execution; compiled production graphs always see a non-empty - representative trace and use dynamic shapes for later calls. - """ - return type(n_edge) is int and n_edge == 0 - - -def _inverse_index(coeff_index: Tensor, dim_full: int) -> Tensor: - """Inverse permutation ``inv[k] = m`` where ``coeff_index[m] == k`` else ``-1``. - - Maps a full packed position ``k`` back to its reduced-layout slot. Used by the - ``rotate_back`` kernels so they can read dense Wigner rows (coalesced) and - gather/scatter the small ``x_local`` instead of gathering Wigner columns. - """ - inv = torch.full((int(dim_full),), -1, dtype=torch.int64, device=coeff_index.device) - inv[coeff_index] = torch.arange( - coeff_index.numel(), dtype=torch.int64, device=coeff_index.device - ) - return inv - - -def _launch_rotate_to_local_fwd( - x: Tensor, - src: Tensor, - wigner: Tensor, - coeff_index: Tensor, - dim_full: int, -) -> Tensor: - n_edge = src.shape[0] - reduced_dim = int(coeff_index.shape[0]) - channels = int(x.shape[2]) - out = torch.empty((n_edge, reduced_dim, channels), dtype=x.dtype, device=x.device) - if _has_no_edges(n_edge): - return out - wrap_triton(_to_local_fwd_kernel)[_grid_over_rows(n_edge, reduced_dim)]( - x, - src, - wigner, - coeff_index, - out, - n_edge, - reduced_dim, - dim_full, - channels, - x.stride(0), - x.stride(1), - x.stride(2), - wigner.stride(0), - wigner.stride(1), - wigner.stride(2), - out.stride(0), - out.stride(1), - out.stride(2), - BLOCK_N=_tile_dim(channels), - ) - return out - - -def _launch_rotate_to_local_bwd( - grad_out: Tensor, - x: Tensor, - src: Tensor, - wigner: Tensor, - coeff_index: Tensor, - dim_full: int, -) -> tuple[Tensor, Tensor]: - n_edge = src.shape[0] - reduced_dim = int(coeff_index.shape[0]) - channels = int(x.shape[2]) - grad_x = torch.zeros_like(x) - grad_wigner = torch.zeros_like(wigner) - if _has_no_edges(n_edge): - return grad_x, grad_wigner - - # --- grad_x: per-edge GEMM atomically scattered into grad_x by src --- - wrap_triton(_to_local_bwd_dx_kernel)[_grid_over_rows(n_edge, dim_full)]( - grad_out, - src, - wigner, - coeff_index, - grad_x, - n_edge, - reduced_dim, - dim_full, - channels, - grad_out.stride(0), - grad_out.stride(1), - grad_out.stride(2), - wigner.stride(0), - wigner.stride(1), - wigner.stride(2), - grad_x.stride(0), - grad_x.stride(1), - grad_x.stride(2), - BLOCK_N=_tile_dim(channels), - ) - - # --- grad_wigner: per-edge GEMM written into rows ``coeff_index`` --- - wrap_triton(_to_local_bwd_dw_kernel)[_grid_over_rows(n_edge, reduced_dim)]( - grad_out, - x, - src, - coeff_index, - grad_wigner, - n_edge, - reduced_dim, - dim_full, - channels, - grad_out.stride(0), - grad_out.stride(1), - grad_out.stride(2), - x.stride(0), - x.stride(1), - x.stride(2), - grad_wigner.stride(0), - grad_wigner.stride(1), - grad_wigner.stride(2), - BLOCK_N=_tile_dim(dim_full), - ) - return grad_x, grad_wigner - - -def _launch_rotate_back_fwd( - x_local: Tensor, - wigner: Tensor, - coeff_index: Tensor, - dim_full: int, -) -> Tensor: - n_edge = x_local.shape[0] - reduced_dim = int(coeff_index.shape[0]) - channels = int(x_local.shape[2]) - out = torch.empty( - (n_edge, dim_full, channels), dtype=x_local.dtype, device=x_local.device - ) - if _has_no_edges(n_edge): - return out - inv_index = _inverse_index(coeff_index, dim_full) - wrap_triton(_back_fwd_kernel)[_grid_over_rows(n_edge, dim_full)]( - x_local, - wigner, - inv_index, - out, - n_edge, - reduced_dim, - dim_full, - channels, - x_local.stride(0), - x_local.stride(1), - x_local.stride(2), - wigner.stride(0), - wigner.stride(1), - wigner.stride(2), - out.stride(0), - out.stride(1), - out.stride(2), - BLOCK_N=_tile_dim(channels), - ) - return out - - -def _launch_rotate_back_bwd( - grad_out: Tensor, - x_local: Tensor, - wigner: Tensor, - coeff_index: Tensor, - dim_full: int, -) -> tuple[Tensor, Tensor]: - n_edge = x_local.shape[0] - reduced_dim = int(coeff_index.shape[0]) - channels = int(x_local.shape[2]) - grad_x_local = torch.empty_like(x_local) - grad_wigner = torch.zeros_like(wigner) - if _has_no_edges(n_edge): - return grad_x_local, grad_wigner - - inv_index = _inverse_index(coeff_index, dim_full) - wrap_triton(_back_bwd_dx_kernel)[_grid_over_rows(n_edge, dim_full)]( - grad_out, - wigner, - inv_index, - grad_x_local, - n_edge, - reduced_dim, - dim_full, - channels, - grad_out.stride(0), - grad_out.stride(1), - grad_out.stride(2), - wigner.stride(0), - wigner.stride(1), - wigner.stride(2), - grad_x_local.stride(0), - grad_x_local.stride(1), - grad_x_local.stride(2), - BLOCK_N=_tile_dim(channels), - ) - wrap_triton(_back_bwd_dw_kernel)[_grid_over_rows(n_edge, dim_full)]( - grad_out, - x_local, - inv_index, - grad_wigner, - n_edge, - reduced_dim, - dim_full, - channels, - grad_out.stride(0), - grad_out.stride(1), - grad_out.stride(2), - x_local.stride(0), - x_local.stride(1), - x_local.stride(2), - grad_wigner.stride(0), - grad_wigner.stride(1), - grad_wigner.stride(2), - BLOCK_N=_tile_dim(dim_full), - ) - return grad_x_local, grad_wigner - - -# ====================================================================== -# Block-diagonal launch wrappers (mmax == 1) -# ====================================================================== -def _launch_bd_to_local_fwd( - x: Tensor, src: Tensor, wigner: Tensor, lmax: int -) -> Tensor: - n_edge = src.shape[0] - channels = int(x.shape[2]) - out = torch.empty((n_edge, 3 * lmax + 1, channels), dtype=x.dtype, device=x.device) - if _has_no_edges(n_edge): - return out - wrap_triton(_bd_to_local_fwd_kernel)[(n_edge,)]( - x, - src, - wigner, - out, - n_edge, - channels, - x.stride(0), - x.stride(1), - x.stride(2), - wigner.stride(0), - wigner.stride(1), - wigner.stride(2), - out.stride(0), - out.stride(1), - out.stride(2), - LMAX=lmax, - BLOCK_C=_tile_dim(channels), - ) - return out - - -def _launch_bd_to_local_bwd( - grad_out: Tensor, x: Tensor, src: Tensor, wigner: Tensor, lmax: int -) -> tuple[Tensor, Tensor]: - n_edge = src.shape[0] - channels = int(x.shape[2]) - grad_x = torch.zeros_like(x) - grad_wigner = torch.zeros_like(wigner) - if _has_no_edges(n_edge): - return grad_x, grad_wigner - wrap_triton(_bd_to_local_bwd_kernel)[(n_edge,)]( - grad_out, - x, - src, - wigner, - grad_x, - grad_wigner, - n_edge, - channels, - grad_out.stride(0), - grad_out.stride(1), - grad_out.stride(2), - x.stride(0), - x.stride(1), - x.stride(2), - wigner.stride(0), - wigner.stride(1), - wigner.stride(2), - grad_x.stride(0), - grad_x.stride(1), - grad_x.stride(2), - grad_wigner.stride(0), - grad_wigner.stride(1), - grad_wigner.stride(2), - LMAX=lmax, - BLOCK_C=_tile_dim(channels), - ) - return grad_x, grad_wigner - - -def _launch_bd_back_fwd(x_local: Tensor, wigner: Tensor, lmax: int) -> Tensor: - n_edge = x_local.shape[0] - channels = int(x_local.shape[2]) - dim_full = (lmax + 1) ** 2 - out = torch.empty( - (n_edge, dim_full, channels), dtype=x_local.dtype, device=x_local.device - ) - if _has_no_edges(n_edge): - return out - wrap_triton(_bd_back_fwd_kernel)[(n_edge,)]( - x_local, - wigner, - out, - n_edge, - channels, - x_local.stride(0), - x_local.stride(1), - x_local.stride(2), - wigner.stride(0), - wigner.stride(1), - wigner.stride(2), - out.stride(0), - out.stride(1), - out.stride(2), - LMAX=lmax, - BLOCK_C=_tile_dim(channels), - ) - return out - - -def _launch_bd_back_bwd( - grad_out: Tensor, x_local: Tensor, wigner: Tensor, lmax: int -) -> tuple[Tensor, Tensor]: - n_edge = x_local.shape[0] - channels = int(x_local.shape[2]) - grad_x_local = torch.empty_like(x_local) - grad_wigner = torch.zeros_like(wigner) - if _has_no_edges(n_edge): - return grad_x_local, grad_wigner - wrap_triton(_bd_back_bwd_kernel)[(n_edge,)]( - grad_out, - x_local, - wigner, - grad_x_local, - grad_wigner, - n_edge, - channels, - grad_out.stride(0), - grad_out.stride(1), - grad_out.stride(2), - x_local.stride(0), - x_local.stride(1), - x_local.stride(2), - wigner.stride(0), - wigner.stride(1), - wigner.stride(2), - grad_x_local.stride(0), - grad_x_local.stride(1), - grad_x_local.stride(2), - grad_wigner.stride(0), - grad_wigner.stride(1), - grad_wigner.stride(2), - LMAX=lmax, - BLOCK_C=_tile_dim(channels), - ) - return grad_x_local, grad_wigner - - -def _launch_bd_back_so2_fwd(x_local_4d: Tensor, wigner: Tensor, lmax: int) -> Tensor: - n_edge = x_local_4d.shape[0] - n_focus = int(x_local_4d.shape[1]) - focus_dim = int(x_local_4d.shape[3]) - channels = n_focus * focus_dim - dim_full = (lmax + 1) ** 2 - out = torch.empty( - (n_edge, dim_full, channels), dtype=x_local_4d.dtype, device=x_local_4d.device - ) - if _has_no_edges(n_edge): - return out - wrap_triton(_bd_back_so2_fwd_kernel)[(n_edge,)]( - x_local_4d, - wigner, - out, - n_edge, - channels, - x_local_4d.stride(0), - x_local_4d.stride(1), - x_local_4d.stride(2), - x_local_4d.stride(3), - wigner.stride(0), - wigner.stride(1), - wigner.stride(2), - out.stride(0), - out.stride(1), - out.stride(2), - LMAX=lmax, - FOCUS_DIM=focus_dim, - BLOCK_C=_tile_dim(channels), - ) - return out - - -def _launch_bd_back_so2_bwd( - grad_out: Tensor, x_local_4d: Tensor, wigner: Tensor, lmax: int -) -> tuple[Tensor, Tensor]: - n_edge = x_local_4d.shape[0] - n_focus = int(x_local_4d.shape[1]) - focus_dim = int(x_local_4d.shape[3]) - channels = n_focus * focus_dim - grad_x_local = torch.empty_like(x_local_4d) - grad_wigner = torch.zeros_like(wigner) - if _has_no_edges(n_edge): - return grad_x_local, grad_wigner - wrap_triton(_bd_back_so2_bwd_kernel)[(n_edge,)]( - grad_out, - x_local_4d, - wigner, - grad_x_local, - grad_wigner, - n_edge, - channels, - grad_out.stride(0), - grad_out.stride(1), - grad_out.stride(2), - x_local_4d.stride(0), - x_local_4d.stride(1), - x_local_4d.stride(2), - x_local_4d.stride(3), - wigner.stride(0), - wigner.stride(1), - wigner.stride(2), - grad_x_local.stride(0), - grad_x_local.stride(1), - grad_x_local.stride(2), - grad_x_local.stride(3), - grad_wigner.stride(0), - grad_wigner.stride(1), - grad_wigner.stride(2), - LMAX=lmax, - FOCUS_DIM=focus_dim, - BLOCK_C=_tile_dim(channels), - ) - return grad_x_local, grad_wigner - - -# ====================================================================== -# Dispatch helpers (triton on CUDA float, eager otherwise) -# ====================================================================== -def _use_triton(tensor: Tensor) -> bool: - return ( - TRITON_ROTATION_AVAILABLE - and tensor.is_cuda - and tensor.dtype in (torch.float16, torch.bfloat16, torch.float32) - ) - - -def _rotate_to_local_impl( - x: Tensor, - src: Tensor, - wigner: Tensor, - coeff_index: Tensor, - dim_full: int, -) -> Tensor: - if not _use_triton(x): - return rotate_to_local_reference(x, src, wigner, coeff_index, dim_full) - return _launch_rotate_to_local_fwd( - x, src.contiguous(), wigner, coeff_index.contiguous(), int(dim_full) - ) - - -def _rotate_to_local_bwd_impl( - grad_out: Tensor, - x: Tensor, - src: Tensor, - wigner: Tensor, - coeff_index: Tensor, - dim_full: int, -) -> tuple[Tensor, Tensor]: - if not _use_triton(x): - return _rotate_to_local_bwd_eager( - grad_out, x, src, wigner, coeff_index, dim_full - ) - return _launch_rotate_to_local_bwd( - grad_out.contiguous(), - x, - src.contiguous(), - wigner, - coeff_index.contiguous(), - int(dim_full), - ) - - -def _rotate_back_impl( - x_local: Tensor, - wigner: Tensor, - coeff_index: Tensor, - dim_full: int, -) -> Tensor: - if not _use_triton(x_local): - return rotate_back_reference(x_local, wigner, coeff_index, dim_full) - return _launch_rotate_back_fwd( - x_local, wigner, coeff_index.contiguous(), int(dim_full) - ) - - -def _rotate_back_bwd_impl( - grad_out: Tensor, - x_local: Tensor, - wigner: Tensor, - coeff_index: Tensor, - dim_full: int, -) -> tuple[Tensor, Tensor]: - if not _use_triton(x_local): - return _rotate_back_bwd_eager(grad_out, x_local, wigner, coeff_index, dim_full) - return _launch_rotate_back_bwd( - grad_out.contiguous(), - x_local, - wigner, - coeff_index.contiguous(), - int(dim_full), - ) - - -# --- block-diagonal impls (mmax == 1; assume block-diagonal Wigner-D) --- -def _block_rotate_to_local_impl( - x: Tensor, src: Tensor, wigner: Tensor, lmax: int -) -> Tensor: - if not _use_triton(x): - coeff = build_m_major_index(int(lmax), 1, device=x.device) - return rotate_to_local_reference(x, src, wigner, coeff, (int(lmax) + 1) ** 2) - return _launch_bd_to_local_fwd(x, src.contiguous(), wigner, int(lmax)) - - -def _block_rotate_to_local_bwd_impl( - grad_out: Tensor, x: Tensor, src: Tensor, wigner: Tensor, lmax: int -) -> tuple[Tensor, Tensor]: - if not _use_triton(x): - coeff = build_m_major_index(int(lmax), 1, device=x.device) - return _rotate_to_local_bwd_eager( - grad_out, x, src, wigner, coeff, (int(lmax) + 1) ** 2 - ) - return _launch_bd_to_local_bwd( - grad_out.contiguous(), x, src.contiguous(), wigner, int(lmax) - ) - - -def _block_rotate_back_impl(x_local: Tensor, wigner: Tensor, lmax: int) -> Tensor: - if not _use_triton(x_local): - coeff = build_m_major_index(int(lmax), 1, device=x_local.device) - return rotate_back_reference(x_local, wigner, coeff, (int(lmax) + 1) ** 2) - return _launch_bd_back_fwd(x_local, wigner, int(lmax)) - - -def _block_rotate_back_bwd_impl( - grad_out: Tensor, x_local: Tensor, wigner: Tensor, lmax: int -) -> tuple[Tensor, Tensor]: - if not _use_triton(x_local): - coeff = build_m_major_index(int(lmax), 1, device=x_local.device) - return _rotate_back_bwd_eager( - grad_out, x_local, wigner, coeff, (int(lmax) + 1) ** 2 - ) - return _launch_bd_back_bwd(grad_out.contiguous(), x_local, wigner, int(lmax)) - - -# ====================================================================== -# Functional triton_op + fake + autograd registration -# ====================================================================== -# Forward and backward are both *functional* triton_ops (mutates_args=()), so -# functionalization keeps the full gradient path -- including grad w.r.t. -# ``wigner`` -- intact under ``torch.compile``. ``triton_op`` (vs ``custom_op``) -# additionally lets Inductor see through to the wrapped Triton kernel and bake -# the cubin into the AOTInductor ``.pt2`` so the LAMMPS C++ runtime needs no -# Python registration. - -_rotate_to_local_op = torch.library.triton_op( - "dpa4_triton::rotate_to_local", mutates_args=() -)(_rotate_to_local_impl) - -_rotate_to_local_bwd_op = torch.library.triton_op( - "dpa4_triton::rotate_to_local_bwd", mutates_args=() -)(_rotate_to_local_bwd_impl) - -_rotate_back_op = torch.library.triton_op("dpa4_triton::rotate_back", mutates_args=())( - _rotate_back_impl -) - -_rotate_back_bwd_op = torch.library.triton_op( - "dpa4_triton::rotate_back_bwd", mutates_args=() -)(_rotate_back_bwd_impl) - - -@_rotate_to_local_op.register_fake -def _(x, src, wigner, coeff_index, dim_full): - return x.new_empty((src.shape[0], coeff_index.shape[0], x.shape[2])) - - -@_rotate_to_local_bwd_op.register_fake -def _(grad_out, x, src, wigner, coeff_index, dim_full): - return torch.empty_like(x), torch.empty_like(wigner) - - -@_rotate_back_op.register_fake -def _(x_local, wigner, coeff_index, dim_full): - return x_local.new_empty((x_local.shape[0], dim_full, x_local.shape[2])) - - -@_rotate_back_bwd_op.register_fake -def _(grad_out, x_local, wigner, coeff_index, dim_full): - return torch.empty_like(x_local), torch.empty_like(wigner) - - -def _rotate_to_local_setup_context(ctx, inputs, output): - x, src, wigner, coeff_index, dim_full = inputs - ctx.save_for_backward(x, src, wigner, coeff_index) - ctx.dim_full = dim_full - - -def _rotate_to_local_backward(ctx, grad_out): - x, src, wigner, coeff_index = ctx.saved_tensors - grad_x, grad_wigner = _rotate_to_local_bwd_op( - grad_out, x, src, wigner, coeff_index, ctx.dim_full - ) - return grad_x, None, grad_wigner, None, None - - -def _rotate_back_setup_context(ctx, inputs, output): - x_local, wigner, coeff_index, dim_full = inputs - ctx.save_for_backward(x_local, wigner, coeff_index) - ctx.dim_full = dim_full - - -def _rotate_back_backward(ctx, grad_out): - x_local, wigner, coeff_index = ctx.saved_tensors - grad_x_local, grad_wigner = _rotate_back_bwd_op( - grad_out, x_local, wigner, coeff_index, ctx.dim_full - ) - return grad_x_local, grad_wigner, None, None - - -_rotate_to_local_op.register_autograd( - _rotate_to_local_backward, setup_context=_rotate_to_local_setup_context -) -_rotate_back_op.register_autograd( - _rotate_back_backward, setup_context=_rotate_back_setup_context -) - - -# --- block-diagonal custom ops (carry only ``lmax``; no coeff_index tensor) --- -_block_to_local_op = torch.library.triton_op( - "dpa4_triton::rotate_to_local_block", mutates_args=() -)(_block_rotate_to_local_impl) - -_block_to_local_bwd_op = torch.library.triton_op( - "dpa4_triton::rotate_to_local_block_bwd", mutates_args=() -)(_block_rotate_to_local_bwd_impl) - -_block_back_op = torch.library.triton_op( - "dpa4_triton::rotate_back_block", mutates_args=() -)(_block_rotate_back_impl) - -_block_back_bwd_op = torch.library.triton_op( - "dpa4_triton::rotate_back_block_bwd", mutates_args=() -)(_block_rotate_back_bwd_impl) - - -@_block_to_local_op.register_fake -def _(x, src, wigner, lmax): - return x.new_empty((src.shape[0], 3 * int(lmax) + 1, x.shape[2])) - - -@_block_to_local_bwd_op.register_fake -def _(grad_out, x, src, wigner, lmax): - return torch.empty_like(x), torch.empty_like(wigner) - - -@_block_back_op.register_fake -def _(x_local, wigner, lmax): - return x_local.new_empty((x_local.shape[0], (int(lmax) + 1) ** 2, x_local.shape[2])) - - -@_block_back_bwd_op.register_fake -def _(grad_out, x_local, wigner, lmax): - return torch.empty_like(x_local), torch.empty_like(wigner) - - -def _block_to_local_setup_context(ctx, inputs, output): - x, src, wigner, lmax = inputs - ctx.save_for_backward(x, src, wigner) - ctx.lmax = lmax - - -def _block_to_local_backward(ctx, grad_out): - x, src, wigner = ctx.saved_tensors - grad_x, grad_wigner = _block_to_local_bwd_op(grad_out, x, src, wigner, ctx.lmax) - return grad_x, None, grad_wigner, None - - -def _block_back_setup_context(ctx, inputs, output): - x_local, wigner, lmax = inputs - ctx.save_for_backward(x_local, wigner) - ctx.lmax = lmax - - -def _block_back_backward(ctx, grad_out): - x_local, wigner = ctx.saved_tensors - grad_x_local, grad_wigner = _block_back_bwd_op(grad_out, x_local, wigner, ctx.lmax) - return grad_x_local, grad_wigner, None - - -_block_to_local_op.register_autograd( - _block_to_local_backward, setup_context=_block_to_local_setup_context -) -_block_back_op.register_autograd( - _block_back_backward, setup_context=_block_back_setup_context -) - - -# ====================================================================== -# Public API -# ====================================================================== -# --- Public entry points ----------------------------------------------------- -def rotate_to_local_dense( - x: Tensor, src: Tensor, wigner: Tensor, coeff_index: Tensor, dim_full: int -) -> Tensor: - """Apply the general ``global -> local`` rotation. - - This entry point honors every value in ``coeff_index`` and supports any - reduced coefficient layout. It computes the same operation as - ``rotate_to_local_reference`` while avoiding materialized gather operands on - CUDA. - """ - return _rotate_to_local_op(x, src, wigner, coeff_index, int(dim_full)) - - -def rotate_back_dense( - x_local: Tensor, wigner: Tensor, coeff_index: Tensor, dim_full: int -) -> Tensor: - """Apply the general ``local -> global`` rotation. - - This entry point honors every value in ``coeff_index`` and supports any - reduced coefficient layout. It computes the same operation as - ``rotate_back_reference`` while avoiding materialized gather operands on - CUDA. - """ - return _rotate_back_op(x_local, wigner, coeff_index, int(dim_full)) - - -def rotate_to_local_block(x: Tensor, src: Tensor, wigner: Tensor, lmax: int) -> Tensor: - """Apply the block-diagonal ``global -> local`` rotation. - - Use this when the caller owns the invariant that the reduced layout is the - canonical m-major ``mmax=1`` layout for ``lmax``. The block kernel derives - the reduced row order from ``lmax`` and does not consume a coefficient-index - tensor. - """ - return _block_to_local_op(x, src, wigner, int(lmax)) - - -def rotate_back_block(x_local: Tensor, wigner: Tensor, lmax: int) -> Tensor: - """Apply the block-diagonal ``local -> global`` rotation. - - Use this when the caller owns the invariant that ``x_local`` is ordered in - the canonical m-major ``mmax=1`` layout for ``lmax``. The block kernel - derives the reduced column order from ``lmax`` and does not consume a - coefficient-index tensor. - """ - return _block_back_op(x_local, wigner, int(lmax)) - - -# ====================================================================== -# Layout-aware block rotate_back (per-focus SO(2) layout, mmax == 1) -# ====================================================================== -# Consumes the (E, F, D_m, Cf) focus layout produced by the SO(2) layers so the -# caller can skip the ``transpose(1, 2).contiguous()`` that would otherwise -# materialize (E, D_m, F * Cf) before the inverse rotation. - - -def _block_rotate_back_so2_impl( - x_local_4d: Tensor, wigner: Tensor, lmax: int -) -> Tensor: - if not _use_triton(x_local_4d): - n_edge, n_focus, reduced_dim, focus_dim = x_local_4d.shape - x_std = x_local_4d.transpose(1, 2).reshape( - n_edge, reduced_dim, n_focus * focus_dim - ) - coeff = build_m_major_index(int(lmax), 1, device=x_local_4d.device) - return rotate_back_reference(x_std, wigner, coeff, (int(lmax) + 1) ** 2) - return _launch_bd_back_so2_fwd(x_local_4d, wigner, int(lmax)) - - -def _block_rotate_back_so2_bwd_impl( - grad_out: Tensor, x_local_4d: Tensor, wigner: Tensor, lmax: int -) -> tuple[Tensor, Tensor]: - if not _use_triton(x_local_4d): - n_edge, n_focus, reduced_dim, focus_dim = x_local_4d.shape - x_std = x_local_4d.transpose(1, 2).reshape( - n_edge, reduced_dim, n_focus * focus_dim - ) - coeff = build_m_major_index(int(lmax), 1, device=x_local_4d.device) - grad_x_std, grad_wigner = _rotate_back_bwd_eager( - grad_out, x_std, wigner, coeff, (int(lmax) + 1) ** 2 - ) - grad_x_local = grad_x_std.reshape( - n_edge, reduced_dim, n_focus, focus_dim - ).transpose(1, 2) - return grad_x_local, grad_wigner - return _launch_bd_back_so2_bwd(grad_out.contiguous(), x_local_4d, wigner, int(lmax)) - - -_block_back_so2_op = torch.library.triton_op( - "dpa4_triton::rotate_back_block_so2", mutates_args=() -)(_block_rotate_back_so2_impl) - -_block_back_so2_bwd_op = torch.library.triton_op( - "dpa4_triton::rotate_back_block_so2_bwd", mutates_args=() -)(_block_rotate_back_so2_bwd_impl) - - -@_block_back_so2_op.register_fake -def _(x_local_4d, wigner, lmax): - n_edge, n_focus, _reduced, focus_dim = x_local_4d.shape - return x_local_4d.new_empty((n_edge, (int(lmax) + 1) ** 2, n_focus * focus_dim)) - - -@_block_back_so2_bwd_op.register_fake -def _(grad_out, x_local_4d, wigner, lmax): - return torch.empty_like(x_local_4d), torch.empty_like(wigner) - - -def _block_back_so2_setup_context(ctx, inputs, output): - x_local_4d, wigner, lmax = inputs - ctx.save_for_backward(x_local_4d, wigner) - ctx.lmax = lmax - - -def _block_back_so2_backward(ctx, grad_out): - x_local_4d, wigner = ctx.saved_tensors - grad_x_local, grad_wigner = _block_back_so2_bwd_op( - grad_out, x_local_4d, wigner, ctx.lmax - ) - return grad_x_local, grad_wigner, None - - -_block_back_so2_op.register_autograd( - _block_back_so2_backward, setup_context=_block_back_so2_setup_context -) - - -def rotate_back_block_so2(x_local_4d: Tensor, wigner: Tensor, lmax: int) -> Tensor: - """Block-diagonal ``local -> global`` rotation reading the per-focus layout. - - Parameters - ---------- - x_local_4d : Tensor - Local features with shape (E, F, reduced_dim, Cf) in the canonical m-major - ``mmax=1`` layout, where C_wide = F * Cf. - wigner : Tensor - Transposed Wigner-D with shape (E, D, D), D = (lmax + 1) ** 2. - lmax : int - Maximum degree. - - Returns - ------- - Tensor - Global-frame message with shape (E, D, C_wide). The per-focus to packed - channel mapping ``c = f * Cf + cf`` folds the inverse transpose into the - kernel addressing, avoiding an explicit copy. - """ - return _block_back_so2_op(x_local_4d, wigner, int(lmax)) diff --git a/deepmd/pt_expt/descriptor/dpa4_nn/wignerd.py b/deepmd/pt_expt/descriptor/dpa4_nn/wignerd.py new file mode 100644 index 0000000000..91c52c5159 --- /dev/null +++ b/deepmd/pt_expt/descriptor/dpa4_nn/wignerd.py @@ -0,0 +1,163 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""pt_expt Wigner-D calculator with an opt-in fused Triton monomial fast path. + +The dpmodel :class:`WignerDCalculator` is array-API only and evaluates the +degree ``l >= 2`` monomial design matrices through the dense power-table chain. +This wrapper injects the reference pt inference fast path around the two +monomial hot paths -- the shared ``l >= 3`` kernel and the ``l = 2`` degree-4 +contraction -- mirroring ``deepmd.pt.model.descriptor.sezm_nn.wignerd``. + +The fused monomial operator is sourced from the central +:mod:`deepmd.kernels.triton.sezm.wigner_monomials` package and gated by the +integer inference level ``DP_TRITON_INFER`` (see +:func:`deepmd.kernels.utils.triton_infer_level`); the fast path requires level +``>= 1``. It runs only during inference (``not self.training``) on CUDA, and +the operator self-guards Triton availability and falls back to an eager +reference off CUDA / on fp64, so importing this module is safe on CPU-only +environments; training and CPU / fp64 inference use the dpmodel dense path. +""" + +from __future__ import ( + annotations, +) + +from itertools import ( + product, +) +from typing import ( + Any, +) + +import numpy as np +import torch + +from deepmd.dpmodel import ( + DEFAULT_PRECISION, +) +from deepmd.dpmodel.descriptor.dpa4_nn.wignerd import ( + WignerDCalculator as WignerDCalculatorDP, +) +from deepmd.kernels.utils import ( + triton_infer_level, +) +from deepmd.pt_expt.common import ( + register_dpmodel_mapping, + torch_module, +) + + +@torch_module +class WignerDCalculator(WignerDCalculatorDP): + """Wigner-D calculator with an opt-in fused Triton monomial inference path.""" + + def __init__( + self, + lmax: int, + *, + eps: float = 1e-7, + precision: str = DEFAULT_PRECISION, + ) -> None: + super().__init__(lmax, eps=eps, precision=precision) + # Inference fast-path gate (``DP_TRITON_INFER >= 1``): read once at + # construction so it is a compile-time constant in the traced + # (``make_fx``) graph, and it only takes effect during inference. + self._use_triton_monomials = triton_infer_level() >= 1 + if self.lmax >= 2: + # Flatten the monomial exponent tables to Python constants in + # eager context: the fused monomial operator bakes them into the + # kernel at compile time, and a trace-time ``.tolist()`` would + # create unbacked symbols under ``make_fx`` and abort export. + self._monomial_exponents_flat: dict[str, list[int]] = {} + for exp_name in ("exp_l3", "exp_l4", "exp_l5", "exp_l6"): + exps = getattr(self.small_order_kernels, exp_name, None) + if exps is not None: + self._monomial_exponents_flat[exp_name] = [ + int(v) for v in exps.reshape(-1).tolist() + ] + # The l = 2 contraction tensor collapsed onto the 35 unique + # degree-4 monomials: column m of the coefficient matrix sums + # C_l2[:, :, p] over the 4^4 index tuples p whose component + # multiplicities equal the monomial exponents. + exp_l2: list[int] = [] + columns: list[np.ndarray] = [] + index_of: dict[tuple[int, int, int, int], int] = {} + c_l2 = self.small_order_kernels.C_l2 + for p in product(range(4), repeat=4): + counts = (p.count(0), p.count(1), p.count(2), p.count(3)) + if counts not in index_of: + index_of[counts] = len(index_of) + exp_l2.extend(counts) + columns.append(np.zeros_like(c_l2[:, :, 0, 0, 0, 0])) + columns[index_of[counts]] = ( + columns[index_of[counts]] + c_l2[:, :, p[0], p[1], p[2], p[3]] + ) + self._monomial_exponents_flat["exp_l2"] = exp_l2 + # Assigned as a numpy array so ``dpmodel_setattr`` registers it as a + # torch buffer (fp64, matching the other dpmodel Wigner constants). + self._l2_monomial_coeff = np.stack([c.reshape(-1) for c in columns], axis=0) + + def forward(self, *args: Any, **kwargs: Any) -> Any: + return self.call(*args, **kwargs) + + def _monomial_matrix( + self, + edge_quaternion: torch.Tensor, + exp_name: str, + max_power: int, + ) -> torch.Tensor: + """Evaluate one degree kernel's monomial basis, with the fused fast path. + + On the CUDA inference path the fused operator evaluates the monomials + in registers with the exponent table baked in at compile time (see + :mod:`deepmd.kernels.triton.sezm.wigner_monomials`); construction-time + solves and CPU targets keep the dense power-table chain. + """ + exps = self._monomial_exponents_flat.get(exp_name) + if ( + self._use_triton_monomials + and exps is not None + and edge_quaternion.is_cuda + and not self.training + ): + from deepmd.kernels.triton.sezm.wigner_monomials import ( + wigner_monomials, + ) + + return wigner_monomials(edge_quaternion, exps, max_power) + return super()._monomial_matrix(edge_quaternion, exp_name, max_power) + + def _compute_l2_block(self, edge_quaternion: torch.Tensor) -> torch.Tensor: + """Compute the ``l=2`` block from the degree-4 quaternion contraction. + + The fused inference path collapses the 256 rank-4 index tuples onto + the 35 unique degree-4 monomials, replacing the ``(E, 4, 4, 4, 4)`` + outer product with a monomial evaluation and one ``(E, 35) x (35, 25)`` + product with no large intermediate. + """ + exps = self._monomial_exponents_flat.get("exp_l2") + if ( + self._use_triton_monomials + and exps is not None + and edge_quaternion.is_cuda + and not self.training + ): + from deepmd.kernels.triton.sezm.wigner_monomials import ( + wigner_monomials, + ) + + monomials = wigner_monomials(edge_quaternion, exps, 4) + # ``_l2_monomial_coeff`` is stored as the fp64 dpmodel constant; cast + # it to the monomial dtype so the fused fp32 path multiplies operands + # of one dtype (mirrors the base's runtime cast of the Wigner + # constants to the edge dtype). + coeff = self._l2_monomial_coeff.to(monomials.dtype) + return torch.matmul(monomials, coeff).view(-1, 5, 5) + return super()._compute_l2_block(edge_quaternion) + + +# WignerDCalculator.deserialize raises NotImplementedError by design (its +# tables are derived constants); rebuild from the stored constructor args. +register_dpmodel_mapping( + WignerDCalculatorDP, + lambda v: WignerDCalculator(v.lmax, eps=v.eps, precision=v.precision), +) diff --git a/deepmd/pt_expt/infer/deep_eval.py b/deepmd/pt_expt/infer/deep_eval.py index 89a724fc2a..58b62aaf56 100644 --- a/deepmd/pt_expt/infer/deep_eval.py +++ b/deepmd/pt_expt/infer/deep_eval.py @@ -1620,17 +1620,40 @@ def _eval_model_spin( charge_spin_t = self._make_charge_spin_input(nframes, charge_spin) - # Call the model with spin. - model_inputs = ( - ext_coord_t, - ext_atype_t, - ext_spin_t, - nlist_t, - mapping_t, - fparam_t, - aparam_t, - charge_spin_t, - ) + # Build the lower inputs for the model's spin ABI. The native scheme + # shares the energy edge contract and feeds the owned-atom spins (the + # descriptor only needs local spins; ghost neighbours resolve to their + # local owners). The deepspin scheme keeps the extended nlist contract. + if self.metadata.get("lower_input_kind") == "edge_vec": + edge_schema = edge_schema_from_extended( + ext_coord_t, + ext_atype_t[:, :natoms], + nlist_t, + mapping_t, + ) + model_inputs = ( + edge_schema.coord, + edge_schema.atype, + edge_schema.edge_index, + edge_schema.edge_vec, + edge_schema.edge_scatter_index, + edge_schema.edge_mask, + spin_t, + fparam_t, + aparam_t, + charge_spin_t, + ) + else: + model_inputs = ( + ext_coord_t, + ext_atype_t, + ext_spin_t, + nlist_t, + mapping_t, + fparam_t, + aparam_t, + charge_spin_t, + ) if self._is_pt2: model_ret = self._pt2_runner(*model_inputs) else: diff --git a/deepmd/pt_expt/utils/edge_schema.py b/deepmd/pt_expt/utils/edge_schema.py index 5532bc9b6e..80871fdee9 100644 --- a/deepmd/pt_expt/utils/edge_schema.py +++ b/deepmd/pt_expt/utils/edge_schema.py @@ -141,7 +141,13 @@ def edge_schema_from_extended( edge_scatter_index, ) schema.coord = coord[:, :nloc, :].contiguous() if scatter_to_local else coord - schema.atype = atype[:, :nloc] + # The local-atom slice is a stride-(nall, 1) view when nloc < nall (always so + # with ghost atoms, and for the spin path where the source carries 2*nall + # columns). The compiled core flattens ``atype`` via ``reshape(-1)``, which + # ``torch.compile`` lowers to ``aten.view`` and rejects on a non-contiguous + # layout under symbolic shapes. Materialize a contiguous copy here, mirroring + # ``coord`` above. + schema.atype = atype[:, :nloc].contiguous() return schema diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 84b04e02ba..8afd7c52a6 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -29,6 +29,12 @@ from deepmd.utils.argcheck_nvnmd import ( nvnmd_args, ) +from deepmd.utils.eval_metrics import ( + ENERGY_FULL_VALIDATION_PROFILE, + FULL_VALIDATION_PROFILES, + SPIN_FULL_VALIDATION_PROFILE, + FullValidationMetricProfile, +) from deepmd.utils.plugin import ( Plugin, ) @@ -145,8 +151,9 @@ def spin_args() -> list[Argument]: doc_use_spin = ( "Whether to use atomic spin model for each atom type. " "List of boolean values with the shape of [ntypes] to specify which types use spin, " - f"or a list of integer values {doc_only_pt_supported} " - "to indicate the index of the type that uses spin." + f"or, {doc_only_pt_supported}, a list of the magnetic types given either as type " + 'indices or as element symbols (e.g. `["Fe"]`), which is expanded against ' + "`type_map` so that a large type map only needs its magnetic species named." ) doc_spin_norm = "The magnitude of atomic spin for each atom type with spin" doc_virtual_len = "The distance between virtual atom representing spin and its corresponding real atom for each atom type with spin" @@ -156,11 +163,25 @@ def spin_args() -> list[Argument]: "This factor is defined as the virtual distance divided by the magnitude of atomic spin " "for each atom type with spin. The virtual coordinate is defined as the real coordinate " "plus spin * virtual_scale. List of float values with shape of [ntypes] or [ntypes_spin] " - "or one single float value for all types, only used when use_spin is True for each atom type." + "or one single float value for all types, only used when use_spin is True for each atom type. " + "Required for the `deepspin` scheme; ignored by the `native` scheme." + ) + doc_scheme = ( + "The spin implementation scheme, only effective for the DPA4/SeZM model. " + "`native` injects the per-atom spin vector as an equivariant feature " + "(l=0 magnitude and l=1 direction) directly into the descriptor and " + "derives the magnetic force as the negative spin gradient of the energy, " + "without virtual atoms. `deepspin` uses the classical DeepSpin virtual-atom " + "representation and is the default. Other models always use the `deepspin` scheme." + ) + doc_allow_missing_label = ( + "Whether to admit training systems that lack a `spin` data file, filling their " + "per-atom spin with zeros instead of raising. Supported only by the SeZM/DPA4 " + "spin model; defaults to false." ) return [ - Argument("use_spin", [list[bool], list[int]], doc=doc_use_spin), + Argument("use_spin", [list[bool], list[int], list[str]], doc=doc_use_spin), Argument( "spin_norm", list[float], @@ -179,6 +200,20 @@ def spin_args() -> list[Argument]: optional=True, doc=doc_only_pt_supported + doc_virtual_scale, ), + Argument( + "scheme", + str, + optional=True, + default="deepspin", + doc=doc_only_pt_supported + doc_scheme, + ), + Argument( + "allow_missing_label", + bool, + optional=True, + default=False, + doc=doc_only_pt_supported + doc_allow_missing_label, + ), ] @@ -431,7 +466,15 @@ def descrpt_se_zm_args() -> list[Argument]: "block `i` uses node degree `l_schedule[i] + extra_node_l`, while SO(2) " "message passing still uses `l_schedule[i]`." ) - doc_n_blocks = "Number of blocks (only used when `l_schedule` is None)." + doc_n_blocks = ( + "Number of interaction blocks (only used when `l_schedule` is None). " + "`0` disables the interaction blocks and builds the zero-block " + "descriptor: type embedding, optional env FiLM and geometric initial " + "embedding, then the final SO(3) read-out. The backbone degree is taken " + "from `lmax` (plus `extra_node_l`); geometry then enters only through " + "the geometric initial embedding, so `use_env_seed=True` with " + "`lmax + extra_node_l > 0` is required for a non-trivial descriptor." + ) doc_block_attn_res = ( "Descriptor-level block attention residual mode over block history " "`[x0, b1, b2, ...]`, where each block summary is the sum of the SO(2) " @@ -589,6 +632,12 @@ def descrpt_se_zm_args() -> list[Argument]: "read-out degree equals the node degree of the last interaction block; " "the Wigner-D frame order follows `kmax`." ) + doc_readout_layers = ( + "Number of stacked equivariant residual read-out FFNs (default 1). Each " + "layer is an `x + FFN(x)` residual block sharing the read-out degree; " + "intermediate layers keep the full SO(3) tensor so high-degree geometry " + "keeps folding into l=0, and only the final layer slices the l=0 channel." + ) doc_lebedev_quadrature = ( "Either one boolean applied to both S2 branches, or two booleans " "`[so2_enabled, ffn_enabled]` aligned with `s2_activation`. If a branch " @@ -947,6 +996,15 @@ def descrpt_se_zm_args() -> list[Argument]: extra_check_errmsg="must be one of 'none', 'glu', or 'mlp'", doc=doc_only_pt_supported + doc_so3_readout, ), + Argument( + "readout_layers", + int, + optional=True, + default=1, + extra_check=lambda x: x >= 1, + extra_check_errmsg="must be >= 1", + doc=doc_only_pt_supported + doc_readout_layers, + ), Argument( "lebedev_quadrature", [bool, list[bool]], @@ -5502,14 +5560,11 @@ def training_extra_check(data: dict | None) -> bool: ) -FULL_VALIDATION_METRIC_PREFS = { - "e:mae": ("start_pref_e", "limit_pref_e"), - "e:rmse": ("start_pref_e", "limit_pref_e"), - "f:mae": ("start_pref_f", "limit_pref_f"), - "f:rmse": ("start_pref_f", "limit_pref_f"), - "v:mae": ("start_pref_v", "limit_pref_v"), - "v:rmse": ("start_pref_v", "limit_pref_v"), -} +def _full_validation_profile_for_loss(loss_type: str) -> FullValidationMetricProfile: + """Return the full validation metric profile for a loss type.""" + if loss_type == "ener_spin": + return SPIN_FULL_VALIDATION_PROFILE + return ENERGY_FULL_VALIDATION_PROFILE def normalize_full_validation_metric(metric: str) -> str: @@ -5518,19 +5573,26 @@ def normalize_full_validation_metric(metric: str) -> str: def is_valid_full_validation_metric(metric: str) -> bool: - """Check whether a full validation metric is supported.""" - return normalize_full_validation_metric(metric) in FULL_VALIDATION_METRIC_PREFS + """Check whether a metric is supported by any full validation profile.""" + normalized_metric = normalize_full_validation_metric(metric) + return any( + normalized_metric in profile.metric_key_map + for profile in FULL_VALIDATION_PROFILES.values() + ) -def get_full_validation_metric_prefactors(metric: str) -> tuple[str, str]: - """Get the prefactor keys required by a full validation metric.""" +def get_full_validation_metric_prefactors( + metric: str, profile: FullValidationMetricProfile +) -> tuple[str, str]: + """Get the loss prefactor keys required by a full validation metric.""" normalized_metric = normalize_full_validation_metric(metric) - if normalized_metric not in FULL_VALIDATION_METRIC_PREFS: - valid_metrics = ", ".join(item.upper() for item in FULL_VALIDATION_METRIC_PREFS) + if normalized_metric not in profile.prefactor_by_metric: + valid_metrics = ", ".join(item.upper() for item in profile.prefactor_by_metric) raise ValueError( - f"validating.validation_metric must be one of {valid_metrics}, got {metric!r}." + "validating.validation_metric must be one of " + f"{valid_metrics} for {profile.name} training, got {metric!r}." ) - return FULL_VALIDATION_METRIC_PREFS[normalized_metric] + return profile.prefactor_by_metric[normalized_metric] def resolve_full_validation_start_step( @@ -5547,7 +5609,21 @@ def resolve_full_validation_start_step( def validating_args() -> Argument: """Generate full validation arguments.""" - valid_metrics = ", ".join(item.upper() for item in FULL_VALIDATION_METRIC_PREFS) + energy_metrics = ", ".join( + item.upper() for item in ENERGY_FULL_VALIDATION_PROFILE.metric_key_map + ) + spin_metrics = ", ".join( + item.upper() for item in SPIN_FULL_VALIDATION_PROFILE.metric_key_map + ) + valid_metrics = ", ".join( + sorted( + { + metric.upper() + for profile in FULL_VALIDATION_PROFILES.values() + for metric in profile.metric_key_map + } + ) + ) doc_full_validation_supported = ( "(Supported Backend: PyTorch, PyTorch Experimental, JAX) " ) @@ -5555,7 +5631,7 @@ def validating_args() -> Argument: "Whether to run an additional full validation pass over the entire " "validation dataset during training. This flow is independent from the " "display-time validation controlled by `training.disp_freq`. Only " - "single-task energy training is supported. Multi-task, spin-energy, " + "single-task energy or spin-energy training is supported. Multi-task " "and `training.zero_stage >= 2` are not supported." ) doc_validation_freq = ( @@ -5571,12 +5647,12 @@ def validating_args() -> Argument: "`training.save_ckpt`." ) doc_ema_full_validation = ( - "Whether to additionally run the same full validation flow on the " - "EMA-smoothed model when `validating.full_validation=true`. This reuses " - "the existing full validation schedule, metric, start step, and " - "best-checkpoint settings, writes results to an EMA-specific validation " - "log such as `val_ema.log`, and saves EMA best checkpoints with a " - "`best_ema.ckpt` prefix. Requires " + "Whether to run the full validation flow on the EMA-smoothed model. " + "This is independent from `validating.full_validation` and may be " + "enabled on its own to validate only the EMA model. It reuses the full " + "validation schedule, metric, and start step, writes results to an " + "EMA-specific validation log such as `val_ema.log`, and saves EMA best " + "checkpoints with a `best_ema.ckpt` prefix. Requires " "`training.enable_ema=true`." ) doc_max_best_ckpt = ( @@ -5586,9 +5662,12 @@ def validating_args() -> Argument: ) doc_validation_metric = ( "Metric used to determine the best checkpoint during full validation. " - f"Supported values are {valid_metrics}. The string is case-insensitive. " - "`E` and `V` are per-atom metrics; `F` uses component-wise force errors, " - "matching `dp test`. The corresponding loss prefactors must not both be 0." + "The string is case-insensitive. For energy training the supported " + f"values are {energy_metrics}; for spin-energy training they are " + f"{spin_metrics}. `E` and `V` are per-atom metrics, `F` and `FR` use " + "component-wise force errors, and `FM` uses magnetic-force errors, " + "matching `dp test`. The corresponding loss prefactors must not both " + "be 0." ) doc_full_val_file = "The file for writing full validation results only. This file is independent from `training.disp_file`." doc_full_val_start = ( @@ -5616,6 +5695,15 @@ def validating_args() -> Argument: "precedence over this option. This does not affect training forwards, " "which are controlled by `model.enable_tf32`." ) + doc_amp_infer = ( + "Whether to enable bf16 automatic mixed precision for eval-time forwards " + "(including regular validation and full validation). When `true`, this " + "flag is translated into `DP_AMP_INFER=1` at trainer startup before any " + "model is constructed. A manually exported `DP_AMP_INFER` takes " + "precedence over this option. This only affects SeZM/DPA4 descriptors " + "with `descriptor.use_amp=true`; training AMP remains controlled by " + "`descriptor.use_amp`." + ) args = [ Argument( "full_validation", @@ -5670,10 +5758,7 @@ def validating_args() -> Argument: default="E:MAE", doc=doc_full_validation_supported + doc_validation_metric, extra_check=is_valid_full_validation_metric, - extra_check_errmsg=( - "must be one of " - + ", ".join(item.upper() for item in FULL_VALIDATION_METRIC_PREFS) - ), + extra_check_errmsg="must be one of " + valid_metrics, ), Argument( "full_val_file", @@ -5705,6 +5790,13 @@ def validating_args() -> Argument: default=False, doc=doc_only_pt_supported + doc_tf32_infer, ), + Argument( + "amp_infer", + bool, + optional=True, + default=False, + doc=doc_only_pt_supported + doc_amp_infer, + ), ] return Argument( "validating", @@ -5714,7 +5806,8 @@ def validating_args() -> Argument: optional=True, default={}, doc=doc_full_validation_supported - + "Independent full validation options for single-task energy training.", + + "Independent full validation options for single-task energy or " + + "spin-energy training.", ) @@ -5726,7 +5819,7 @@ def validate_full_validation_config( training_params = data.get("training", {}) or {} full_validation_enabled = bool(validating.get("full_validation", False)) ema_full_validation_enabled = bool(validating.get("ema_full_validation", False)) - if not full_validation_enabled: + if not full_validation_enabled and not ema_full_validation_enabled: return if ema_full_validation_enabled and not training_params.get("enable_ema", False): raise ValueError( @@ -5735,13 +5828,6 @@ def validate_full_validation_config( if float(validating.get("full_val_start", 0.0)) == 1.0: return - metric = str(validating.get("validation_metric", "E:MAE")) - if not is_valid_full_validation_metric(metric): - valid_metrics = ", ".join(item.upper() for item in FULL_VALIDATION_METRIC_PREFS) - raise ValueError( - f"validating.validation_metric must be one of {valid_metrics}, got {metric!r}." - ) - if multi_task: raise ValueError( "validating.full_validation only supports single-task energy training; multi-task training is not supported." @@ -5749,20 +5835,24 @@ def validate_full_validation_config( loss_params = data.get("loss", {}) loss_type = loss_params.get("type", "ener") - if loss_type == "ener_spin": + if loss_type not in ("ener", "ener_spin"): raise ValueError( - "validating.full_validation only supports single-task energy " - "training; spin-energy training is not supported." + "validating.full_validation only supports single-task energy or " + f"spin-energy training; got loss.type={loss_type!r}." ) - if loss_type != "ener": + profile = _full_validation_profile_for_loss(loss_type) + + metric = str(validating.get("validation_metric", "E:MAE")) + if normalize_full_validation_metric(metric) not in profile.metric_key_map: + valid_metrics = ", ".join(item.upper() for item in profile.metric_key_map) raise ValueError( - "validating.full_validation only supports single-task energy " - f"training with loss.type='ener'; got loss.type={loss_type!r}." + "validating.validation_metric must be one of " + f"{valid_metrics} for {profile.name} training, got {metric!r}." ) if not training_params.get("validation_data"): raise ValueError( - "full validation requires `training.validation_data`. It is only supported for single-task energy training." + "full validation requires `training.validation_data`. It is only supported for single-task energy or spin-energy training." ) zero_stage = int(training_params.get("zero_stage", 0)) @@ -5772,7 +5862,9 @@ def validate_full_validation_config( "training with training.zero_stage < 2." ) - pref_start_key, pref_limit_key = get_full_validation_metric_prefactors(metric) + pref_start_key, pref_limit_key = get_full_validation_metric_prefactors( + metric, profile + ) pref_start = float(loss_params.get(pref_start_key, 0.0)) pref_limit = float(loss_params.get(pref_limit_key, 0.0)) if pref_start == 0.0 or pref_limit == 0.0: diff --git a/deepmd/utils/eval_metrics.py b/deepmd/utils/eval_metrics.py index ed210c9b78..d75c0b7385 100644 --- a/deepmd/utils/eval_metrics.py +++ b/deepmd/utils/eval_metrics.py @@ -7,30 +7,22 @@ from dataclasses import ( dataclass, ) +from typing import ( + TYPE_CHECKING, +) import numpy as np -FULL_VALIDATION_METRIC_KEY_MAP = { - "e:mae": "mae_e_per_atom", - "e:rmse": "rmse_e_per_atom", - "f:mae": "mae_f", - "f:rmse": "rmse_f", - "v:mae": "mae_v_per_atom", - "v:rmse": "rmse_v_per_atom", -} +if TYPE_CHECKING: + from collections.abc import ( + Callable, + ) + FULL_VALIDATION_WEIGHTED_METRIC_KEYS = { "energy_per_atom": ("mae_e_per_atom", "rmse_e_per_atom"), "force": ("mae_f", "rmse_f"), "virial_per_atom": ("mae_v_per_atom", "rmse_v_per_atom"), } -FULL_VALIDATION_METRIC_FAMILY_BY_KEY = { - "mae_e_per_atom": "e", - "rmse_e_per_atom": "e", - "mae_f": "f", - "rmse_f": "f", - "mae_v_per_atom": "v", - "rmse_v_per_atom": "v", -} DP_TEST_WEIGHTED_METRIC_KEYS = { "energy": ("mae_e", "rmse_e"), "energy_per_atom": ("mae_ea", "rmse_ea"), @@ -224,3 +216,289 @@ def compute_spin_force_metrics( force_real=force_real, force_magnetic=force_magnetic, ) + + +def _spin_force_metrics_from_prediction( + prediction: dict[str, np.ndarray], + test_data: dict[str, np.ndarray], +) -> SpinForceEvalMetrics: + """Align predicted and reference forces into real and magnetic subsets. + + Real forces cover all atoms, while magnetic forces are restricted to the + magnetic atoms selected by the boolean ``mask_mag`` of shape + ``(nframes, natoms)``. The magnetic term is produced only when + ``find_force_mag`` is set and both prediction and reference magnetic + forces are present, matching the ``dp test`` spin convention. + + Parameters + ---------- + prediction : dict[str, np.ndarray] + Model predictions containing ``force`` and, for the magnetic term, + ``force_mag`` and ``mask_mag``. + test_data : dict[str, np.ndarray] + Reference labels and ``find_*`` availability flags for one system. + + Returns + ------- + SpinForceEvalMetrics + The real-atom and (optionally) magnetic-atom force errors. + """ + force_real_prediction = prediction["force"].reshape(-1, 3) + force_real_reference = test_data["force"].reshape(-1, 3) + has_force_mag = ( + bool(test_data.get("find_force_mag", 0.0)) + and "force_mag" in prediction + and "force_mag" in test_data + ) + if not has_force_mag: + return compute_spin_force_metrics( + force_real_prediction=force_real_prediction, + force_real_reference=force_real_reference, + ) + magnetic_mask = prediction["mask_mag"].reshape(-1).astype(bool) + return compute_spin_force_metrics( + force_real_prediction=force_real_prediction, + force_real_reference=force_real_reference, + force_magnetic_prediction=prediction["force_mag"].reshape(-1, 3)[magnetic_mask], + force_magnetic_reference=test_data["force_mag"].reshape(-1, 3)[magnetic_mask], + ) + + +def compute_full_validation_energy_metrics( + prediction: dict[str, np.ndarray], + test_data: dict[str, np.ndarray], + natoms: int, + has_pbc: bool, +) -> dict[str, tuple[float, float]]: + """Compute energy-type full validation metrics for one system. + + Parameters + ---------- + prediction : dict[str, np.ndarray] + Model predictions containing ``energy``, ``force`` and optionally + ``virial``. + test_data : dict[str, np.ndarray] + Reference labels and ``find_*`` availability flags for one system. + natoms : int + The number of atoms per frame, used for per-atom normalization. + has_pbc : bool + Whether the system is periodic, gating the virial metrics. + + Returns + ------- + dict[str, tuple[float, float]] + Weighted-average-ready ``(value, weight)`` pairs keyed by metric. + """ + metrics = compute_energy_type_metrics(prediction, test_data, natoms, has_pbc) + return metrics.as_weighted_average_errors(FULL_VALIDATION_WEIGHTED_METRIC_KEYS) + + +def compute_full_validation_spin_metrics( + prediction: dict[str, np.ndarray], + test_data: dict[str, np.ndarray], + natoms: int, + has_pbc: bool, +) -> dict[str, tuple[float, float]]: + """Compute spin-energy full validation metrics for one system. + + The energy term reuses per-atom energy errors. Forces are split into a + real-atom term over all atoms and a magnetic term over the magnetic atoms + selected by ``mask_mag``. Spin models do not report virial, so no virial + metric is produced. + + Parameters + ---------- + prediction : dict[str, np.ndarray] + Model predictions containing ``energy``, ``force``, ``force_mag`` and + the boolean ``mask_mag``. + test_data : dict[str, np.ndarray] + Reference labels and ``find_*`` availability flags for one system. + natoms : int + The number of atoms per frame, used for per-atom normalization. + has_pbc : bool + Unused; spin full validation never reports virial. Present to keep a + uniform profile signature. + + Returns + ------- + dict[str, tuple[float, float]] + Weighted-average-ready ``(value, weight)`` pairs keyed by metric. + """ + errors: dict[str, tuple[float, float]] = {} + if bool(test_data.get("find_energy", 0.0)): + energy_per_atom = compute_error_stat( + prediction["energy"].reshape(-1, 1), + test_data["energy"].reshape(-1, 1), + scale=1.0 / natoms, + ) + errors.update( + energy_per_atom.as_weighted_average_errors( + "mae_e_per_atom", "rmse_e_per_atom" + ) + ) + if bool(test_data.get("find_force", 0.0)): + spin_metrics = _spin_force_metrics_from_prediction(prediction, test_data) + errors.update( + spin_metrics.as_weighted_average_errors(DP_TEST_SPIN_WEIGHTED_METRIC_KEYS) + ) + return errors + + +@dataclass(frozen=True) +class FullValidationMetricProfile: + """Metric family definition for one full validation model class. + + Bundles every aspect that differs between energy-type and spin-energy full + validation so the validator stays data-driven instead of branching on the + model class: + + - ``column_order`` defines the ``val.log`` table layout as + ``(header_label, metric_key)`` pairs. + - ``metric_key_map`` maps a normalized ``validation_metric`` token (such as + ``"e:mae"``) to an internal metric key (such as ``"mae_e_per_atom"``). + - ``metric_family_by_key`` maps an internal metric key back to its family, + used for display-unit lookup. + - ``unit_by_family`` maps a family to its ``(display_unit, scale)``. + - ``prefactor_by_metric`` maps a metric token to the loss prefactor keys + that must both be active for the metric to be trainable. + - ``needs_spin`` indicates whether the model consumes a spin input and + emits magnetic forces. + - ``log_header_note`` is the one-line table legend written to ``val.log``. + - ``compute_system_metrics`` turns one system's prediction and reference + into weighted ``(value, weight)`` metric pairs. + + Attributes + ---------- + name : str + Profile identifier, either ``"energy"`` or ``"spin"``. + column_order : tuple[tuple[str, str], ...] + Ordered ``(header_label, metric_key)`` pairs for the log table. + metric_key_map : dict[str, str] + Normalized metric token to internal metric key. + metric_family_by_key : dict[str, str] + Internal metric key to family identifier. + unit_by_family : dict[str, tuple[str, float]] + Family identifier to ``(display_unit, scale)``. + prefactor_by_metric : dict[str, tuple[str, str]] + Normalized metric token to ``(start_pref_key, limit_pref_key)``. + needs_spin : bool + Whether the profile requires spin input and magnetic-force outputs. + log_header_note : str + One-line legend describing the metric columns. + compute_system_metrics : Callable + Routine computing weighted metric pairs for one system, with signature + ``(prediction, test_data, natoms, has_pbc) -> dict``. + """ + + name: str + column_order: tuple[tuple[str, str], ...] + metric_key_map: dict[str, str] + metric_family_by_key: dict[str, str] + unit_by_family: dict[str, tuple[str, float]] + prefactor_by_metric: dict[str, tuple[str, str]] + needs_spin: bool + log_header_note: str + compute_system_metrics: Callable[ + [dict[str, np.ndarray], dict[str, np.ndarray], int, bool], + dict[str, tuple[float, float]], + ] + + +ENERGY_FULL_VALIDATION_PROFILE = FullValidationMetricProfile( + name="energy", + column_order=( + ("E_MAE", "mae_e_per_atom"), + ("E_RMSE", "rmse_e_per_atom"), + ("F_MAE", "mae_f"), + ("F_RMSE", "rmse_f"), + ("V_MAE", "mae_v_per_atom"), + ("V_RMSE", "rmse_v_per_atom"), + ), + metric_key_map={ + "e:mae": "mae_e_per_atom", + "e:rmse": "rmse_e_per_atom", + "f:mae": "mae_f", + "f:rmse": "rmse_f", + "v:mae": "mae_v_per_atom", + "v:rmse": "rmse_v_per_atom", + }, + metric_family_by_key={ + "mae_e_per_atom": "e", + "rmse_e_per_atom": "e", + "mae_f": "f", + "rmse_f": "f", + "mae_v_per_atom": "v", + "rmse_v_per_atom": "v", + }, + unit_by_family={ + "e": ("meV/atom", 1000.0), + "f": ("meV/Å", 1000.0), + "v": ("meV/atom", 1000.0), + }, + prefactor_by_metric={ + "e:mae": ("start_pref_e", "limit_pref_e"), + "e:rmse": ("start_pref_e", "limit_pref_e"), + "f:mae": ("start_pref_f", "limit_pref_f"), + "f:rmse": ("start_pref_f", "limit_pref_f"), + "v:mae": ("start_pref_v", "limit_pref_v"), + "v:rmse": ("start_pref_v", "limit_pref_v"), + }, + needs_spin=False, + log_header_note=( + "# E uses per-atom energy, F uses component-wise force errors, " + "and V uses virial normalized by natoms.\n" + ), + compute_system_metrics=compute_full_validation_energy_metrics, +) + +SPIN_FULL_VALIDATION_PROFILE = FullValidationMetricProfile( + name="spin", + column_order=( + ("E_MAE", "mae_e_per_atom"), + ("E_RMSE", "rmse_e_per_atom"), + ("FR_MAE", "mae_fr"), + ("FR_RMSE", "rmse_fr"), + ("FM_MAE", "mae_fm"), + ("FM_RMSE", "rmse_fm"), + ), + metric_key_map={ + "e:mae": "mae_e_per_atom", + "e:rmse": "rmse_e_per_atom", + "fr:mae": "mae_fr", + "fr:rmse": "rmse_fr", + "fm:mae": "mae_fm", + "fm:rmse": "rmse_fm", + }, + metric_family_by_key={ + "mae_e_per_atom": "e", + "rmse_e_per_atom": "e", + "mae_fr": "fr", + "rmse_fr": "fr", + "mae_fm": "fm", + "rmse_fm": "fm", + }, + unit_by_family={ + "e": ("meV/atom", 1000.0), + "fr": ("meV/Å", 1000.0), + "fm": ("meV/μB", 1000.0), + }, + prefactor_by_metric={ + "e:mae": ("start_pref_e", "limit_pref_e"), + "e:rmse": ("start_pref_e", "limit_pref_e"), + "fr:mae": ("start_pref_fr", "limit_pref_fr"), + "fr:rmse": ("start_pref_fr", "limit_pref_fr"), + "fm:mae": ("start_pref_fm", "limit_pref_fm"), + "fm:rmse": ("start_pref_fm", "limit_pref_fm"), + }, + needs_spin=True, + log_header_note=( + "# E uses per-atom energy, FR uses component-wise real-atom force " + "errors, and FM uses magnetic-atom force errors.\n" + ), + compute_system_metrics=compute_full_validation_spin_metrics, +) + +FULL_VALIDATION_PROFILES: dict[str, FullValidationMetricProfile] = { + ENERGY_FULL_VALIDATION_PROFILE.name: ENERGY_FULL_VALIDATION_PROFILE, + SPIN_FULL_VALIDATION_PROFILE.name: SPIN_FULL_VALIDATION_PROFILE, +} diff --git a/deepmd/utils/spin.py b/deepmd/utils/spin.py index aed82cae8b..b03cfb07bd 100644 --- a/deepmd/utils/spin.py +++ b/deepmd/utils/spin.py @@ -30,18 +30,25 @@ class Spin: The virtual coordinate is defined as the real coordinate plus spin * virtual_scale. List of float values with shape of [ntypes] or [ntypes_spin] or one single float value for all types, only used when use_spin is True for each atom type. + allow_missing_label: bool + Whether a training system that lacks a ``spin`` data file is admitted by + filling its per-atom spin with zeros instead of raising. Supported only by + the SeZM/DPA4 spin model. As a data-loading option it is excluded from + serialization. """ def __init__( self, use_spin: list[bool], virtual_scale: list[float] | float, + allow_missing_label: bool = False, ) -> None: type_dtype = np.int32 self.ntypes_real = len(use_spin) self.ntypes_spin = use_spin.count(True) self.use_spin = np.array(use_spin) self.spin_mask = self.use_spin.astype(np.int64) + self.allow_missing_label = bool(allow_missing_label) self.ntypes_real_and_spin = self.ntypes_real + self.ntypes_spin self.ntypes_placeholder = self.ntypes_real - self.ntypes_spin self.ntypes_input = 2 * self.ntypes_real # with placeholder for input types diff --git a/doc/model/dpa4.md b/doc/model/dpa4.md index 325f1074b1..94270a8098 100644 --- a/doc/model/dpa4.md +++ b/doc/model/dpa4.md @@ -79,8 +79,10 @@ unnecessary and not recommended (see [Hardware selection](#hardware-selection)). descriptors. On the conservative **energy** path it is only an initial neighbor-search capacity that grows on demand, so it never truncates the neighbor list and you do not need to size it to the true maximum neighbor count. -Only the denoising (`dens`) and spin paths cap the list at `sum(sel)`. You can -also set `sel` to `auto` or `auto:factor` to size it from the training data. +The native spin scheme shares this energy path, so it grows on demand too; only +the denoising (`dens`) path and the `deepspin` spin scheme cap the list at +`sum(sel)`. You can also set `sel` to `auto` or `auto:factor` to size it from +the training data. ::: ### Main options @@ -174,8 +176,77 @@ default training path. See `examples/water/dpa4/input_dens.json` for an example. ### Spin -DPA4/SeZM supports the DeePMD-kit spin convention. Keep the model type and add -the standard `model.spin` block: +DPA4/SeZM supports the DeePMD-kit spin convention through the standard +`model.spin` block. Two schemes are available, selected by `model.spin.scheme`: + +- `native` — the per-atom spin vector enters the descriptor as an + equivariant feature, and the magnetic force is the spin gradient of the + energy. No virtual atoms are introduced, so the neighbor list and type map + keep their real-system sizes. +- `deepspin` (default) — the classical DeepSpin representation, in which each + magnetic atom is paired with a virtual atom displaced along its spin. It is + the default, so a `model.spin` block without an explicit `scheme` reproduces + the classical behaviour shared with every non-SeZM spin model. + +Both schemes train against the conservative `ener_spin` loss, share the same +`spin` / `force_mag` data convention, and are not combined with the `dens` mode. +Because the dataset and loss are identical, switching `scheme` does not require +any change to the data or the loss block. Complete inputs are in +`examples/spin/dpa4/`: `input.json` for the native scheme and +`input-deepspin.json` for the deepspin scheme. See +[training spin energy models](train-energy-spin.md) for the general workflow. + +`use_spin` accepts a per-type boolean list or, to avoid enumerating a large +`type_map`, the list of magnetic species as type indices or element symbols +(e.g. `"use_spin": ["Fe"]`), expanded against `type_map`. For the native scheme, +`model.spin.allow_missing_label` additionally admits training systems that lack a +`spin` data file by filling their spin with zeros; since a zero spin reduces the +native descriptor to its spin-free form, a model can be trained on a mixture of +spin-labelled and spin-free systems, or pretrained on spin-free data and +fine-tuned on spin-labelled data without changing its type map or parameters. + +#### Native scheme + +```json +{ + "model": { + "type": "dpa4", + "type_map": [ + "Ni", + "O" + ], + "spin": { + "use_spin": [ + true, + false + ], + "scheme": "native" + }, + "descriptor": { + "rcut": 6.0 + } + } +} +``` + +`use_spin` marks which atom types carry spin. Both the conservative force and +the magnetic force come from a single energy gradient: + +```math +\mathbf{F}_i = -\frac{\partial E}{\partial \mathbf{r}_i}, +\qquad +\mathbf{F}^{\mathrm{mag}}_i = -\frac{\partial E}{\partial \mathbf{s}_i}, +``` + +where $\mathbf{s}_i$ is the input spin vector. The model does not rescale the +spin internally, so $\mathbf{s}_i$ keeps the dataset's `spin` convention and the +magnetic force is reported in the matching units of `force_mag` -- there is no +`virtual_scale` factor in this scheme. The magnetic force is reported on the +magnetic atom types only, matching the `force_mag` label. The native scheme +relies on the descriptor's angular degrees to represent the spin direction, so +it requires `lmax >= 1` (the default). + +#### DeepSpin virtual-atom scheme (default) ```json { @@ -192,7 +263,8 @@ the standard `model.spin` block: ], "virtual_scale": [ 0.314 - ] + ], + "scheme": "deepspin" }, "descriptor": { "sel": 120, @@ -202,9 +274,10 @@ the standard `model.spin` block: } ``` -The spin path uses the conservative `ener_spin` loss and is not combined with -the `dens` mode. See [training spin energy models](train-energy-spin.md) and -`examples/water/dpa4/input-spin.json`. +The `deepspin` scheme augments each magnetic atom with a virtual atom at +`coord + spin * virtual_scale`, which doubles the internal neighbor capacity and +type map. `virtual_scale` is required by this scheme and ignored by the native +scheme. ### Multi-task / shared fitting @@ -294,29 +367,36 @@ Three options control training precision and the compiled path: Inference behavior is controlled by environment variables, each with an equivalent input-file option used during training validation: -| Environment variable | Input-file option | Default | Effect | -| -------------------- | --------------------------- | ------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `DP_COMPILE_INFER` | `validating.compiled_infer` | off | Use the compile path for evaluation/inference. Same `torch==2.11` / CUDA ≥ 12.6 requirements as `model.use_compile`. | -| `DP_TF32_INFER` | `validating.tf32_infer` | `0` (highest) | float32 matmul precision for inference: `0` highest, `1` high, `2` medium. Higher values improve throughput but make the potential energy surface less smooth. | -| `DP_TRITON_INFER` | — | off | Fused block-diagonal Triton kernels for the SO(2) Wigner-D rotation (CUDA eval only). Lower latency and peak memory, numerically equivalent to the dense path with full float32 accumulation. Compatible with `DP_COMPILE_INFER`. | +| Environment variable | Input-file option | Default | Effect | +| -------------------- | --------------------------- | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `DP_COMPILE_INFER` | `validating.compiled_infer` | off | Use the compile path for evaluation/inference. Same `torch==2.11` / CUDA ≥ 12.6 requirements as `model.use_compile`. | +| `DP_TF32_INFER` | `validating.tf32_infer` | `0` (highest) | float32 matmul precision for inference: `0` highest, `1` high, `2` medium. Higher values improve throughput but make the potential energy surface less smooth. | +| `DP_AMP_INFER` | `validating.amp_infer` | off | bf16 autocast inside the descriptor interaction blocks for inference when `descriptor.use_amp=true`. Usually keeps aggregate MAE similar but can make the potential energy surface less smooth. | +| `DP_TRITON_INFER` | — | `0` | Triton inference kernel level `0`-`3` (CUDA eval only, compatible with `DP_COMPILE_INFER`). `1`: universal fused kernels, numerically equivalent to the dense path with full float32 accumulation. `2`: adds the table-configured fused SO(2) value path and edge-block backward kernels (still exact float32). `3`: additionally runs the SO(2) mixing stack on fp16 tensor cores with split compensation — roughly float32-level accuracy (maximum force deviation about 4e-6 eV/Å on a 4-thousand-atom system) at a substantial speedup; only shapes validated by the tuning sweep are affected. Levels 2 and 3 read launch tables tuned per GPU model (H20 ships built in); on other GPUs the kernels fall back to conservative configurations, and `dp --pt freeze` tunes the missing entries automatically on the local GPU before exporting (a one-off sweep of a few minutes, baked into the `.pt2`). | -Accepted boolean values are `1`/`true`/`yes`/`on` and `0`/`false`/`no`/`off`. +Accepted boolean values for the other switches are `1`/`true`/`yes`/`on` and +`0`/`false`/`no`/`off`; `DP_TRITON_INFER` accepts only the numeric levels. Shell exports take precedence over the input-file options and over values written in the input; they are read when the model is constructed and changing them afterward has no effect. For molecular dynamics and other workflows sensitive to the smoothness of the -potential energy surface, keep `DP_TF32_INFER=0`. `DP_TRITON_INFER=1` retains -full float32 accumulation regardless of `DP_TF32_INFER` and is therefore safe -for those workflows. +potential energy surface, keep `DP_TF32_INFER=0` and `DP_AMP_INFER=0`. +`DP_AMP_INFER` can coexist with `DP_TF32_INFER`, but bf16 autocast dominates +the eligible operations it covers, so TF32 usually adds little extra throughput +there. `DP_TRITON_INFER` levels `1` and `2` retain full float32 accumulation +regardless of the precision policy and are therefore safe for those workflows; +level `3` perturbs forces at the 2^-22 rounding scale (three orders of +magnitude finer than TF32) and is the recommended fast setting once validated +for the target system. :::{important} Set these variables **before** running `dp --pt freeze`. The exported `.pt2` is -an AOTInductor artifact, so the SO(2) rotation branch (`DP_TRITON_INFER`) and -the matmul precision (`DP_TF32_INFER`) are captured into the graph at export -time and are **not** re-evaluated when the `.pt2` is later loaded by ASE or -LAMMPS. A frozen `.pt2` runs a forward-only package, so training-time -memory-saving switches do not apply to it. +an AOTInductor artifact, so the SO(2) rotation branch (`DP_TRITON_INFER`), the +matmul precision (`DP_TF32_INFER`), and inference AMP (`DP_AMP_INFER`) are +captured into the graph at export time and are **not** re-evaluated when the +`.pt2` is later loaded by ASE or LAMMPS. A frozen `.pt2` runs a forward-only +package, so training-time memory-saving switches do not apply to it. ::: ### Hardware selection diff --git a/examples/water/dpa4/input-spin.json b/examples/spin/dpa4/input-deepspin.json similarity index 92% rename from examples/water/dpa4/input-spin.json rename to examples/spin/dpa4/input-deepspin.json index c82754212d..c3824c947e 100644 --- a/examples/water/dpa4/input-spin.json +++ b/examples/spin/dpa4/input-deepspin.json @@ -1,5 +1,5 @@ { - "_comment": "DPA4/SeZM spin-energy training example using the DeePMD spin convention.", + "_comment": "DPA4/SeZM spin-energy training example using the virtual-atom (DeepSpin) scheme; see input.json for the native scheme.", "model": { "type": "DPA4", "type_map": [ @@ -7,12 +7,13 @@ "O" ], "spin": { + "scheme": "deepspin", + "virtual_scale": [ + 0.314 + ], "use_spin": [ true, false - ], - "virtual_scale": [ - 0.314 ] }, "descriptor": { @@ -82,14 +83,14 @@ "stat_file": "./dpa4_spin.hdf5", "training_data": { "systems": [ - "../../spin/data_reformat/data_0", - "../../spin/data_reformat/data_1" + "../data_reformat/data_0", + "../data_reformat/data_1" ], "batch_size": 1 }, "validation_data": { "systems": [ - "../../spin/data_reformat/data_2" + "../data_reformat/data_2" ], "batch_size": 1, "numb_batch": 1 diff --git a/examples/spin/dpa4/input.json b/examples/spin/dpa4/input.json new file mode 100644 index 0000000000..ec3bb8d2bb --- /dev/null +++ b/examples/spin/dpa4/input.json @@ -0,0 +1,116 @@ +{ + "_comment": "DPA4/SeZM native-spin training example: spin enters the descriptor as an equivariant feature (no virtual atoms).", + "model": { + "type": "DPA4", + "type_map": [ + "Ni", + "O" + ], + "spin": { + "scheme": "native", + "use_spin": [ + true, + false + ] + }, + "descriptor": { + "sel": 120, + "rcut": 6.0, + "channels": 32, + "n_radial": 16, + "use_env_seed": true, + "lmax": 3, + "mmax": 1, + "n_blocks": 2, + "so2_layers": 3, + "radial_so2_mode": "degree_channel", + "radial_so2_rank": 1, + "n_focus": 2, + "focus_dim": 0, + "n_atten_head": 1, + "ffn_neurons": 0, + "ffn_so3_grid": true, + "grid_mlp": false, + "grid_branch": 1, + "ffn_blocks": 2, + "sandwich_norm": [ + false, + true, + true, + false + ], + "message_node_so3": true, + "use_amp": true, + "precision": "float32", + "seed": 42 + }, + "fitting_net": { + "neuron": [ + 0 + ], + "precision": "float32", + "seed": 42 + }, + "use_compile": false, + "enable_tf32": true + }, + "learning_rate": { + "type": "wsd", + "start_lr": 4.5e-4, + "stop_lr": 1e-6, + "warmup_steps": 5000, + "warmup_start_factor": 0.2, + "decay_phase_ratio": 0.65, + "decay_type": "cosine" + }, + "loss": { + "type": "ener_spin", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_fr": 1000, + "limit_pref_fr": 1, + "start_pref_fm": 1000, + "limit_pref_fm": 1 + }, + "optimizer": { + "type": "HybridMuon", + "weight_decay": 0.001 + }, + "training": { + "stat_file": "./dpa4_spin.hdf5", + "training_data": { + "systems": [ + "../data_reformat/data_0", + "../data_reformat/data_1" + ], + "batch_size": 1 + }, + "validation_data": { + "systems": [ + "../data_reformat/data_2" + ], + "batch_size": 1, + "numb_batch": 1 + }, + "numb_steps": 2000000, + "gradient_max_norm": 5.0, + "save_freq": 2000, + "max_ckpt_keep": 3, + "enable_ema": true, + "ema_decay": 0.999, + "ema_ckpt_keep": 3, + "disp_file": "lcurve.out", + "disp_freq": 1000, + "disp_avg": true, + "disp_training": true, + "time_training": true, + "tensorboard": false, + "enable_profiler": false, + "tensorboard_freq": 1000, + "tensorboard_log_dir": "tb_log", + "profiling": false, + "profiling_file": "timeline.json", + "zero_stage": 1, + "seed": 42 + } +} diff --git a/examples/spin/dpa4/lmp/README.md b/examples/spin/dpa4/lmp/README.md new file mode 100644 index 0000000000..ac1a75712c --- /dev/null +++ b/examples/spin/dpa4/lmp/README.md @@ -0,0 +1,38 @@ +# LAMMPS example for DPA4 / SeZM native spin + +Runs a native-spin DPA4 / SeZM model in LAMMPS through `pair_style deepspin` +backed by an AOTInductor `.pt2` archive. For the classical DeepSpin +(virtual-atom) scheme, see `examples/spin/lmp`. + +## Files + +| File | Description | +| ----------- | --------------------------------------------------------- | +| `in.lammps` | Single-point evaluation of a 4-atom NiO cell. | +| `init.data` | `atom_style spin` data: 2 magnetic Ni + 2 non-magnetic O. | + +## Usage + +Train (configuration in `../input.json`) and freeze to a `.pt2` archive. The +freeze CLI detects DPA4 / SeZM and rewrites the suffix to `.pt2`; the archive is +target-specific and is not shipped, so freeze locally: + +```bash +dp --pt train ../input.json --skip-neighbor-stat +dp --pt freeze -c model.ckpt.pt -o frozen_model +``` + +Run: + +```bash +lmp -in in.lammps +``` + +`spin.dump` holds the per-atom force (`fx fy fz`) and magnetic force +(`c_fmag[1..3]`), which is non-zero on the magnetic Ni atoms and zero on O. + +The same archive runs under domain decomposition without any change: + +```bash +mpirun -np 2 lmp -in in.lammps +``` diff --git a/examples/spin/dpa4/lmp/in.lammps b/examples/spin/dpa4/lmp/in.lammps new file mode 100644 index 0000000000..f20c278441 --- /dev/null +++ b/examples/spin/dpa4/lmp/in.lammps @@ -0,0 +1,23 @@ +units metal +atom_style spin +boundary p p p +atom_modify map yes + +neighbor 2.0 bin +neigh_modify every 10 delay 0 check no + +read_data init.data +mass 1 58.69 +mass 2 16.00 + +pair_style deepspin frozen_model.pt2 +pair_coeff * * Ni O + +# per-atom magnetic force fm = -dE/dspin +compute fmag all property/atom fmx fmy fmz + +thermo_style custom step pe +thermo 1 +dump 1 all custom 1 spin.dump id type x y z fx fy fz c_fmag[1] c_fmag[2] c_fmag[3] +dump_modify 1 sort id format float %.12e +run 0 diff --git a/examples/spin/dpa4/lmp/init.data b/examples/spin/dpa4/lmp/init.data new file mode 100644 index 0000000000..e8c378b907 --- /dev/null +++ b/examples/spin/dpa4/lmp/init.data @@ -0,0 +1,21 @@ +LAMMPS data file via dp-raw_to_lmp_data python script + +4 atoms +2 atom types + +0.00000000e+00 5.12307051e+00 xlo xhi +0.00000000e+00 2.83444838e+00 ylo yhi +0.00000000e+00 2.52508896e+00 zlo zhi +4.26742132e+00 4.26744884e+00 1.28772700e+00 xy xz yz + +Masses + +1 58.69 +2 16.00 + +Atoms # spin + +1 1 0.00000000e+00 0.00000000e+00 0.00000000e+00 -3.33632650e-01 8.03816499e-01 4.92512224e-01 1.27370000e+00 +2 1 6.82897034e+00 2.06108769e+00 1.26254448e+00 3.33632650e-01 -8.03816499e-01 -4.92512224e-01 1.27370000e+00 +3 2 3.41448517e+00 1.03054385e+00 6.31272240e-01 0.00000000e+00 0.00000000e+00 1.00000000e+00 0.00000000e+00 +4 2 1.02434555e+01 3.09163154e+00 1.89381672e+00 0.00000000e+00 0.00000000e+00 1.00000000e+00 0.00000000e+00 diff --git a/examples/water/dpa4/README.md b/examples/water/dpa4/README.md index 15d6a383f0..3cbb0dde49 100644 --- a/examples/water/dpa4/README.md +++ b/examples/water/dpa4/README.md @@ -9,7 +9,6 @@ Input files: - `input.json`: baseline conservative energy training, using a compact DPA4-Neo-style parameter set. - `input-zbl.json`: energy training with ZBL zone bridging. -- `input-spin.json`: spin-energy training with the DeePMD spin convention. - `input_dens.json`: direct-force denoising training. - `input_multitask.json`: multitask training with a shared descriptor and case-conditioned shared fitting network. diff --git a/examples/water/dpa4/input.json b/examples/water/dpa4/input.json index c0e12b9be4..1126c7107a 100644 --- a/examples/water/dpa4/input.json +++ b/examples/water/dpa4/input.json @@ -123,6 +123,7 @@ "validating": { "compiled_infer": false, "tf32_infer": false, + "amp_infer": false, "save_best_dir": "ckpt_best" } } diff --git a/pyproject.toml b/pyproject.toml index fd7e156a87..4f3a707977 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -460,6 +460,7 @@ runtime-evaluated-base-classes = ["torch.nn.Module"] "backend/**" = ["ANN"] "data/**" = ["ANN"] "deepmd/_vendors/**" = ["ALL"] +"deepmd/kernels/**" = ["TID253", "B905"] "deepmd/tf/**" = ["TID253"] "deepmd/tf2/**" = ["TID253"] "deepmd/pt/**" = ["TID253", "B905"] diff --git a/source/api_cc/include/DeepPot.h b/source/api_cc/include/DeepPot.h index 177513ebb0..4a197cb064 100644 --- a/source/api_cc/include/DeepPot.h +++ b/source/api_cc/include/DeepPot.h @@ -307,6 +307,24 @@ class DeepPotBackend : public DeepBaseModelBackend { computew_mixed_type(ener, force, virial, atom_energy, atom_virial, nframes, coord, atype, box, fparam, aparam, atomic); } + /** + * @brief GPU-resident edge-input inference for the SeZM/DPA4 graph-form .pt2: + *given device edge tensors, write per-atom energy / force / virial back to + * the device output pointers. The PyTorch Exportable backend overrides this; + * every other backend inherits the throwing default. The signature is + * intentionally torch-free so the dispatcher stays backend-agnostic (no + * dynamic_cast into a PyTorch-heavy type, so ``libdeepmd_cc`` need not link + * PyTorch). + */ + virtual void compute_edges_gpu(double* d_atom_energy, + double* d_force, + double* d_atom_virial, + const double* d_coord, + const int* d_atype, + const int* d_edge_index, + const double* d_edge_vec, + const int nloc, + const int nedge); }; /** @@ -664,6 +682,36 @@ class DeepPot : public DeepBaseModel { const std::vector& charge_spin = std::vector()); /** @} */ + /** + * @brief Fully device-resident edge inference for single-domain SeZM/DPA4. + * + * Forwards to the PyTorch Exportable (.pt2) backend's GPU edge path; raising + * if the active backend is not ``DeepPotPTExpt``. All pointers reference GPU + * memory on the model's device. See + * ``DeepPotPTExpt::compute_edges_gpu`` for the edge contract. This signature + * is intentionally torch-free so MD-engine call sites need no PyTorch + * headers. + * + * @param[out] d_atom_energy Per-atom energy, GPU [nloc]. + * @param[out] d_force Per-atom force, GPU [nloc * 3] row-major. + * @param[out] d_atom_virial Per-atom virial, GPU [nloc * 9] row-major. + * @param[in] d_coord Local coordinates, GPU [nloc * 3] row-major. + * @param[in] d_atype Local atom types, GPU [nloc]. + * @param[in] d_edge_index Local edge graph, GPU [2 * nedge]. + * @param[in] d_edge_vec Minimum-image bond vectors, GPU [nedge * 3]. + * @param[in] nloc Number of local atoms. + * @param[in] nedge Number of physical edges. + */ + void compute_edges_gpu(double* d_atom_energy, + double* d_force, + double* d_atom_virial, + const double* d_coord, + const int* d_atype, + const int* d_edge_index, + const double* d_edge_vec, + const int nloc, + const int nedge); + int dim_chg_spin() const; protected: diff --git a/source/api_cc/include/DeepPotPTExpt.h b/source/api_cc/include/DeepPotPTExpt.h index ddaea35646..6c30fae991 100644 --- a/source/api_cc/include/DeepPotPTExpt.h +++ b/source/api_cc/include/DeepPotPTExpt.h @@ -290,6 +290,36 @@ class DeepPotPTExpt : public DeepPotBackend { const std::vector& charge_spin, const bool atomic) override; + /** + * @brief Fully device-resident edge inference for single-domain SeZM/DPA4. + * + * Runs the exported model directly on a GPU-built compact edge schema, + * keeping coordinates, the edge graph and the outputs on the device. All + * pointers reference GPU memory on the model's device. ``edge_index`` is the + * flattened [2, nedge] local edge graph (row 0 = neighbor/source, row 1 = + * center/destination); ``edge_vec`` is the matching minimum-image bond vector + * ``r_neighbor - r_center``. Outputs are written device-to-device. + * + * @param[out] d_atom_energy Per-atom energy, GPU [nloc]. + * @param[out] d_force Per-atom force, GPU [nloc * 3] row-major. + * @param[out] d_atom_virial Per-atom virial, GPU [nloc * 9] row-major. + * @param[in] d_coord Local coordinates, GPU [nloc * 3] row-major. + * @param[in] d_atype Local atom types, GPU [nloc]. + * @param[in] d_edge_index Local edge graph, GPU [2 * nedge]. + * @param[in] d_edge_vec Minimum-image bond vectors, GPU [nedge * 3]. + * @param[in] nloc Number of local atoms. + * @param[in] nedge Number of physical edges (dummy edges added internally). + */ + void compute_edges_gpu(double* d_atom_energy, + double* d_force, + double* d_atom_virial, + const double* d_coord, + const int* d_atype, + const int* d_edge_index, + const double* d_edge_vec, + const int nloc, + const int nedge) override; + private: bool inited; int ntypes; diff --git a/source/api_cc/include/DeepSpinPTExpt.h b/source/api_cc/include/DeepSpinPTExpt.h index e21d10ec36..a4d0311d98 100644 --- a/source/api_cc/include/DeepSpinPTExpt.h +++ b/source/api_cc/include/DeepSpinPTExpt.h @@ -191,10 +191,15 @@ class DeepSpinPTExpt : public DeepSpinBackend { std::vector type_map; std::vector output_keys; bool do_atomic_virial; // whether model was exported with atomic virial corr - int nnei; // expected nlist nnei dimension (= sum(sel)) + // Whether the exported graph consumes the compact edge schema (native spin, + // shared with DeepPotPTExpt) rather than the deepspin-scheme nlist contract. + bool lower_input_is_edge_ = false; + int nnei; // expected nlist nnei dimension (= sum(sel)) NeighborListData nlist_data; - at::Tensor mapping_tensor; // cached mapping tensor (LAMMPS path) - at::Tensor firstneigh_tensor; // cached nlist tensor (LAMMPS path) + at::Tensor mapping_tensor; // cached mapping tensor (LAMMPS path) + at::Tensor firstneigh_tensor; // cached nlist tensor (LAMMPS path) + at::Tensor edge_index_tensor; // cached local-folded edges (edge path) + at::Tensor edge_index_ext_tensor; // cached extended edges (edge path) std::unique_ptr loader; // Optional with-comm artifact for multi-rank GNN spin inference. bool has_comm_artifact_ = false; @@ -212,6 +217,39 @@ class DeepSpinPTExpt : public DeepSpinBackend { const torch::Tensor& fparam, const torch::Tensor& aparam); + /** + * @brief Run the native-spin edge artifact: the energy edge schema plus the + * per-local-atom spin leaf. + */ + std::vector run_model_edges( + const torch::Tensor& coord, + const torch::Tensor& atype, + const torch::Tensor& edge_index, + const torch::Tensor& edge_vec, + const torch::Tensor& edge_scatter_index, + const torch::Tensor& edge_mask, + const torch::Tensor& spin, + const torch::Tensor& fparam, + const torch::Tensor& aparam); + + /** + * @brief Run the native-spin parallel edge artifact: the energy edge + * with-comm schema (coord and extended types span the extended node set) + * plus the EXTENDED per-node spin leaf, then the 8 border_op comm tensors. + */ + std::vector run_model_edges_with_comm( + const torch::Tensor& coord, + const torch::Tensor& atype, + const torch::Tensor& extended_atype, + const torch::Tensor& edge_index, + const torch::Tensor& edge_vec, + const torch::Tensor& edge_scatter_index, + const torch::Tensor& edge_mask, + const torch::Tensor& spin, + const torch::Tensor& fparam, + const torch::Tensor& aparam, + const std::vector& comm_tensors); + /** * @brief Run with-comm spin artifact: 5-7 base inputs (incl. * extended_spin) + 8 comm tensors. diff --git a/source/api_cc/src/DeepPot.cc b/source/api_cc/src/DeepPot.cc index 02ee1d24b3..d2b9d773ea 100644 --- a/source/api_cc/src/DeepPot.cc +++ b/source/api_cc/src/DeepPot.cc @@ -564,6 +564,49 @@ template void DeepPot::compute_mixed_type( const std::vector& aparam, const std::vector& charge_spin); +void DeepPotBackend::compute_edges_gpu(double* d_atom_energy, + double* d_force, + double* d_atom_virial, + const double* d_coord, + const int* d_atype, + const int* d_edge_index, + const double* d_edge_vec, + const int nloc, + const int nedge) { + (void)d_atom_energy; + (void)d_force; + (void)d_atom_virial; + (void)d_coord; + (void)d_atype; + (void)d_edge_index; + (void)d_edge_vec; + (void)nloc; + (void)nedge; + throw deepmd::deepmd_exception( + "compute_edges_gpu (GPU-resident edge inference) is only supported by " + "the " + "PyTorch Exportable (.pt2) backend."); +} + +void DeepPot::compute_edges_gpu(double* d_atom_energy, + double* d_force, + double* d_atom_virial, + const double* d_coord, + const int* d_atype, + const int* d_edge_index, + const double* d_edge_vec, + const int nloc, + const int nedge) { + // Polymorphic dispatch to the loaded backend: the PyTorch Exportable backend + // overrides ``compute_edges_gpu``; other backends inherit the throwing + // default. This replaces a ``dynamic_cast`` into the PyTorch-heavy + // ``DeepPotPTExpt``, which the "load backends as plugins" refactor made + // uncompilable in the backend-agnostic ``libdeepmd_cc`` (it does not link + // PyTorch), so the cast branch was always stubbed out. + dp->compute_edges_gpu(d_atom_energy, d_force, d_atom_virial, d_coord, d_atype, + d_edge_index, d_edge_vec, nloc, nedge); +} + int DeepPot::dim_chg_spin() const { return dp->dim_chg_spin(); } DeepPotModelDevi::DeepPotModelDevi() { diff --git a/source/api_cc/src/DeepPotPTExpt.cc b/source/api_cc/src/DeepPotPTExpt.cc index cff23388b6..9088762fcb 100644 --- a/source/api_cc/src/DeepPotPTExpt.cc +++ b/source/api_cc/src/DeepPotPTExpt.cc @@ -2,6 +2,7 @@ #include "DeepPotPTExpt.h" #if defined(BUILD_PYTORCH) && BUILD_PT_EXPT +#include #include #include @@ -1746,4 +1747,154 @@ void DeepPotPTExpt::computew_mixed_type(std::vector& ener, charge_spin, atomic); }); } + +void DeepPotPTExpt::compute_edges_gpu(double* d_atom_energy, + double* d_force, + double* d_atom_virial, + const double* d_coord, + const int* d_atype, + const int* d_edge_index, + const double* d_edge_vec, + const int nloc, + const int nedge) { + // Fully device-resident edge inference for single-domain SeZM/DPA4 models. + // + // The caller (an MD engine such as GPUMD) builds the neighbor list and the + // compact edge schema on the GPU and passes raw device pointers. This entry + // keeps every tensor on the GPU: coordinates, the edge graph and the model + // outputs never leave the device, eliminating the host neighbor-list build + // and the per-step host-device transfers of the standalone ``compute`` path. + // + // Edge contract (single domain, minimum-image): ``edge_index`` and + // ``edge_scatter_index`` coincide and index local atoms; ``edge_vec`` carries + // the minimum-image bond vector ``r_neighbor - r_center``. The SeZM force is + // ``dE/d(edge_vec)`` scattered through ``edge_scatter_index``, so the + // per-atom force lands directly on local atoms with no ghost fold-back. + if (!gpu_enabled) { + throw deepmd::deepmd_exception( + "compute_edges_gpu requires a CUDA device but the model was loaded on " + "CPU."); + } + if (!do_atomic_virial) { + throw deepmd::deepmd_exception( + "compute_edges_gpu always returns the per-atom virial, but this .pt2 " + "model was exported without it (do_atomic_virial=False)."); + } + if (!lower_input_is_edge_ && !lower_input_is_graph_) { + throw deepmd::deepmd_exception( + "compute_edges_gpu requires an edge-input (SeZM/DPA4) or graph-input " + "(DPA1/DPA2/DPA3) .pt2 model."); + } + translate_error([&] { + const torch::Device device(torch::kCUDA, gpu_id); + const c10::DeviceGuard device_guard(device); + const auto opt_f64 = + torch::TensorOptions().dtype(torch::kFloat64).device(device); + const auto opt_i32 = + torch::TensorOptions().dtype(torch::kInt32).device(device); + const auto opt_i64 = + torch::TensorOptions().dtype(torch::kInt64).device(device); + const auto opt_bool = + torch::TensorOptions().dtype(torch::kBool).device(device); + + // === Step 1. Wrap caller GPU buffers as tensors (no copy) === + at::Tensor coord_t = + torch::from_blob(const_cast(d_coord), {1, nloc, 3}, opt_f64); + at::Tensor atype_t = + torch::from_blob(const_cast(d_atype), {1, nloc}, opt_i32) + .to(torch::kInt64); + at::Tensor edge_index_real = + torch::from_blob(const_cast(d_edge_index), {2, nedge}, opt_i32) + .to(torch::kInt64); + at::Tensor edge_vec_real = + torch::from_blob(const_cast(d_edge_vec), {nedge, 3}, opt_f64); + + // === Step 2. Append two masked dummy edges (exported-graph contract) === + at::Tensor edge_index = + torch::cat({edge_index_real, torch::zeros({2, 2}, opt_i64)}, 1); + at::Tensor edge_vec = + torch::cat({edge_vec_real, torch::zeros({2, 3}, opt_f64)}, 0); + at::Tensor edge_mask = torch::cat( + {torch::ones({nedge}, opt_bool), torch::zeros({2}, opt_bool)}); + // Single-domain scheme: the force-scatter index is the (local) edge graph. + const at::Tensor& edge_scatter_index = edge_index; + + // === Step 3. Optional model inputs (defaults for fparam/charge_spin) === + at::Tensor fparam_tensor; + if (dfparam > 0) { + if (!(has_default_fparam_ && !default_fparam_.empty())) { + throw deepmd::deepmd_exception( + "compute_edges_gpu: model requires fparam but no default_fparam is " + "stored in the .pt2 metadata."); + } + fparam_tensor = + torch::from_blob( + const_cast(default_fparam_.data()), + {1, static_cast(default_fparam_.size())}, + torch::TensorOptions().dtype(torch::kFloat64)) + .clone() + .to(device); + } else { + fparam_tensor = torch::zeros({0}, opt_f64); + } + if (daparam > 0) { + throw deepmd::deepmd_exception( + "compute_edges_gpu: aparam models are not supported by the GPU edge " + "path."); + } + at::Tensor aparam_tensor = torch::zeros({0}, opt_f64); + at::Tensor charge_spin_tensor; + if (dchgspin > 0) { + if (default_chg_spin_.empty()) { + throw deepmd::deepmd_exception( + "compute_edges_gpu: model requires charge_spin but no " + "default_chg_spin is stored in the .pt2 metadata."); + } + charge_spin_tensor = + torch::from_blob(const_cast(default_chg_spin_.data()), + {1, dchgspin}, + torch::TensorOptions().dtype(torch::kFloat64)) + .clone() + .to(device); + } + + // === Step 4. Run the exported model and read the per-atom outputs === + // The two lower forms share the masked edge tensors but differ in both the + // input set and the output naming, so each form runs and unpacks itself; + // the result is always per-atom energy (nloc), force (nloc, 3) and virial + // (nloc, 9), copied out below. + at::Tensor ae, force_t, av; + std::map out; + if (lower_input_is_graph_) { + // Graph (DPA1/DPA2/DPA3 NeighborGraph): single-frame node count and a + // flat node-major atype; the model returns the high-level per-atom + // quantities. + const at::Tensor n_node = + torch::full({1}, static_cast(nloc), opt_i64); + extract_outputs( + out, run_model_graph(atype_t.reshape({nloc}), n_node, edge_index, + edge_vec, edge_mask, fparam_tensor, + aparam_tensor, charge_spin_tensor)); + ae = out["atom_energy"].reshape({nloc}).contiguous(); + force_t = out["force"].reshape({nloc, 3}).contiguous(); + av = out["atom_virial"].reshape({nloc, 9}).contiguous(); + } else { + // Edge (SeZM/DPA4): coord + edge_scatter_index; the model returns the raw + // reduced-energy derivatives (force/virial per extended atom). + extract_outputs( + out, run_model_edges(coord_t, atype_t, edge_index, edge_vec, + edge_scatter_index, edge_mask, fparam_tensor, + aparam_tensor, charge_spin_tensor)); + ae = out["energy"].reshape({nloc}).contiguous(); + force_t = + out["energy_derv_r"].squeeze(-2).reshape({nloc, 3}).contiguous(); + av = out["energy_derv_c"].squeeze(-2).reshape({nloc, 9}).contiguous(); + } + + // === Step 5. Copy per-atom outputs into caller GPU buffers (D2D) === + torch::from_blob(d_atom_energy, {nloc}, opt_f64).copy_(ae); + torch::from_blob(d_force, {nloc, 3}, opt_f64).copy_(force_t); + torch::from_blob(d_atom_virial, {nloc, 9}, opt_f64).copy_(av); + }); +} #endif diff --git a/source/api_cc/src/DeepSpinPTExpt.cc b/source/api_cc/src/DeepSpinPTExpt.cc index 6209409f55..c908be18c7 100644 --- a/source/api_cc/src/DeepSpinPTExpt.cc +++ b/source/api_cc/src/DeepSpinPTExpt.cc @@ -163,6 +163,15 @@ void DeepSpinPTExpt::init(const std::string& model, } } + // Native spin shares the energy edge ABI; the deepspin scheme keeps the nlist + // contract. Pre-edge spin archives lack the field and default to nlist. + if (metadata.obj_val.count("lower_input_kind")) { + lower_input_is_edge_ = + metadata["lower_input_kind"].as_string() == "edge_vec"; + } else { + lower_input_is_edge_ = false; + } + type_map.clear(); for (const auto& v : metadata["type_map"].as_array()) { type_map.push_back(v.as_string()); @@ -259,6 +268,88 @@ std::vector DeepSpinPTExpt::run_model( return loader->run(inputs); } +std::vector DeepSpinPTExpt::run_model_edges( + const torch::Tensor& coord, + const torch::Tensor& atype, + const torch::Tensor& edge_index, + const torch::Tensor& edge_vec, + const torch::Tensor& edge_scatter_index, + const torch::Tensor& edge_mask, + const torch::Tensor& spin, + const torch::Tensor& fparam, + const torch::Tensor& aparam) { + // Native-spin edge ABI: the energy edge inputs followed by the + // per-local-atom spin leaf, then the optional fparam / aparam / charge_spin. + std::vector inputs = { + coord, atype, edge_index, edge_vec, edge_scatter_index, edge_mask, spin}; + if (dfparam > 0) { + inputs.push_back(fparam); + } + if (daparam > 0) { + inputs.push_back(aparam); + } + if (dim_chg_spin > 0) { + auto charge_spin = torch::tensor(default_chg_spin_, coord.options()) + .view({1, dim_chg_spin}) + .expand({coord.size(0), dim_chg_spin}) + .contiguous(); + inputs.push_back(charge_spin); + } + return loader->run(inputs); +} + +std::vector DeepSpinPTExpt::run_model_edges_with_comm( + const torch::Tensor& coord, + const torch::Tensor& atype, + const torch::Tensor& extended_atype, + const torch::Tensor& edge_index, + const torch::Tensor& edge_vec, + const torch::Tensor& edge_scatter_index, + const torch::Tensor& edge_mask, + const torch::Tensor& spin, + const torch::Tensor& fparam, + const torch::Tensor& aparam, + const std::vector& comm_tensors) { + if (!with_comm_loader) { + throw deepmd::deepmd_exception( + "DeepSpinPTExpt::run_model_edges_with_comm called but the with-comm " + "artifact is not available. Either the .pt2 file has no with-comm " + "artifact compiled, or the artifact was present in the .pt2 metadata " + "but failed to load at init time (see earlier stderr log). Multi-rank " + "LAMMPS requires a working with-comm artifact."); + } + if (comm_tensors.size() != 8) { + throw deepmd::deepmd_exception( + "DeepSpinPTExpt::run_model_edges_with_comm: comm_tensors must contain " + "exactly 8 tensors. Got " + + std::to_string(comm_tensors.size()) + "."); + } + // Native-spin parallel ABI: the energy edge with-comm inputs (coord and the + // extended types span the extended node set) followed by the EXTENDED + // per-node spin leaf, then the optional fparam / aparam / charge_spin, then + // the eight border_op communication tensors. + std::vector inputs = {coord, atype, extended_atype, + edge_index, edge_vec, edge_scatter_index, + edge_mask, spin}; + if (dfparam > 0) { + inputs.push_back(fparam); + } + if (daparam > 0) { + inputs.push_back(aparam); + } + if (dim_chg_spin > 0) { + auto charge_spin = torch::tensor(default_chg_spin_, coord.options()) + .view({1, dim_chg_spin}) + .expand({coord.size(0), dim_chg_spin}) + .contiguous(); + inputs.push_back(charge_spin); + } + for (const auto& t : comm_tensors) { + inputs.push_back(t); + } + return with_comm_loader->run(inputs); +} + std::vector DeepSpinPTExpt::run_model_with_comm( const torch::Tensor& coord, const torch::Tensor& atype, @@ -482,13 +573,12 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener, if (ago == 0) { nlist_data.copy_from_nlist(lmp_list, nall - nghost); nlist_data.shuffle_exclude_empty(fwd_map); - nlist_data.padding(); - // Rebuild mapping tensor. Phantom slots (when phantom_n > 0) get - // identity entries — they index into their own row and never appear - // in any other atom's nlist (their nlist rows are all -1 below). + // Rebuild mapping. Phantom slots (when phantom_n > 0) get identity + // entries — they index into their own row and never appear in any other + // atom's nlist (their nlist rows are all -1 below). + std::vector mapping(nall_real); if (lmp_list.mapping) { - std::vector mapping(nall_real); for (int ii = 0; ii < phantom_n; ii++) { mapping[ii] = ii; } @@ -504,36 +594,49 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener, mapping[ii] = fwd_map[lmp_list.mapping[bkw_map[ii - phantom_n]]] + phantom_n; } - mapping_tensor = - torch::from_blob(mapping.data(), {1, nall_real}, int_option) - .clone() - .to(device); } else { // Identity fallback. See DeepPotPTExpt::compute_inner for the // invariant rationale: this branch is only reached when the // model is non-message-passing, nghost==0, or use_with_comm is // true (border_op fills ghosts); other configurations were // rejected by the fail-fast above. - std::vector mapping(nall_real); for (int ii = 0; ii < nall_real; ii++) { mapping[ii] = ii; } - mapping_tensor = - torch::from_blob(mapping.data(), {1, nall_real}, int_option) - .clone() - .to(device); } + mapping_tensor = + torch::from_blob(mapping.data(), {1, nall_real}, int_option) + .clone() + .to(device); - // Flatten raw nlist — the .pt2 model sorts by distance on-device. - // Phantom rows (all -1) are prepended below so the AOTI graph sees - // nloc == phantom_n + nloc_real_orig instead of 0. - firstneigh_tensor = - createNlistTensor(nlist_data.jlist, nnei).to(torch::kInt64).to(device); - if (phantom_n > 0) { - auto phantom_rows = torch::full( - {1, phantom_n, nnei}, static_cast(-1), - torch::TensorOptions().dtype(torch::kInt64).device(device)); - firstneigh_tensor = torch::cat({phantom_rows, firstneigh_tensor}, 1); + if (lower_input_is_edge_) { + // Native spin reuses the energy edge ABI: cache only the real skin + // topology and recompute the model-cutoff edge vectors on-device every + // step (see DeepPotPTExpt.cc). Single-rank folds ghost neighbours onto + // their local owners (``fold_to_local=true``); multi-rank indexes the + // extended node set directly (``fold_to_local=false``) so ghost node + // features -- including the per-node spin embedding -- can be exchanged + // across ranks via border_op. + const auto edge_tensors = createEdgeTensors( + nlist_data.jlist, dcoord, mapping, nloc, nall_real, device, + /*with_geometry=*/false, /*row_centers=*/&nlist_data.ilist, + /*fold_to_local=*/!use_with_comm); + edge_index_tensor = edge_tensors.edge_index; + edge_index_ext_tensor = edge_tensors.edge_index_ext; + } else { + nlist_data.padding(); + // Flatten raw nlist — the .pt2 model sorts by distance on-device. + // Phantom rows (all -1) are prepended below so the AOTI graph sees + // nloc == phantom_n + nloc_real_orig instead of 0. + firstneigh_tensor = createNlistTensor(nlist_data.jlist, nnei) + .to(torch::kInt64) + .to(device); + if (phantom_n > 0) { + auto phantom_rows = torch::full( + {1, phantom_n, nnei}, static_cast(-1), + torch::TensorOptions().dtype(torch::kInt64).device(device)); + firstneigh_tensor = torch::cat({phantom_rows, firstneigh_tensor}, 1); + } } } @@ -599,7 +702,39 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener, if (use_with_comm) { bool has_null_atoms = (nall_real < nall); std::vector comm_tensors; - if (has_null_atoms) { + if (phantom_n > 0) { + // Empty subdomain: the phantom prefix shifts every node index by + // ``phantom_n`` (received ghost features land at [phantom_n, nall)), so + // the forwarded send-list -- built in the real-atom node space, remapped + // when NULL-type atoms were filtered -- is offset to match. Without the + // offset border_op forwards the zeroed phantom slots instead of the + // relayed ghost features (see DeepPotPTExpt.cc for the full rationale). + if (has_null_atoms) { + deepmd::remap_comm_sendlist(remapped_sendlist, remapped_sendnum, + remapped_recvnum, lmp_list, fwd_map); + } else { + remapped_sendlist.resize(lmp_list.nswap); + remapped_sendnum.assign(lmp_list.sendnum, + lmp_list.sendnum + lmp_list.nswap); + remapped_recvnum.assign(lmp_list.recvnum, + lmp_list.recvnum + lmp_list.nswap); + for (int iswap = 0; iswap < lmp_list.nswap; ++iswap) { + remapped_sendlist[iswap].assign( + lmp_list.sendlist[iswap], + lmp_list.sendlist[iswap] + lmp_list.sendnum[iswap]); + } + } + remapped_sendlist_ptrs.resize(lmp_list.nswap); + for (int iswap = 0; iswap < lmp_list.nswap; ++iswap) { + for (int& idx : remapped_sendlist[iswap]) { + idx += phantom_n; + } + remapped_sendlist_ptrs[iswap] = remapped_sendlist[iswap].data(); + } + comm_tensors = deepmd::ptexpt::build_comm_tensors_positional( + lmp_list, remapped_sendlist_ptrs.data(), remapped_sendnum.data(), + remapped_recvnum.data(), phantom_n, nghost_real); + } else if (has_null_atoms) { comm_tensors = deepmd::ptexpt::build_comm_tensors_positional_with_virtual_atoms( lmp_list, fwd_map, nloc, nghost_real, remapped_sendlist, @@ -609,9 +744,56 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener, lmp_list, lmp_list.sendlist, lmp_list.sendnum, lmp_list.recvnum, nloc, nghost_real); } - flat_outputs = run_model_with_comm( - coord_Tensor, atype_Tensor, spin_Tensor, firstneigh_tensor, - mapping_tensor, fparam_tensor, aparam_tensor, comm_tensors); + if (lower_input_is_edge_) { + // Native spin multi-rank: edges index the extended node set + // (fold_to_local=false above), the EXTENDED per-node spin feeds the + // descriptor (ghost spins arrive via the LAMMPS sp forward-comm), and + // border_op exchanges ghost node features between interaction blocks. + // The conservative and magnetic forces both return extended and are + // folded onto owners by the LAMMPS force / spin reverse-comm. + if (phantom_n > 0) { + // Empty rank: coord/atype/spin already carry the phantom prefix; supply + // two masked self-edges (edge_mask=false) so the graph runs at + // nedge>=2 / nloc>=2 with zero physical contribution. Real ghost + // features still arrive via border_op at slots [phantom_n, nall). + const auto bool_option = + torch::TensorOptions().device(torch::kCPU).dtype(torch::kBool); + at::Tensor ph_edge_index = torch::zeros({2, 2}, int_option).to(device); + at::Tensor ph_edge_vec = torch::zeros({2, 3}, options).to(device); + at::Tensor ph_edge_mask = torch::zeros({2}, bool_option).to(device); + flat_outputs = run_model_edges_with_comm( + coord_Tensor, atype_Tensor.slice(1, 0, nloc), atype_Tensor, + ph_edge_index, ph_edge_vec, ph_edge_index, ph_edge_mask, + spin_Tensor, fparam_tensor, aparam_tensor, comm_tensors); + } else { + const auto edge_tensors = + compactEdgeTensors(edge_index_tensor, edge_index_ext_tensor, + coord_Tensor, static_cast(rcut)); + flat_outputs = run_model_edges_with_comm( + coord_Tensor, atype_Tensor.slice(1, 0, nloc), atype_Tensor, + edge_tensors.edge_index, edge_tensors.edge_vec, + edge_tensors.edge_index_ext, edge_tensors.edge_mask, spin_Tensor, + fparam_tensor, aparam_tensor, comm_tensors); + } + } else { + flat_outputs = run_model_with_comm( + coord_Tensor, atype_Tensor, spin_Tensor, firstneigh_tensor, + mapping_tensor, fparam_tensor, aparam_tensor, comm_tensors); + } + } else if (lower_input_is_edge_) { + // Native spin edge path (single-rank): recompute the model-cutoff edge + // vectors from the cached skin topology and feed only the owned-atom + // spins; the conservative force stays extended (folded back like the + // energy model), while the magnetic force is already per-local-atom + // (zero-padded to nall inside the graph). + const auto edge_tensors = + compactEdgeTensors(edge_index_tensor, edge_index_ext_tensor, + coord_Tensor, static_cast(rcut)); + flat_outputs = run_model_edges( + coord_Tensor, atype_Tensor.slice(1, 0, nloc), edge_tensors.edge_index, + edge_tensors.edge_vec, edge_tensors.edge_index_ext, + edge_tensors.edge_mask, spin_Tensor.slice(1, 0, nloc), fparam_tensor, + aparam_tensor); } else { flat_outputs = run_model(coord_Tensor, atype_Tensor, spin_Tensor, firstneigh_tensor, @@ -887,14 +1069,23 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener, torch::from_blob(atype_64.data(), {1, nall}, int_options) .clone() .to(device); - // Flatten raw nlist — the .pt2 model sorts by distance on-device. - at::Tensor nlist_tensor = - createNlistTensor(nlist_raw, nnei).to(torch::kInt64).to(device); std::vector mapping_64(mapping_vec.begin(), mapping_vec.end()); at::Tensor mapping_tensor = torch::from_blob(mapping_64.data(), {1, nall}, int_options) .clone() .to(device); + at::Tensor nlist_tensor; + EdgeTensorPack edge_tensors; + if (lower_input_is_edge_) { + // Native spin edge ABI: build the full edge schema once (no cached skin + // topology in the standalone path), folding ghosts onto local owners. + edge_tensors = createEdgeTensors(nlist_raw, coord_cpy_d, mapping_64, nloc, + nall, device); + } else { + // Flatten raw nlist — the .pt2 model sorts by distance on-device. + nlist_tensor = + createNlistTensor(nlist_raw, nnei).to(torch::kInt64).to(device); + } // Build fparam/aparam tensors auto valuetype_options = std::is_same::value @@ -936,10 +1127,20 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener, aparam_tensor = torch::zeros({0}, options).to(device); } - // 5. Run the .pt2 model (7 args for spin) - auto flat_outputs = - run_model(coord_Tensor, atype_Tensor, spin_Tensor, nlist_tensor, - mapping_tensor, fparam_tensor, aparam_tensor); + // 5. Run the .pt2 model: native spin uses the energy edge ABI plus the + // owned-atom spins; the deepspin scheme keeps the 7-arg nlist contract. + std::vector flat_outputs; + if (lower_input_is_edge_) { + flat_outputs = run_model_edges( + coord_Tensor, atype_Tensor.slice(1, 0, nloc), edge_tensors.edge_index, + edge_tensors.edge_vec, edge_tensors.edge_index_ext, + edge_tensors.edge_mask, spin_Tensor.slice(1, 0, nloc), fparam_tensor, + aparam_tensor); + } else { + flat_outputs = + run_model(coord_Tensor, atype_Tensor, spin_Tensor, nlist_tensor, + mapping_tensor, fparam_tensor, aparam_tensor); + } // 6. Extract outputs std::map output_map; diff --git a/source/lmp/pair_deepmd.cpp b/source/lmp/pair_deepmd.cpp index 7af7198317..3b87e859f9 100644 --- a/source/lmp/pair_deepmd.cpp +++ b/source/lmp/pair_deepmd.cpp @@ -991,6 +991,27 @@ void PairDeepMD::coeff(int narg, char** arg) { } } if (narg <= 2) { + // A bare `pair_coeff * *` maps LAMMPS atom types onto the model's first + // ntypes elements by position, which is only unambiguous when the model + // type_map has exactly ntypes entries. A larger type_map (e.g. a + // periodic-table pretrained or fine-tuned model) would silently mislabel + // the species, so the elements must be named explicitly. + std::string type_map_str; + deep_pot.get_type_map(type_map_str); + std::istringstream iss(type_map_str); + std::string type_name; + int model_ntypes = 0; + while (iss >> type_name) { + ++model_ntypes; + } + if (model_ntypes > n) { + error->all(FLERR, "pair_coeff * * is ambiguous: the model defines " + + std::to_string(model_ntypes) + + " element types but the system has " + + std::to_string(n) + + " atom types; list the elements explicitly, e.g. " + "pair_coeff * * O H."); + } type_idx_map.resize(n); for (int ii = 0; ii < n; ++ii) { type_idx_map[ii] = ii; diff --git a/source/lmp/pair_deepspin.cpp b/source/lmp/pair_deepspin.cpp index 30ca48576f..2dad2912f0 100644 --- a/source/lmp/pair_deepspin.cpp +++ b/source/lmp/pair_deepspin.cpp @@ -861,6 +861,27 @@ void PairDeepSpin::coeff(int narg, char** arg) { } } if (narg <= 2) { + // A bare `pair_coeff * *` maps LAMMPS atom types onto the model's first + // ntypes elements by position, which is only unambiguous when the model + // type_map has exactly ntypes entries. A larger type_map (e.g. a + // periodic-table pretrained or fine-tuned model) would silently mislabel + // the species, so the elements must be named explicitly. + std::string type_map_str; + deep_spin.get_type_map(type_map_str); + std::istringstream iss(type_map_str); + std::string type_name; + int model_ntypes = 0; + while (iss >> type_name) { + ++model_ntypes; + } + if (model_ntypes > n) { + error->all(FLERR, "pair_coeff * * is ambiguous: the model defines " + + std::to_string(model_ntypes) + + " element types but the system has " + + std::to_string(n) + + " atom types; list the elements explicitly, e.g. " + "pair_coeff * * Fe C."); + } type_idx_map.resize(n); for (int ii = 0; ii < n; ++ii) { type_idx_map[ii] = ii; diff --git a/source/op/pt/comm.cc b/source/op/pt/comm.cc index 31691d5e7d..e00938883a 100644 --- a/source/op/pt/comm.cc +++ b/source/op/pt/comm.cc @@ -354,6 +354,22 @@ class Border : public torch::autograd::Function { recv_g1_tensor.slice(0, 0, nrecv)); } } + // When the forward ran swaps it overwrites every ghost row + // (g_out[ghost] = g_in[owner]), so a ghost INPUT value never reaches the + // output and its gradient is exactly zero. The reverse-comm loop above has + // already routed every ghost output-gradient into its owner row, but it + // leaves the upstream ghost-row gradients in ``d_local_g1_tensor`` + // untouched. Zeroing them is what makes this the true Jacobian-vector + // product: without it the op returns a spurious + // dL/dg_in[ghost] = dL/dg_out[ghost], which is harmless for an + // edge-geometry leaf (ghost rows feed no edge there) but corrupts any + // per-node leaf that flows through the exchanged features (e.g. the native + // spin magnitude/direction carried in the node state). With ``nswap == 0`` + // the forward is the identity and the ghost inputs are preserved, so the + // gradient passes through unchanged. + if (nswap > 0 && nghost > 0) { + d_local_g1_tensor.slice(0, nlocal, ntotal).zero_(); + } #ifdef USE_MPI // Drain pending eager-send ACKs before returning — see forward_t // for the full rationale. Backward has the same asymmetric diff --git a/source/tests/common/dpmodel/test_dpa4_so3_projector.py b/source/tests/common/dpmodel/test_dpa4_so3_projector.py index 8aa2a69ec5..be851ce4cb 100644 --- a/source/tests/common/dpmodel/test_dpa4_so3_projector.py +++ b/source/tests/common/dpmodel/test_dpa4_so3_projector.py @@ -140,10 +140,14 @@ def test_kmax_zero_zonal() -> None: projector.to_grid_mat[:, 1:], zonal, atol=1e-14, rtol=1e-14 ) - # the single-frame projector still round-trips legal coefficients + # the single-frame projector still round-trips legal coefficients; the + # lmax=6 recovery residual sits at ~1e-12 (float64 re-association in the + # Wigner-D monomial products accumulates through the round-trip), so it is + # asserted to 1e-11 for the same reason as + # ``test_roundtrip_preserves_legal_frame_coeffs``. rng = np.random.default_rng(99) x = rng.standard_normal((2, projector.coeff_dim, 2)).astype(np.float64) mask = _legal_so3_frame_mask(projector) x[:, ~mask, :] = 0.0 y = projector.from_grid(projector.to_grid(x)) - np.testing.assert_allclose(y[:, mask, :], x[:, mask, :], atol=1e-12, rtol=1e-12) + np.testing.assert_allclose(y[:, mask, :], x[:, mask, :], atol=1e-11, rtol=1e-11) diff --git a/source/tests/common/test_examples.py b/source/tests/common/test_examples.py index e47e6e2837..111436fd3b 100644 --- a/source/tests/common/test_examples.py +++ b/source/tests/common/test_examples.py @@ -49,6 +49,8 @@ p_examples / "dos" / "train" / "input_torch.json", p_examples / "spin" / "se_e2_a" / "input_tf.json", p_examples / "spin" / "se_e2_a" / "input_torch.json", + p_examples / "spin" / "dpa4" / "input.json", + p_examples / "spin" / "dpa4" / "input-deepspin.json", p_examples / "dprc" / "normal" / "input.json", p_examples / "dprc" / "pairwise" / "input.json", p_examples / "dprc" / "generalized_force" / "input.json", @@ -62,7 +64,6 @@ p_examples / "water" / "dpa3" / "input_torch_dynamic.json", p_examples / "water" / "dpa4" / "input.json", p_examples / "water" / "dpa4" / "input-zbl.json", - p_examples / "water" / "dpa4" / "input-spin.json", p_examples / "water" / "dpa4" / "lmp" / "input.json", p_examples / "property" / "train" / "input_torch.json", p_examples / "water" / "se_e3_tebd" / "input_torch.json", diff --git a/source/tests/pt/model/test_descriptor_sezm.py b/source/tests/pt/model/test_descriptor_sezm.py index 093ef2b8d0..161e8a87e6 100644 --- a/source/tests/pt/model/test_descriptor_sezm.py +++ b/source/tests/pt/model/test_descriptor_sezm.py @@ -1,7 +1,17 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import itertools import math +import os import unittest +from contextlib import ( + nullcontext, +) +from types import ( + SimpleNamespace, +) +from unittest import ( + mock, +) import torch @@ -16,6 +26,7 @@ NodeCartesianTensorProduct, SeZMDirectForceHead, SO2Linear, + SpinEmbedding, WignerDCalculator, build_cartesian_basis, build_edge_cartesian_tensors, @@ -160,6 +171,37 @@ def _assert_forward_backward_smoke(self, **model_kwargs) -> DescrptSeZM: self.assertTrue(torch.all(torch.isfinite(extended_coord.grad))) return model + def test_amp_infer_env_controls_eval_autocast(self) -> None: + """Inference AMP is sampled from env and still gated by ``use_amp``.""" + with mock.patch.dict(os.environ, {"DP_AMP_INFER": "1"}, clear=False): + enabled_model = DescrptSeZM(**_descriptor_kwargs(use_amp=True)) + disabled_model = DescrptSeZM(**_descriptor_kwargs(use_amp=False)) + + enabled_model.eval() + disabled_model.eval() + + with mock.patch("torch.autocast", return_value=nullcontext()) as autocast_mock: + with enabled_model._compute_mode_ctx(torch.device("cuda")): + pass + autocast_mock.assert_called_once_with( + device_type="cuda", + dtype=torch.bfloat16, + enabled=True, + ) + + with mock.patch("torch.autocast", return_value=nullcontext()) as autocast_mock: + with disabled_model._compute_mode_ctx(torch.device("cuda")): + pass + autocast_mock.assert_not_called() + + with mock.patch.dict(os.environ, {"DP_AMP_INFER": "0"}, clear=False): + default_model = DescrptSeZM(**_descriptor_kwargs(use_amp=True)) + default_model.eval() + with mock.patch("torch.autocast", return_value=nullcontext()) as autocast_mock: + with default_model._compute_mode_ctx(torch.device("cuda")): + pass + autocast_mock.assert_not_called() + def test_cartesian_config_wiring(self) -> None: """Each Cartesian/mixing config builds the intended submodules. @@ -284,6 +326,112 @@ def test_so3_readout_empty_edge_shrinking_schedule(self) -> None: self.assertEqual(desc.shape, (1, 2, 4)) self.assertTrue(torch.all(torch.isfinite(desc))) + def test_zero_block_descriptor(self) -> None: + """``n_blocks=0`` builds the interaction-free descriptor end to end. + + Covers the schedule collapse (empty schedules, backbone degrees falling + back to ``lmax`` plus ``extra_node_l``, no blocks), both forward entry + points, the conservative-force path, and serialization. The + coordinate-carrying paths (FiLM, GIE, read-out) are zero-initialized, so + a parameter perturbation activates them before the force check; + ``so3_readout`` glu/mlp fold the GIE ``l>0`` geometry into the scalar, + while ``none`` keeps only the env-seed scalar. + """ + dtype = torch.float32 + coord, atype, nlist = _tiny_two_atom_system(self.device, dtype=dtype) + atol, rtol = _forward_tols(dtype) + for readout, extra in (("mlp", 0), ("glu", 0), ("none", 0), ("mlp", 1)): + with self.subTest(so3_readout=readout, extra_node_l=extra): + model = DescrptSeZM( + **_descriptor_kwargs( + l_schedule=None, + n_blocks=0, + lmax=2, + mmax=1, + kmax=1, + extra_node_l=extra, + use_env_seed=True, + so3_readout=readout, + ) + ) + # Empty schedules; backbone degrees collapse onto lmax(+extra). + self.assertEqual(model.n_blocks, 0) + self.assertEqual(model.l_schedule, []) + self.assertEqual(model.m_schedule, []) + self.assertEqual(len(model.blocks), 0) + self.assertEqual(model.node_init_lmax, 2 + extra) + self.assertEqual(model.node_readout_lmax, 2 + extra) + self.assertTrue(model.use_gie) + + # Activate the zero-initialized geometry paths so the force check + # exercises genuine coordinate dependence rather than the + # all-zero gradient seen at initialization. + torch.manual_seed(0) + with torch.no_grad(): + for param in model.parameters(): + param.add_(torch.randn_like(param) * 0.2) + + extended_coord = coord.reshape(1, -1).detach().requires_grad_(True) + desc, *_ = model(extended_coord, atype, nlist) + self.assertEqual(desc.shape, (1, 2, 4)) + self.assertTrue(torch.all(torch.isfinite(desc))) + desc.sum().backward() + self.assertTrue(torch.all(torch.isfinite(extended_coord.grad))) + self.assertGreater(extended_coord.grad.abs().max().item(), 0.0) + + # Serialization restores the empty schedule and reproduces output. + data = model.serialize() + self.assertEqual(data["config"]["n_blocks"], 0) + self.assertEqual(data["config"]["l_schedule"], []) + restored = DescrptSeZM.deserialize(data) + desc2, *_ = restored(coord.reshape(1, -1), atype, nlist) + torch.testing.assert_close(desc.detach(), desc2, atol=atol, rtol=rtol) + + # Sparse-edge entry point: the latent keeps the initial degree + # ``(node_init_lmax + 1) ** 2`` since no pyramid shrinks it. + flat = coord.reshape(-1, 3) + edge_index = torch.tensor( + [[1, 0], [0, 1]], dtype=torch.long, device=self.device + ) + edge_vec = flat[edge_index[0]] - flat[edge_index[1]] + edge_mask = torch.ones(2, dtype=torch.bool, device=self.device) + desc_e, latent = model.forward_with_edges( + extended_coord=coord.reshape(1, -1), + extended_atype=atype, + edge_index=edge_index, + edge_vec=edge_vec, + edge_mask=edge_mask, + ) + self.assertEqual(desc_e.shape, (1, 2, 4)) + self.assertEqual(latent.shape, (2, (2 + extra + 1) ** 2, 1, 4)) + + def test_zero_block_without_env_seed_is_geometry_free(self) -> None: + """Without env-seed the zero-block descriptor loses all geometry. + + With no blocks, no env FiLM, and no GIE, the scalar output depends only + on the type embedding and is independent of the coordinates. This + documents that ``use_env_seed=True`` is required for a meaningful + zero-block descriptor, since the single ``use_env_seed`` switch gates the + only geometry source (GIE). + """ + coord, atype, nlist = _tiny_two_atom_system(self.device, dtype=torch.float32) + model = DescrptSeZM( + **_descriptor_kwargs( + l_schedule=None, + n_blocks=0, + lmax=2, + use_env_seed=False, + so3_readout="mlp", + ) + ) + self.assertFalse(model.use_gie) + extended_coord = coord.reshape(1, -1).detach().requires_grad_(True) + desc, *_ = model(extended_coord, atype, nlist) + self.assertTrue(torch.all(torch.isfinite(desc))) + # The descriptor does not depend on coordinates, so there is no force path. + (grad,) = torch.autograd.grad(desc.sum(), extended_coord, allow_unused=True) + self.assertIsNone(grad) + def test_forward_with_descriptor_variants(self) -> None: """Test forward/backward smoke paths for compact descriptor variants.""" cases = { @@ -616,6 +764,15 @@ def test_serialization_deserialization(self) -> None: radial_mlp=[6], ffn_neurons=8, ), + "native_spin": _descriptor_kwargs( + precision="float32", + channels=4, + n_radial=3, + radial_mlp=[6], + ffn_neurons=8, + use_spin=[True, False], + use_env_seed=True, + ), } dtype = PRECISION_DICT["float32"] for case_name, model_kwargs in cases.items(): @@ -751,6 +908,190 @@ def test_seed_reproducibility(self) -> None: ) +class TestSeZMSpinEmbedding(_SeZMTestCase): + """Test the native per-atom spin embedding and its descriptor injection.""" + + def _spin_edges( + self, dtype: torch.dtype + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """A two-atom edge system with one magnetic and one non-magnetic atom.""" + coord = torch.tensor( + [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], dtype=dtype, device=self.device + ).view(1, -1, 3) + atype = torch.tensor([[0, 1]], dtype=torch.int64, device=self.device) + edge_index = torch.tensor( + [[1, 0], [0, 1]], dtype=torch.long, device=self.device + ) + edge_vec = torch.tensor( + [[1.0, 0.0, 0.0], [-1.0, 0.0, 0.0]], dtype=dtype, device=self.device + ) + edge_mask = torch.ones(2, dtype=torch.bool, device=self.device) + return coord, atype, edge_index, edge_vec, edge_mask + + def test_cart_to_l1_intertwines_wigner_rotation(self) -> None: + """The l=1 spin map must rotate with the descriptor's Wigner-D block.""" + dtype = torch.float64 + module = SpinEmbedding( + ntypes=2, channels=4, use_spin=[True, False], dtype=dtype, seed=1 + ).to(self.device) + cart_to_l1 = module.cart_to_l1 + # cart_to_l1 == sqrt(2) * S with S the unit l=1 basis of WignerDCalculator. + expected = math.sqrt(2.0) * torch.tensor( + [[0.0, -1.0, 0.0], [0.0, 0.0, -1.0], [1.0, 0.0, 0.0]], + dtype=dtype, + device=self.device, + ) + torch.testing.assert_close(cart_to_l1, expected, atol=1e-12, rtol=1e-12) + + quat = _random_quaternion(5, device=self.device, dtype=dtype) + rot = quaternion_to_rotation_matrix(quat) # (5, 3, 3) + wigner = WignerDCalculator(lmax=1, dtype=dtype).to(self.device) + d_full, _ = wigner(quat) + d1 = d_full[:, 1:4, 1:4] # (5, 3, 3) + vec = torch.randn(5, 3, dtype=dtype, device=self.device) + # cart_to_l1(R v) == D1 (cart_to_l1 v) + lhs = torch.einsum( + "dk,bk->bd", cart_to_l1, torch.einsum("bij,bj->bi", rot, vec) + ) + rhs = torch.einsum("bij,bj->bi", d1, torch.einsum("dk,bk->bd", cart_to_l1, vec)) + torch.testing.assert_close(lhs, rhs, atol=1e-10, rtol=1e-10) + + def test_spin_embedding_smooth_and_masked(self) -> None: + """l=1 is linear in spin (zero at s=0); non-magnetic types are gated off.""" + dtype = torch.float64 + module = SpinEmbedding( + ntypes=2, channels=4, use_spin=[True, False], dtype=dtype, seed=2 + ).to(self.device) + atype = torch.tensor([0, 0, 1], device=self.device) + spin = torch.randn(3, 3, dtype=dtype, device=self.device) + + # Linear in spin: vector(0) == 0 and vector(2 s) == 2 vector(s). + zero_scalar, zero_vec = module(torch.zeros_like(spin), atype) + self.assertTrue(torch.allclose(zero_vec, torch.zeros_like(zero_vec))) + _, vec1 = module(spin, atype) + _, vec2 = module(2.0 * spin, atype) + torch.testing.assert_close(vec2, 2.0 * vec1, atol=1e-12, rtol=1e-12) + + # Non-magnetic type (index 1) is gated to exactly zero on both branches. + scalar, vector = module(spin, atype) + self.assertTrue(torch.allclose(scalar[2], torch.zeros_like(scalar[2]))) + self.assertTrue(torch.allclose(vector[2], torch.zeros_like(vector[2]))) + + # The l=0 magnitude branch is smooth at s=0: its spin gradient vanishes. + spin_leaf = torch.zeros(2, 3, dtype=dtype, device=self.device).requires_grad_( + True + ) + scalar0, _ = module(spin_leaf, torch.tensor([0, 0], device=self.device)) + (grad,) = torch.autograd.grad(scalar0.sum(), spin_leaf) + self.assertTrue(torch.allclose(grad, torch.zeros_like(grad), atol=1e-12)) + + def test_edge_l1_equivariance_and_masking(self) -> None: + """The per-edge neighbor-spin l=1 message rotates as l=1, is linear and masked.""" + dtype = torch.float64 + module = SpinEmbedding( + ntypes=2, channels=4, use_spin=[True, False], dtype=dtype, seed=3 + ).to(self.device) + # Perturb the per-type neighbor weight off zero so the map is non-trivial. + with torch.no_grad(): + module.adam_spin_nbr_weight.copy_( + torch.randn_like(module.adam_spin_nbr_weight) + ) + atype = torch.tensor([0, 0, 1], device=self.device) # node 2 is non-magnetic + # Edges 0, 1 have magnetic sources; edge 2's source (node 2) is not. + src = torch.tensor([0, 1, 2], dtype=torch.long, device=self.device) + edge_cache = SimpleNamespace( + src=src, + edge_env=torch.rand(3, 1, dtype=dtype, device=self.device) + 0.1, + ) + spin = torch.randn(3, 3, dtype=dtype, device=self.device) + + msg = module.edge_l1(spin, atype, edge_cache) # (E, 3, C) + self.assertEqual(msg.shape, (3, 3, module.channels)) + + # Linear in spin: zero at s=0 and homogeneous of degree one. + zero = module.edge_l1(torch.zeros_like(spin), atype, edge_cache) + self.assertTrue(torch.allclose(zero, torch.zeros_like(zero))) + msg2 = module.edge_l1(2.0 * spin, atype, edge_cache) + torch.testing.assert_close(msg2, 2.0 * msg, atol=1e-12, rtol=1e-12) + + # The non-magnetic source (edge 2) contributes exactly zero. + torch.testing.assert_close(msg[2], torch.zeros_like(msg[2])) + self.assertGreater(msg[0].abs().max().item(), 0.0) + + # l=1 equivariance: rotating the spin rotates each edge message by D^1(R), + # the packed Wigner-D block conjugate to cart_to_l1. + quat = _random_quaternion(1, device=self.device, dtype=dtype) + rot = quaternion_to_rotation_matrix(quat)[0] # (3, 3) + d1 = module.cart_to_l1 @ rot @ torch.linalg.inv(module.cart_to_l1) + msg_rot = module.edge_l1( + torch.einsum("ij,nj->ni", rot, spin), atype, edge_cache + ) + expected = torch.einsum("de,nec->ndc", d1, msg) + torch.testing.assert_close(msg_rot, expected, atol=1e-10, rtol=1e-10) + + def test_descriptor_spin_joint_rotation_invariance(self) -> None: + """The scalar descriptor is invariant under a joint rotation of geometry and spin. + + Covers both spin injection routes by toggling ``use_env_seed``: the + env-seed branch adds the neighbor spin to the environment matrix, while + the backbone branch carries the on-site and neighbor-aggregated l=1. + """ + dtype = torch.float64 + coord, atype, edge_index, edge_vec, edge_mask = self._spin_edges(dtype) + spin = torch.zeros(1, 2, 3, dtype=dtype, device=self.device) + spin[0, 0] = torch.tensor([0.3, -0.7, 0.5], dtype=dtype, device=self.device) + quat = _random_quaternion(1, device=self.device, dtype=dtype) + rot = quaternion_to_rotation_matrix(quat)[0] # (3, 3) + + for use_env_seed in (False, True): + with self.subTest(use_env_seed=use_env_seed): + model = DescrptSeZM( + **_descriptor_kwargs( + precision="float64", + use_spin=[True, False], + use_env_seed=use_env_seed, + seed=7, + ) + ) + # Perturb away from the near-identity initialization so the + # rotation check exercises a non-trivial spin-dependent + # descriptor (env-seed output_proj is otherwise zero-init). + torch.manual_seed(0) + with torch.no_grad(): + for p in model.parameters(): + p.copy_(torch.randn_like(p) * 0.1) + model.eval() + + desc, _ = model.forward_with_edges( + extended_coord=coord, + extended_atype=atype, + edge_index=edge_index, + edge_vec=edge_vec, + edge_mask=edge_mask, + spin=spin, + ) + desc_rot, _ = model.forward_with_edges( + extended_coord=coord, + extended_atype=atype, + edge_index=edge_index, + edge_vec=torch.einsum("ij,ej->ei", rot, edge_vec), + edge_mask=edge_mask, + spin=torch.einsum("ij,nkj->nki", rot, spin), + ) + torch.testing.assert_close(desc, desc_rot, atol=1e-9, rtol=1e-9) + + # Spin actually changes the descriptor (injection is not a no-op). + desc_zero, _ = model.forward_with_edges( + extended_coord=coord, + extended_atype=atype, + edge_index=edge_index, + edge_vec=edge_vec, + edge_mask=edge_mask, + spin=torch.zeros_like(spin), + ) + self.assertFalse(torch.allclose(desc, desc_zero, atol=1e-6)) + + class TestBuildEdgeQuaternion(_SeZMTestCase): """Test the stable edge-quaternion chart used by SeZM.""" @@ -1160,16 +1501,18 @@ def test_equivariance_random_angles(self) -> None: ) dim_red = so2_linear.reduced_dim + # SO2Linear consumes the focus-major ``(F, E, D_m, C)`` contract, so + # the edge axis (the per-edge z-rotation batch) is axis 1. x = torch.randn( - batch, 1, dim_red, channels_in, device=self.device, dtype=dtype + 1, batch, dim_red, channels_in, device=self.device, dtype=dtype ) angles = torch.rand(batch, device=self.device, dtype=dtype) * 2 * 3.14159 Z = self._build_m_major_z_rotation(angles, lmax, mmax) - x_rotated = torch.einsum("bij,bfjc->bfic", Z, x) + x_rotated = torch.einsum("eij,fejc->feic", Z, x) lhs = so2_linear(x_rotated) - rhs = torch.einsum("bij,bfjc->bfic", Z, so2_linear(x)) + rhs = torch.einsum("eij,fejc->feic", Z, so2_linear(x)) torch.testing.assert_close( lhs, diff --git a/source/tests/pt/model/test_descriptor_sezm_triton.py b/source/tests/pt/model/test_descriptor_sezm_triton.py index e74c31d72a..4a1b77616f 100644 --- a/source/tests/pt/model/test_descriptor_sezm_triton.py +++ b/source/tests/pt/model/test_descriptor_sezm_triton.py @@ -1,7 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later """Unit tests for the opt-in Triton inference kernels of the SeZM descriptor -(enabled via ``DP_TRITON_INFER``): the block-diagonal SO(2)/Wigner rotation and -the fused dynamic radial degree mixer. +(enabled via the ``DP_TRITON_INFER`` level, see +:func:`deepmd.kernels.utils.triton_infer_level`): the +block-diagonal SO(2)/Wigner rotation, the fused dynamic radial degree mixer, +the fused value path with its table-routed edge-block backwards, and the +level-3 fp16x3 mixing stack. For the rotation kernels two properties are checked against the eager PyTorch reference: @@ -20,6 +23,7 @@ """ import math +import typing import unittest import torch @@ -27,20 +31,14 @@ make_fx, ) -from deepmd.pt.model.descriptor.sezm_nn.indexing import ( - build_m_major_index, - get_so3_dim_of_lmax, +from deepmd.kernels.triton.sezm import ( + TRITON_AVAILABLE, ) -from deepmd.pt.model.descriptor.sezm_nn.so2 import ( - DynamicRadialDegreeMixer, -) -from deepmd.pt.model.descriptor.sezm_nn.triton.radial_mix import ( - RADIAL_MIX_TRITON_AVAILABLE, +from deepmd.kernels.triton.sezm.radial_mix import ( radial_mix_block, radial_mix_reference, ) -from deepmd.pt.model.descriptor.sezm_nn.triton.so2_rotation import ( - TRITON_ROTATION_AVAILABLE, +from deepmd.kernels.triton.sezm.so2_rotation import ( rotate_back_block, rotate_back_block_so2, rotate_back_dense, @@ -49,9 +47,24 @@ rotate_to_local_dense, rotate_to_local_reference, ) +from deepmd.pt.model.descriptor.sezm_nn.indexing import ( + build_m_major_index, + get_so3_dim_of_lmax, +) +from deepmd.pt.model.descriptor.sezm_nn.so2 import ( + DynamicRadialDegreeMixer, +) _CUDA = torch.cuda.is_available() +# All SeZM Triton kernels ship together (every per-module availability flag +# reduces to "the triton package imports"), so one gate covers every +# GPU-kernel test class. Dispatch- and parsing-only classes stay ungated. +_GPU_KERNELS = unittest.skipUnless( + _CUDA and TRITON_AVAILABLE, + "CUDA and Triton are required for the SeZM GPU kernels", +) + def _block_diagonal_wigner(n_edge, lmax, device, dtype, generator): """Random Wigner-D that is block-diagonal by ``l`` (block ``l`` occupies @@ -129,8 +142,7 @@ def fn(x, src, wigner, coeff_index): self.assertNotIn("sezm_triton.rotate_to_local_block.default", graph_code) -@unittest.skipIf(not _CUDA, "CUDA is required for the Triton rotation kernels") -@unittest.skipIf(not TRITON_ROTATION_AVAILABLE, "Triton is not available") +@_GPU_KERNELS class TestSeZMTritonRotation(unittest.TestCase): def setUp(self): self.device = torch.device("cuda") @@ -461,8 +473,7 @@ def forward_and_grad(x_local, wigner): self.assertGreater(float(grad_w_eager.abs().max()), 0.0) -@unittest.skipIf(not _CUDA, "CUDA is required for the Triton radial-mix kernel") -@unittest.skipIf(not RADIAL_MIX_TRITON_AVAILABLE, "Triton is not available") +@_GPU_KERNELS class TestSeZMTritonRadialMix(unittest.TestCase): """Fused dynamic radial degree mixer (``degree_channel``, ``mmax == 1``). @@ -593,5 +604,947 @@ def forward_and_grad(compact, x_local): self.assertIn("sezm_triton.radial_mix_block", traced.code) +@_GPU_KERNELS +class TestSeZMTritonValuePath(unittest.TestCase): + """Cross-check the fused SO(2) value path against ``SO2Convolution``. + + The fused entry replaces ``so2_message(..., return_local=True)`` + end to end (rotation, radial degree mixing, gated mixing stack, focus + competition), so the reference is the module's own eager path. Cases + span the supported family axes: mixer-free ``lmax = 1``, the rank-1 + ``degree_channel`` mixer at small and large ``lmax``, a rank-2 mixer with + two focus streams, and a non-power-of-two focus width. + """ + + CASES: typing.ClassVar[list[tuple]] = [ + # (lmax, channels, n_focus, focus_dim, mixing_layers, mode, rank) + (1, 32, 1, 0, 3, "none", 0), + (2, 32, 1, 0, 3, "degree_channel", 1), + (4, 64, 1, 0, 4, "degree_channel", 1), + (5, 64, 2, 0, 4, "degree_channel", 2), + (3, 64, 2, 96, 4, "degree_channel", 1), + ] + + N_NODE = 512 + N_EDGE = 20000 + + def _build_conv(self, lmax, channels, n_focus, focus_dim, layers, mode, rank): + from deepmd.pt.model.descriptor.sezm_nn.so2 import ( + SO2Convolution, + ) + + return ( + SO2Convolution( + lmax=lmax, + mmax=1, + channels=channels, + n_focus=n_focus, + focus_dim=focus_dim, + mixing_layers=layers, + radial_so2_mode=mode, + radial_so2_rank=rank, + n_atten_head=1, + dtype=torch.float32, + seed=7, + trainable=False, + ) + .to("cuda") + .eval() + ) + + def _edge_inputs(self, conv, lmax, channels): + from deepmd.pt.model.descriptor.sezm_nn.wignerd import ( + WignerDCalculator, + ) + + generator = torch.Generator(device="cuda").manual_seed(42) + dim = (lmax + 1) ** 2 + c_wide = conv.n_focus * conv.so2_focus_dim + x = ( + torch.randn(self.N_NODE, dim, c_wide, device="cuda", generator=generator) + * 0.5 + ) + src = torch.randint( + 0, self.N_NODE, (self.N_EDGE,), device="cuda", generator=generator + ) + radial = ( + torch.randn( + self.N_EDGE, lmax + 1, channels, device="cuda", generator=generator + ) + * 0.3 + ) + calculator = WignerDCalculator(lmax, dtype=torch.float32).to("cuda") + quaternion = torch.randn(self.N_EDGE, 4, device="cuda", generator=generator) + quaternion = quaternion / quaternion.norm(dim=-1, keepdim=True) + wigner = calculator(quaternion)[0] + + class _Cache: + pass + + cache = _Cache() + cache.src = src + cache.dst = torch.randint( + 0, self.N_NODE, (self.N_EDGE,), device="cuda", generator=generator + ) + cache.D_full = wigner + cache.D_to_m_cache = {} + return x, cache, radial + + def test_forward_backward_matches_reference_across_family(self): + from deepmd.kernels.triton.sezm.so2_value_path import ( + make_triton_value_path, + ) + + for case in self.CASES: + lmax, channels, n_focus, focus_dim, layers, mode, rank = case + with self.subTest(case=case): + conv = self._build_conv(*case) + fused = make_triton_value_path(conv) + self.assertIsNotNone(fused) + x, cache, radial = self._edge_inputs(conv, lmax, channels) + + x_ref = x.clone().requires_grad_(True) + rad_ref = radial.clone().requires_grad_(True) + wigner_ref = cache.D_full.clone().requires_grad_(True) + cache.D_full = wigner_ref + ref_local, _ = conv.so2_message( + x_ref, cache, rad_ref, return_local=True + ) + + x_fused = x.clone().requires_grad_(True) + rad_fused = radial.clone().requires_grad_(True) + wigner_fused = wigner_ref.detach().clone().requires_grad_(True) + cache.D_full = wigner_fused + out_local, _ = fused(x_fused, cache, rad_fused) + + scale = ref_local.abs().max().item() + torch.testing.assert_close( + out_local, ref_local, atol=5e-5 * max(scale, 1.0), rtol=1e-4 + ) + + grad_seed = torch.randn_like(ref_local) + ref_grads = torch.autograd.grad( + ref_local, [x_ref, rad_ref, wigner_ref], grad_seed + ) + fused_grads = torch.autograd.grad( + out_local, [x_fused, rad_fused, wigner_fused], grad_seed + ) + # The Wigner gradient is compared on the structural block + # diagonal only: off-block entries multiply exactly-zero + # Wigner values, so the model discards them. + mask = _block_mask(lmax, "cuda") + comparisons = [ + (ref_grads[0], fused_grads[0]), + (ref_grads[1], fused_grads[1]), + (ref_grads[2] * mask, fused_grads[2] * mask), + ] + for ref_grad, fused_grad in comparisons: + grad_scale = ref_grad.abs().max().item() + torch.testing.assert_close( + fused_grad, + ref_grad, + atol=1e-4 * max(grad_scale, 1.0), + rtol=1e-4, + ) + + def test_factory_rejects_unsupported_layouts(self): + from deepmd.kernels.triton.sezm.so2_value_path import ( + make_triton_value_path, + ) + + conv = self._build_conv(2, 32, 1, 0, 3, "degree_channel", 1) + conv.mmax = 2 + self.assertIsNone(make_triton_value_path(conv)) + + +@_GPU_KERNELS +class TestSeZMTritonWignerMonomials(unittest.TestCase): + """Check the fused quaternion monomial operator against its reference.""" + + def _exponents(self, degree): + """All exponent 4-tuples of the given total degree, flattened.""" + exps = [] + for a in range(degree + 1): + for b in range(degree + 1 - a): + for c in range(degree + 1 - a - b): + exps.extend((a, b, c, degree - a - b - c)) + return exps + + def test_forward_backward_matches_reference(self): + from deepmd.kernels.triton.sezm.wigner_monomials import ( + _monomials_reference, + wigner_monomials, + ) + + generator = torch.Generator(device="cuda").manual_seed(3) + for degree in (4, 6, 8, 10, 12): + with self.subTest(degree=degree): + exponents = self._exponents(degree) + q = torch.randn(4096, 4, device="cuda", generator=generator) + q = q / q.norm(dim=-1, keepdim=True) + + q_fused = q.clone().requires_grad_(True) + out = wigner_monomials(q_fused, exponents, degree) + q_ref = q.clone().requires_grad_(True) + ref = _monomials_reference(q_ref, exponents, degree) + torch.testing.assert_close(out, ref, atol=1e-6, rtol=1e-5) + + grad_seed = torch.randn_like(ref) + (grad_fused,) = torch.autograd.grad(out, q_fused, grad_seed) + (grad_ref,) = torch.autograd.grad(ref, q_ref, grad_seed) + torch.testing.assert_close(grad_fused, grad_ref, atol=1e-5, rtol=1e-5) + + def test_wigner_calculator_matches_reference_chain(self): + """The calculator's fused monomial path reproduces the dense chain.""" + import os + from unittest import ( + mock, + ) + + from deepmd.pt.model.descriptor.sezm_nn import wignerd as wignerd_module + + generator = torch.Generator(device="cuda").manual_seed(11) + q = torch.randn(2048, 4, device="cuda", generator=generator) + q = q / q.norm(dim=-1, keepdim=True) + for lmax in (2, 3, 4, 5): + with self.subTest(lmax=lmax): + with mock.patch.dict(os.environ, {"DP_TRITON_INFER": "1"}): + fused_calc = ( + wignerd_module.WignerDCalculator(lmax, dtype=torch.float32) + .to("cuda") + .eval() + ) + with mock.patch.dict(os.environ, {"DP_TRITON_INFER": "0"}): + ref_calc = ( + wignerd_module.WignerDCalculator(lmax, dtype=torch.float32) + .to("cuda") + .eval() + ) + self.assertTrue(fused_calc._use_triton_monomials) + got = fused_calc(q)[0] + want = ref_calc(q)[0] + torch.testing.assert_close(got, want, atol=1e-5, rtol=1e-5) + + +@_GPU_KERNELS +class TestSeZMTritonForceAssembly(unittest.TestCase): + """Check the segmented force / virial assembly against ``index_add``.""" + + def _topology(self, n_edge, n_ext, device, generator): + dst = torch.randint(0, n_ext, (n_edge,), device=device, generator=generator) + src = torch.randint(0, n_ext, (n_edge,), device=device, generator=generator) + dst_order = torch.argsort(dst) + src_order = torch.argsort(src) + boundaries = torch.arange(n_ext + 1, device=device, dtype=dst.dtype) + dst_row_ptr = torch.searchsorted(dst.index_select(0, dst_order), boundaries) + src_row_ptr = torch.searchsorted(src.index_select(0, src_order), boundaries) + return dst, src, dst_order, dst_row_ptr, src_order, src_row_ptr + + def test_matches_index_add_assembly(self): + from deepmd.kernels.triton.sezm.force_assembly import ( + edge_force_assembly, + ) + + generator = torch.Generator(device="cuda").manual_seed(5) + n_edge, n_ext = 50000, 700 + g = torch.randn(n_edge, 3, device="cuda", generator=generator) + edge_vec = torch.randn(n_edge, 3, device="cuda", generator=generator) + dst, src, dst_order, dst_row_ptr, src_order, src_row_ptr = self._topology( + n_edge, n_ext, "cuda", generator + ) + + force, virial = edge_force_assembly( + g, edge_vec, dst_order, dst_row_ptr, src_order, src_row_ptr + ) + + force_ref = torch.zeros(n_ext, 3, device="cuda") + force_ref.index_add_(0, dst, g) + force_ref.index_add_(0, src, -g) + half_w = -0.5 * torch.einsum("ek,ej->ekj", g, edge_vec).reshape(-1, 9) + virial_ref = torch.zeros(n_ext, 9, device="cuda") + virial_ref.index_add_(0, dst, half_w) + virial_ref.index_add_(0, src, half_w) + + torch.testing.assert_close(force, force_ref, atol=1e-4, rtol=1e-5) + torch.testing.assert_close(virial, virial_ref, atol=1e-4, rtol=1e-5) + + +@_GPU_KERNELS +class TestSeZMTritonFlashAttenSegmented(unittest.TestCase): + """Check the destination-segmented flash forward against the reference. + + Destinations are deliberately unsorted: the traced SeZM graph keeps + masked padding edges in arbitrary destination order, so the operator must + build its own sorted CSR topology (a sorted-input-only regression once + produced silently wrong aggregates on the compiled path). + """ + + def test_forward_matches_reference_on_unsorted_destinations(self): + from deepmd.kernels.triton.sezm.flash_atten import ( + build_row_ptr, + flash_atten_aggregate, + flash_atten_aggregate_reference, + ) + + generator = torch.Generator(device="cuda").manual_seed(9) + lmax, n_focus, focus_dim, n_head = 3, 2, 32, 2 + n_edge, n_node = 30000, 400 + reduced_dim = 3 * lmax + 1 + dim = (lmax + 1) ** 2 + + x_local = torch.randn( + n_edge, n_focus, reduced_dim, focus_dim, device="cuda", generator=generator + ) + wigner_dt = _block_diagonal_wigner( + n_edge, lmax, "cuda", torch.float32, generator + ) + rescale = torch.rand(dim, device="cuda", generator=generator) + 0.5 + alpha = torch.rand(n_edge, n_focus, n_head, device="cuda", generator=generator) + dst = torch.randint(0, n_node, (n_edge,), device="cuda", generator=generator) + row_ptr = build_row_ptr(torch.sort(dst).values, n_node) + + got = flash_atten_aggregate( + x_local, wigner_dt, rescale, alpha, row_ptr, dst, lmax, n_head + ) + want = flash_atten_aggregate_reference( + x_local, wigner_dt, rescale, alpha, dst, n_node, lmax, n_head + ) + torch.testing.assert_close(got, want, atol=1e-4, rtol=1e-5) + + def test_backward_matches_reference_on_both_dispatch_paths(self): + """The backward is exact on the edge-block and the per-edge dispatch. + + The routing table keys on ``(C_wide, lmax)``. Both routes are pinned + through explicit runtime registrations (an entry for the narrow case, + ``None`` for the wide one) so the test exercises both kernels + regardless of the built-in coverage of the running GPU. + """ + from deepmd.kernels.triton.sezm import ( + tile_configs, + ) + from deepmd.kernels.triton.sezm.flash_atten import ( + _flash_atten_backward_reference, + _flash_bwd_op, + ) + + runtime = tile_configs._RUNTIME["flash_bwd_block"] + saved = dict(runtime) + self.addCleanup(lambda: (runtime.clear(), runtime.update(saved))) + tile_configs.register_tile_configs( + "flash_bwd_block", {(64, 3): (4, 2, 2), (256, 3): None} + ) + + generator = torch.Generator(device="cuda").manual_seed(13) + cases = [ + # (lmax, n_focus, focus_dim, expects_block_dispatch) + (3, 2, 32, True), + (3, 2, 128, False), + ] + for lmax, n_focus, focus_dim, expects_block in cases: + with self.subTest(lmax=lmax, c_wide=n_focus * focus_dim): + self.assertEqual( + tile_configs.flash_bwd_block_config(n_focus * focus_dim, lmax) + is not None, + expects_block, + ) + n_edge, n_node, n_head = 20000, 300, 1 + reduced_dim = 3 * lmax + 1 + dim = (lmax + 1) ** 2 + grad_pre_gate = torch.randn( + n_node, dim, n_focus * focus_dim, device="cuda", generator=generator + ) + x_local = torch.randn( + n_edge, + n_focus, + reduced_dim, + focus_dim, + device="cuda", + generator=generator, + ) + wigner_dt = _block_diagonal_wigner( + n_edge, lmax, "cuda", torch.float32, generator + ) + rescale = torch.rand(dim, device="cuda", generator=generator) + 0.5 + alpha = torch.rand( + n_edge, n_focus, n_head, device="cuda", generator=generator + ) + dst = torch.randint( + 0, n_node, (n_edge,), device="cuda", generator=generator + ) + + got = _flash_bwd_op( + grad_pre_gate, x_local, wigner_dt, rescale, alpha, dst, lmax, n_head + ) + want = _flash_atten_backward_reference( + grad_pre_gate, x_local, wigner_dt, rescale, alpha, dst, lmax, n_head + ) + mask = _block_mask(lmax, "cuda") + comparisons = [ + (got[0], want[0]), + (got[1] * mask, want[1] * mask), + (got[2], want[2]), + ] + for got_grad, want_grad in comparisons: + scale = want_grad.abs().max().item() + torch.testing.assert_close( + got_grad, + want_grad, + atol=1e-4 * max(scale, 1.0), + rtol=1e-4, + ) + + +class TestTritonInferLevel(unittest.TestCase): + """Parse and reject semantics of the ``DP_TRITON_INFER`` level.""" + + def test_levels_parse_and_non_numeric_values_are_rejected(self): + import os + from unittest import ( + mock, + ) + + from deepmd.kernels.utils import ( + triton_infer_level, + ) + + for raw, expected in (("0", 0), ("1", 1), ("2", 2), ("3", 3), (" 2 ", 2)): + with mock.patch.dict(os.environ, {"DP_TRITON_INFER": raw}): + self.assertEqual(triton_infer_level(), expected) + with mock.patch.dict(os.environ, clear=False): + os.environ.pop("DP_TRITON_INFER", None) + self.assertEqual(triton_infer_level(), 0) + for raw in ("on", "off", "true", "false", "yes", "4", "-1"): + with ( + mock.patch.dict(os.environ, {"DP_TRITON_INFER": raw}), + self.assertRaises(ValueError), + ): + triton_infer_level() + + +@_GPU_KERNELS +class TestSeZMStackFP16x3(unittest.TestCase): + """Correctness of the level-3 fp16x3 mixing-stack operator. + + The accuracy contract is relative: the fp16x3 error against an fp64 + reference must stay within a small factor of the fp32 operator's own + rounding error on identical data (absolute thresholds mis-fire because + fp32 rounding grows with the reduction width). Finiteness is asserted + over the full tensors and across input magnitude scales spanning the + documented dynamic-range protections. + """ + + LMAX = 3 + FOCUS_DIM = 32 + N_FOCUS = 2 + N_LAYERS = 3 + N_EDGE = 20000 + + def setUp(self) -> None: + """Pin launch configurations so the tests run on any CUDA device. + + The fp16x3 operator refuses shapes without a table entry, and the + built-in tables only cover swept GPU models. On uncovered devices a + conservative single-stage configuration is registered for the test + shapes: ``num_stages = 1`` disables the software pipeliner, which is + the component whose miscompilation the validated tables guard + against, and the fp64 comparisons of this class independently verify + the numerics on whatever device runs the suite. + """ + from deepmd.kernels.triton.sezm import ( + tile_configs, + ) + + for key in ((self.FOCUS_DIM, self.LMAX), (32, 2)): + if tile_configs.stack_fp16x3_configs(*key) is None: + tile_configs.register_tile_configs( + "stack_fp16x3", {key: ((64, 64, 32, 4, 1),) * 4} + ) + self.addCleanup( + lambda key=key: tile_configs._RUNTIME["stack_fp16x3"].pop(key, None) + ) + + def _stack_inputs(self, generator): + lmax, focus_dim, n_focus = self.LMAX, self.FOCUS_DIM, self.N_FOCUS + m0 = (lmax + 1) * focus_dim + half = lmax * focus_dim + row = (3 * lmax + 1) * focus_dim + + def randn(*shape): + return torch.randn(*shape, device="cuda", generator=generator) + + u0 = randn(n_focus, self.N_EDGE, row) + alpha = ( + torch.rand(self.N_EDGE, n_focus, device="cuda", generator=generator) + 0.1 + ) + w0_all = randn(self.N_LAYERS, n_focus, m0, m0) * 0.2 + block_u = randn(self.N_LAYERS, n_focus, half, half) * 0.2 + block_v = randn(self.N_LAYERS, n_focus, half, half) * 0.2 + # The |m| = 1 weight carries the [[U, V], [-V, U]] complex structure + # of SO2Linear so the synthetic stack matches the production operator. + w1_all = torch.cat( + [ + torch.cat([block_u, block_v], dim=3), + torch.cat([-block_v, block_u], dim=3), + ], + dim=2, + ).contiguous() + gw_all = randn(self.N_LAYERS - 1, n_focus, focus_dim, half) * 0.3 + return u0, alpha, w0_all, w1_all, gw_all + + def _errors_against_fp64(self, op, u0, alpha, w0_all, w1_all, gw_all, grad_seed): + from deepmd.kernels.triton.sezm.so2_value_path import ( + _mixing_stack_reference, + ) + + u0_ref = u0.double().requires_grad_(True) + alpha_ref = alpha.double().requires_grad_(True) + x_ref, _ = _mixing_stack_reference( + u0_ref, + alpha_ref, + w0_all.double(), + w1_all.double(), + gw_all.double(), + self.LMAX, + self.FOCUS_DIM, + True, + ) + gu_ref, _ = torch.autograd.grad(x_ref, [u0_ref, alpha_ref], grad_seed.double()) + + u0_run = u0.clone().requires_grad_(True) + alpha_run = alpha.clone().requires_grad_(True) + x_run, z_run = op( + u0_run, alpha_run, w0_all, w1_all, gw_all, self.LMAX, self.FOCUS_DIM, True + ) + self.assertTrue(bool(torch.isfinite(x_run).all())) + self.assertTrue(bool(torch.isfinite(z_run).all())) + gu_run, _ = torch.autograd.grad(x_run, [u0_run, alpha_run], grad_seed) + self.assertTrue(bool(torch.isfinite(gu_run).all())) + + def relerr(a, b): + return float((a - b.float()).abs().max() / b.abs().max().clamp_min(1e-30)) + + return relerr(x_run, x_ref), relerr(gu_run, gu_ref) + + def test_matches_fp64_within_fp32_error_budget(self): + from deepmd.kernels.triton.sezm.so2_stack_fp16x3 import ( + mixing_stack_fp16x3, + ) + from deepmd.kernels.triton.sezm.so2_value_path import ( + _mixing_stack_op, + ) + + generator = torch.Generator(device="cuda").manual_seed(21) + inputs = self._stack_inputs(generator) + row = (3 * self.LMAX + 1) * self.FOCUS_DIM + grad_seed = torch.randn( + self.N_EDGE, self.N_FOCUS, row, device="cuda", generator=generator + ) + + fp32_fwd, fp32_bwd = self._errors_against_fp64( + _mixing_stack_op, *inputs, grad_seed + ) + x3_fwd, x3_bwd = self._errors_against_fp64( + mixing_stack_fp16x3, *inputs, grad_seed + ) + self.assertLess(x3_fwd, max(3.0 * fp32_fwd, 2e-6)) + self.assertLess(x3_bwd, max(3.0 * fp32_bwd, 8e-6)) + + def test_extreme_input_scales_stay_finite_and_accurate(self): + """The tail scaling and the activation prescale hold across magnitudes. + + Inputs four orders of magnitude below and two above the typical + working point must stay finite with the error budget intact; this + pins the ``2^11`` tail scaling (small magnitudes) and the ``2^-4`` + activation prescale (large magnitudes). + """ + from deepmd.kernels.triton.sezm.so2_stack_fp16x3 import ( + mixing_stack_fp16x3, + ) + from deepmd.kernels.triton.sezm.so2_value_path import ( + _mixing_stack_op, + ) + + generator = torch.Generator(device="cuda").manual_seed(22) + u0, alpha, w0_all, w1_all, gw_all = self._stack_inputs(generator) + row = (3 * self.LMAX + 1) * self.FOCUS_DIM + grad_seed = torch.randn( + self.N_EDGE, self.N_FOCUS, row, device="cuda", generator=generator + ) + for scale in (1e-4, 1e2): + with self.subTest(scale=scale): + scaled = (u0 * scale, alpha, w0_all, w1_all, gw_all) + fp32_fwd, fp32_bwd = self._errors_against_fp64( + _mixing_stack_op, *scaled, grad_seed + ) + x3_fwd, x3_bwd = self._errors_against_fp64( + mixing_stack_fp16x3, *scaled, grad_seed + ) + self.assertLess(x3_fwd, max(3.0 * fp32_fwd, 2e-6)) + self.assertLess(x3_bwd, max(3.0 * fp32_bwd, 8e-6)) + + def test_inductor_compiled_matches_eager(self): + """The Inductor-lowered operator is bitwise identical to eager. + + Guards the weight fp16 splits: the tail of a split is defined by an + ``fp32 -> fp16 -> fp32`` rounding round-trip, which Inductor's + pointwise fusion elides when the split is expressed in aten (the + intermediate stays in an fp32 register), zeroing the tails and + silently degrading the compiled operator to fp16-head weights. The + split therefore runs as a Triton kernel, and this test pins the + compiled-versus-eager parity through the same make_fx + Inductor + pipeline that model freezing uses. + """ + from torch._functorch.aot_autograd import ( + aot_module_simplified, + ) + from torch._inductor.compile_fx import ( + compile_fx_inner, + ) + from torch._inductor.decomposition import ( + select_decomp_table, + ) + from torch.fx.experimental.proxy_tensor import ( + make_fx, + ) + + from deepmd.kernels.triton.sezm.so2_stack_fp16x3 import ( + mixing_stack_fp16x3, + ) + + generator = torch.Generator(device="cuda").manual_seed(23) + inputs = self._stack_inputs(generator) + lmax, focus_dim = self.LMAX, self.FOCUS_DIM + + def fn(u0, alpha, w0_all, w1_all, gw_all): + x_local, z_all = mixing_stack_fp16x3( + u0, alpha, w0_all, w1_all, gw_all, lmax, focus_dim, True + ) + return (x_local, z_all) + + eager_x, eager_z = fn(*inputs) + graph = make_fx(fn, tracing_mode="symbolic")(*inputs) + # AOTAutograd's PhiloxStateTracker allocates tensors without an + # explicit device and would trip the pt-test default-device sentinel + # (source/tests/pt/__init__.py), so the sentinel is suspended here. + saved_device = torch.get_default_device() + torch.set_default_device(None) + try: + compiled = aot_module_simplified( + graph, + inputs, + fw_compiler=lambda gm, args: compile_fx_inner(gm, args), + decompositions=select_decomp_table(), + ) + with torch.no_grad(): + compiled_x, compiled_z = compiled(*inputs) + finally: + torch.set_default_device(saved_device) + torch.testing.assert_close(compiled_x, eager_x, atol=0.0, rtol=0.0) + torch.testing.assert_close(compiled_z, eager_z, atol=0.0, rtol=0.0) + + def test_dynamic_compile_survives_int32_stride_overflow_edge_counts(self): + """A graph traced on a small system must run beyond 2^31 / ROW edges. + + Triton specializes scalar kernel arguments to int32 when the first + compilation sees a small value, so any scalar that grows as + ``n_edge * ROW`` overflows the launcher on large systems (observed + at 1.1e7 edges, ``ROW = 224``). The mixing-stack kernels (fp32 and + fp16x3 alike) therefore derive their strides in-kernel from + constexpr layout flags; this test pins that contract for both stack + operators by tracing at 4e4 edges and running at 9.7e6 edges, where + ``n_edge * ROW`` exceeds int32. + """ + if torch.cuda.get_device_properties(0).total_memory < 60 * 2**30: + self.skipTest("requires ~40 GB of free device memory") + from deepmd.kernels.triton.sezm.so2_stack_fp16x3 import ( + mixing_stack_fp16x3, + ) + from deepmd.kernels.triton.sezm.so2_value_path import ( + _mixing_stack_op, + ) + + lmax, focus_dim, n_focus = 2, 32, 1 + generator = torch.Generator(device="cuda").manual_seed(29) + + def stack_inputs(n_edge): + m0, half = (lmax + 1) * focus_dim, lmax * focus_dim + row = (3 * lmax + 1) * focus_dim + u0 = torch.randn(n_focus, n_edge, row, device="cuda", generator=generator) + alpha = ( + torch.rand(n_edge, n_focus, device="cuda", generator=generator) + 0.1 + ) + w0 = torch.randn(3, n_focus, m0, m0, device="cuda", generator=generator) + bu = torch.randn(3, n_focus, half, half, device="cuda", generator=generator) + bv = torch.randn(3, n_focus, half, half, device="cuda", generator=generator) + w1 = torch.cat( + [torch.cat([bu, bv], 3), torch.cat([-bv, bu], 3)], 2 + ).contiguous() + gw = torch.randn( + 2, n_focus, focus_dim, half, device="cuda", generator=generator + ) + return u0, alpha, w0 * 0.2, w1 * 0.2, gw * 0.3 + + def make_fn(op): + def fn(u0, alpha, w0, w1, gw): + x_local, _ = op(u0, alpha, w0, w1, gw, lmax, focus_dim, True) + return (x_local,) + + return fn + + for stack_op in (mixing_stack_fp16x3, _mixing_stack_op): + with self.subTest(stack_op=getattr(stack_op, "__name__", str(stack_op))): + fn = make_fn(stack_op) + + small = stack_inputs(40000) + graph = make_fx(fn, tracing_mode="symbolic")(*small) + compiled = torch.compile(graph, backend="inductor", dynamic=True) + with torch.no_grad(): + compiled(*small) + big = stack_inputs(9_700_000) + out_big = compiled(*big)[0] + eager_big = fn(*big)[0] + torch.testing.assert_close(out_big, eager_big, atol=0.0, rtol=0.0) + del big, out_big, eager_big + torch.cuda.empty_cache() + + def test_value_path_selects_fp16x3_only_at_level_3(self): + """The stack operator selection follows the gate level and the table.""" + import os + from unittest import ( + mock, + ) + + from deepmd.kernels.triton.sezm.so2_stack_fp16x3 import ( + mixing_stack_fp16x3, + ) + from deepmd.kernels.triton.sezm.so2_value_path import ( + _mixing_stack_op, + make_triton_value_path, + ) + from deepmd.pt.model.descriptor.sezm_nn.so2 import ( + SO2Convolution, + ) + + def build_conv(): + return SO2Convolution( + lmax=self.LMAX, + mmax=1, + channels=self.FOCUS_DIM, + n_focus=self.N_FOCUS, + focus_dim=0, + mixing_layers=3, + radial_so2_mode="degree_channel", + radial_so2_rank=1, + n_atten_head=1, + dtype=torch.float32, + seed=7, + trainable=False, + ) + + with mock.patch.dict(os.environ, {"DP_TRITON_INFER": "3"}): + entry = make_triton_value_path(build_conv()) + self.assertIs(entry._stack_op, mixing_stack_fp16x3) + + with mock.patch.dict(os.environ, {"DP_TRITON_INFER": "2"}): + entry = make_triton_value_path(build_conv()) + self.assertIs(entry._stack_op, _mixing_stack_op) + + def test_unswept_shape_has_no_config_and_operator_refuses_it(self): + from deepmd.kernels.triton.sezm.so2_stack_fp16x3 import ( + mixing_stack_fp16x3, + ) + from deepmd.kernels.triton.sezm.tile_configs import ( + stack_fp16x3_configs, + ) + + self.assertIsNone(stack_fp16x3_configs(48, 3)) + row = (3 * 3 + 1) * 48 + u0 = torch.zeros(1, 8, row, device="cuda") + alpha = torch.ones(8, 1, device="cuda") + w0 = torch.zeros(3, 1, 4 * 48, 4 * 48, device="cuda") + w1 = torch.zeros(3, 1, 6 * 48, 6 * 48, device="cuda") + gw = torch.zeros(2, 1, 48, 3 * 48, device="cuda") + with self.assertRaises(RuntimeError): + mixing_stack_fp16x3(u0, alpha, w0, w1, gw, 3, 48, True) + + +class _TileConfigRuntimeIsolation(unittest.TestCase): + """Base fixture: snapshot and restore the process-local runtime tables.""" + + def setUp(self) -> None: + from deepmd.kernels.triton.sezm import ( + tile_configs, + ) + + self.tile_configs = tile_configs + saved = {family: dict(table) for family, table in tile_configs._RUNTIME.items()} + + def restore() -> None: + for family, table in tile_configs._RUNTIME.items(): + table.clear() + table.update(saved[family]) + + self.addCleanup(restore) + + +class TestTileConfigLookup(_TileConfigRuntimeIsolation): + """Device-independent resolution semantics of the launch-config tables. + + Configurations resolve through the process-local runtime registrations, + then the built-in tables of the running GPU, then the family fallback; + ``has_tile_config`` distinguishes "swept, default won" (an explicit + ``None`` entry) from "never swept" so the freeze auto-tuner only sweeps + genuinely uncovered keys. These semantics hold identically on hosts + without CUDA, where the built-in layer is empty. + """ + + def test_runtime_registration_precedes_builtin_and_none_means_default(self): + tc = self.tile_configs + # An unswept key resolves to the family fallback and reports uncovered. + self.assertEqual(tc.gate_config(48, 3), (16, 8, 2)) + self.assertFalse(tc.has_tile_config("gate", (48, 3))) + # A registration serves lookups and marks the key covered. + tc.register_tile_configs("gate", {(48, 3): (32, 4, 1)}) + self.assertEqual(tc.gate_config(48, 3), (32, 4, 1)) + self.assertTrue(tc.has_tile_config("gate", (48, 3))) + # A None registration records "swept, default won": the lookup keeps + # the fallback while the key counts as covered. + tc.register_tile_configs("rotate_mix_fwd", {(96, 3): None}) + self.assertEqual(tc.rotate_mix_fwd_config(96, 3), (2, 2)) + self.assertTrue(tc.has_tile_config("rotate_mix_fwd", (96, 3))) + # Unknown families are rejected on every entry point. + with self.assertRaises(ValueError): + tc.register_tile_configs("bogus", {}) + with self.assertRaises(ValueError): + tc.has_tile_config("bogus", (32, 3)) + with self.assertRaises(ValueError): + tc._runtime_tile_configs("bogus") + + +@_GPU_KERNELS +class TestTileConfigLayering(_TileConfigRuntimeIsolation): + """GPU-bound layering behaviour: built-in dispatch, collection, tuning.""" + + def test_unknown_gpu_resolves_every_family_to_its_fallback(self): + from unittest import ( + mock, + ) + + tc = self.tile_configs + tc._builtin_tables.cache_clear() + self.addCleanup(tc._builtin_tables.cache_clear) + with mock.patch.object( + torch.cuda, "get_device_name", return_value="NVIDIA Imaginary GPU" + ): + self.assertEqual(tc.gate_config(32, 3), (16, 8, 2)) + self.assertEqual(tc.rotate_mix_fwd_config(64, 3), (2, 2)) + self.assertIsNone(tc.flash_bwd_block_config(64, 3)) + self.assertIsNone(tc.stack_fp16x3_configs(32, 3)) + self.assertFalse(tc.has_tile_config("gate", (32, 3))) + # Runtime registrations still resolve on an untuned GPU. + tc.register_tile_configs( + "stack_fp16x3", {(32, 3): ((64, 64, 32, 4, 1),) * 4} + ) + self.assertIsNotNone(tc.stack_fp16x3_configs(32, 3)) + tc._builtin_tables.cache_clear() + + def test_collect_model_shape_keys_reports_supported_convolutions(self): + from deepmd.kernels.triton.sezm.sweep_tile_configs import ( + collect_model_shape_keys, + ) + from deepmd.pt.model.descriptor.sezm_nn.so2 import ( + SO2Convolution, + ) + + def build_conv(**overrides): + kwargs = { + "lmax": 3, + "mmax": 1, + "channels": 32, + "n_focus": 2, + "focus_dim": 0, + "mixing_layers": 3, + "radial_so2_mode": "degree_channel", + "radial_so2_rank": 1, + "n_atten_head": 1, + "dtype": torch.float32, + "seed": 7, + "trainable": False, + } + kwargs.update(overrides) + return SO2Convolution(**kwargs) + + model = torch.nn.ModuleList( + [ + build_conv(), + build_conv(), # duplicate shape, deduplicated + build_conv(mmax=3), # unsupported layout, contributes no key + ] + ) + self.assertEqual(collect_model_shape_keys(model), [(32, 3, 2, 1)]) + + def test_tune_missing_configs_sweeps_only_uncovered_groups(self): + from unittest import ( + mock, + ) + + from deepmd.kernels.triton.sezm import ( + sweep_tile_configs, + ) + + tc = self.tile_configs + # Cover the pointwise and fp16x3 groups; leave the (C_wide, lmax) + # groups uncovered so only they should be swept at level 2. + tc.register_tile_configs("gate", {(48, 2): (16, 4, 1)}) + tc.register_tile_configs("stack_fp16x3", {(48, 2): None}) + calls: list[str] = [] + + def fake_sweep(group, family, key): + def run(cf, lmax, **kwargs): + calls.append(group) + return {family: {key: (4, 2, 1)}} + + return run + + fake_sweeps = { + "pointwise": fake_sweep("pointwise", "gate", (48, 2)), + "rotate_fwd": fake_sweep("rotate_fwd", "rotate_mix_fwd", (96, 2)), + "rotate_bwd": fake_sweep("rotate_bwd", "rotate_mix_bwd_block", (96, 2)), + "flash_bwd": fake_sweep("flash_bwd", "flash_bwd_block", (96, 2)), + "fp16x3": fake_sweep("fp16x3", "stack_fp16x3", (48, 2)), + } + shape_keys = [(48, 2, 2, 1)] + with mock.patch.dict(sweep_tile_configs._SWEEPS, fake_sweeps): + registered = sweep_tile_configs.tune_missing_configs( + shape_keys, level=2, device="cuda" + ) + self.assertEqual(sorted(calls), ["flash_bwd", "rotate_bwd", "rotate_fwd"]) + self.assertEqual( + sorted(registered), + ["flash_bwd_block", "rotate_mix_bwd_block", "rotate_mix_fwd"], + ) + # The registrations are now covered: a second tune is a no-op, and + # level 3 adds only the fp16x3 group (whose key was pre-covered by + # the explicit None above). + calls.clear() + with mock.patch.dict(sweep_tile_configs._SWEEPS, fake_sweeps): + self.assertEqual( + sweep_tile_configs.tune_missing_configs( + shape_keys, level=3, device="cuda" + ), + {}, + ) + self.assertEqual(calls, []) + # Levels below 2 never sweep. + with mock.patch.dict(sweep_tile_configs._SWEEPS, fake_sweeps): + self.assertEqual( + sweep_tile_configs.tune_missing_configs( + [(64, 5, 2, 1)], level=1, device="cuda" + ), + {}, + ) + self.assertEqual(calls, []) + + if __name__ == "__main__": unittest.main() diff --git a/source/tests/pt/model/test_dpa4_dpmodel_parity.py b/source/tests/pt/model/test_dpa4_dpmodel_parity.py index 128073bb60..345583b092 100644 --- a/source/tests/pt/model/test_dpa4_dpmodel_parity.py +++ b/source/tests/pt/model/test_dpa4_dpmodel_parity.py @@ -880,8 +880,9 @@ def test_reduced_equivariant_rmsnorm(self, lmax, mmax, n_focus) -> None: } dp_mod = DPReducedEquivariantRMSNorm.deserialize(serialized) rng = np.random.default_rng(2044) - x = rng.normal(size=(17, n_focus, degree_index_m.size, self.channels)) - x[0] = 0.0 # all-zeros row exercises the eps path + # focus-major layout (F, E, D_m_trunc, C): the focus stream is axis 0. + x = rng.normal(size=(n_focus, 17, degree_index_m.size, self.channels)) + x[:, 0] = 0.0 # an all-zeros edge exercises the eps path assert_parity(dp_mod.call(x), pt_mod(to_pt(x))) def test_reduced_equivariant_rmsnorm_roundtrip(self) -> None: @@ -900,7 +901,8 @@ def test_reduced_equivariant_rmsnorm_roundtrip(self) -> None: ) dp_mod2 = DPReducedEquivariantRMSNorm.deserialize(dp_mod.serialize()) rng = np.random.default_rng(2045) - x = rng.normal(size=(17, 2, degree_index_m.size, self.channels)) + # focus-major layout (F, E, D_m_trunc, C): the focus stream is axis 0. + x = rng.normal(size=(2, 17, degree_index_m.size, self.channels)) np.testing.assert_array_equal( np.asarray(dp_mod.call(x)), np.asarray(dp_mod2.call(x)) ) @@ -2001,7 +2003,9 @@ def test_so2_linear(self, lmax, mmax, mlp_bias, n_focus) -> None: serialized = pt_mod.serialize() dp_mod = DPSO2Linear.deserialize(serialized) rng = np.random.default_rng(2053) - x = rng.normal(size=(13, n_focus, dp_mod.reduced_dim, 5)) + # SO2Linear consumes the focus-major layout (F, E, D_m, Cin): the focus + # stream is the batched-matmul axis and the edge axis follows. + x = rng.normal(size=(n_focus, 13, dp_mod.reduced_dim, 5)) assert_parity(dp_mod.call(x), pt_mod(to_pt(x))) def test_so2_linear_roundtrip(self) -> None: @@ -2020,7 +2024,8 @@ def test_so2_linear_roundtrip(self) -> None: ) dp_mod2 = DPSO2Linear.deserialize(dp_mod.serialize()) rng = np.random.default_rng(2054) - x = rng.normal(size=(9, 2, dp_mod.reduced_dim, 4)) + # focus-major (F, E, D_m, Cin) + x = rng.normal(size=(2, 9, dp_mod.reduced_dim, 4)) np.testing.assert_array_equal( np.asarray(dp_mod.call(x)), np.asarray(dp_mod2.call(x)) ) @@ -3651,6 +3656,164 @@ def test_descriptor_cross_deserialize(self) -> None: out2 = np.asarray(dp_mod2.call(*args, mapping=inp["mapping"])[0]) np.testing.assert_array_equal(out1, out2) + def test_descriptor_zero_blocks(self) -> None: + # n_blocks=0: no interaction blocks. Geometry then enters only through + # the Geometric Initial Embedding, which is active when use_env_seed=True + # and lmax + extra_node_l > 0 (lmax=3 here hosts the l>=1 GIE features). + pt_mod, dp_mod, _ = self._build_descr_pair(n_blocks=0, use_env_seed=True) + assert dp_mod.n_blocks == 0 + self._assert_descr_parity(pt_mod, dp_mod) + + def test_descriptor_native_spin(self) -> None: + # Native per-atom spin: ``use_spin`` conditions the l=0 type features on + # the per-type spin magnitude and injects an l=1 direction feature (needs + # a node degree >= 1, satisfied by lmax=3). Parity is checked with a real + # spin tensor and with spin=None, and spin=None is pinned to reproduce the + # genuine no-spin descriptor exactly on both backends. + from deepmd.dpmodel.descriptor.dpa4 import ( + DescrptDPA4, + ) + from deepmd.pt.model.descriptor.sezm import ( + DescrptSeZM, + ) + + use_spin = [True, False, False] # ntypes==3; type 0 is spin-active + pt_mod, dp_mod, _ = self._build_descr_pair(use_spin=use_spin) + inp = self._inputs() + coord, atype_ext, nlist, mp = ( + inp["coord"], + inp["atype_ext"], + inp["nlist"], + inp["mapping"], + ) + nf = coord.shape[0] + # local types include a spin-active type-0 atom so the spin path is live + assert (atype_ext[:, : self.nloc] == 0).any() + rng = np.random.default_rng(2170) + spin = rng.normal(size=(nf, self.nloc, 3)) + + def _call(spin_arg): + out_dp = dp_mod.call( + coord.reshape(nf, -1), atype_ext, nlist, mapping=mp, spin=spin_arg + ) + out_pt = pt_mod( + to_pt(coord), + to_pt(atype_ext), + to_pt(nlist), + mapping=to_pt(mp), + spin=None if spin_arg is None else to_pt(spin_arg), + ) + return out_dp, out_pt + + # spin path: pt vs dp parity (descriptor-level fp64 gate) + out_dp_s, out_pt_s = _call(spin) + assert out_dp_s[0].shape == tuple(out_pt_s[0].shape) + assert_parity(out_dp_s[0], out_pt_s[0], rtol=1e-10, atol=1e-12) + assert out_dp_s[1:] == (None, None, None, None) + + # spin=None path: pt vs dp parity + out_dp_n, out_pt_n = _call(None) + assert_parity(out_dp_n[0], out_pt_n[0], rtol=1e-10, atol=1e-12) + + # the spin tensor must actually move the descriptor (guards a no-op path) + d_s = np.asarray(out_dp_s[0]) + d_n = np.asarray(out_dp_n[0]) + assert np.abs(d_s - d_n).max() > 1e-3 + + # spin=None reproduces the genuine no-spin descriptor: copy the shared + # (non-spin) weights into a use_spin=None twin and check the l=0 output is + # bit-identical to the use_spin model evaluated with spin=None. + kwargs = self._descr_kwargs() + pt_nospin = DescrptSeZM(**kwargs, use_spin=None).double().eval() + sd_spin = pt_mod.state_dict() + pt_nospin.load_state_dict( + {k: sd_spin[k].clone() for k in pt_nospin.state_dict()} + ) + dp_nospin = DescrptDPA4.deserialize(pt_nospin.serialize()) + d_ns = np.asarray( + dp_nospin.call(coord.reshape(nf, -1), atype_ext, nlist, mapping=mp)[0] + ) + np.testing.assert_array_equal(d_n, d_ns) + p_ns = pt_nospin( + to_pt(coord), to_pt(atype_ext), to_pt(nlist), mapping=to_pt(mp) + )[0] + assert_parity(d_ns, p_ns, rtol=1e-10, atol=1e-12) + + @pytest.mark.parametrize( + "so3_readout", ["none", "mlp"] + ) # scalar readout vs SO(3) grid MLP readout + def test_descriptor_readout_layers(self, so3_readout) -> None: + # readout_layers=2 stacks a residual output-FFN layer before the final + # l=0 projection. pt is pinned to CPU (as in test_descriptor_so3_readout) + # so the strict fp64 gate holds under a CUDA default device. + from deepmd.dpmodel.descriptor.dpa4 import ( + DescrptDPA4, + ) + from deepmd.pt.model.descriptor.sezm import ( + DescrptSeZM, + ) + + kwargs = self._descr_kwargs(readout_layers=2, so3_readout=so3_readout) + pt_mod = DescrptSeZM(**kwargs).double().eval().to("cpu") + # output projections are zero-initialized; perturb for a nontrivial readout + rng = np.random.default_rng(2180) + with torch.no_grad(): + for p in pt_mod.parameters(): + p += torch.from_numpy(0.05 * rng.normal(size=tuple(p.shape))).to("cpu") + dp_mod = DescrptDPA4.deserialize(pt_mod.serialize()) + assert dp_mod.readout_layers == 2 + + inp = self._inputs() + coord, atype_ext, nlist, mp = ( + inp["coord"], + inp["atype_ext"], + inp["nlist"], + inp["mapping"], + ) + nf = coord.shape[0] + out_dp = np.asarray( + dp_mod.call(coord.reshape(nf, -1), atype_ext, nlist, mapping=mp)[0] + ) + out_pt = ( + pt_mod( + torch.from_numpy(coord).to("cpu"), + torch.from_numpy(atype_ext.astype(np.int64)).to("cpu"), + torch.from_numpy(nlist.astype(np.int64)).to("cpu"), + mapping=torch.from_numpy(mp.astype(np.int64)).to("cpu"), + )[0] + .detach() + .cpu() + .numpy() + ) + assert out_dp.shape == out_pt.shape + assert np.abs(out_dp).max() > 1e-6 # guards a trivially-zero readout + np.testing.assert_allclose(out_dp, out_pt, rtol=1e-10, atol=1e-12) + + def test_descriptor_focus_major_so2(self) -> None: + # Multi-stream focus-major SO(2) mixing: n_focus>1 carries the mixing + # activation as (F, E, D_m, Cf) with the focus stream on the batched + # matmul axis. Combined with multi-layer mixing (mixing_layers>=2), + # attention (n_atten_head>0), and the cross-focus competition that + # activates for n_focus>1, this validates the full focus-major path. + # so2_norm stays False here to isolate the mixing path; the + # n_focus>1 + so2_norm=True combination is covered by + # test_descriptor_focus_major_so2_norm. + pt_mod, dp_mod, _ = self._build_descr_pair( + n_focus=2, mixing_layers=2, n_atten_head=1 + ) + assert dp_mod.n_focus == 2 + self._assert_descr_parity(pt_mod, dp_mod) + + def test_descriptor_focus_major_so2_norm(self) -> None: + # n_focus>1 + so2_norm=True: the focus-major SO(2) mixing feeds + # ReducedEquivariantRMSNorm a (F, E, D_m, Cf) tensor, and the norm now + # applies its per-focus affine on the focus axis (axis 0), so the + # affine broadcast holds for E != n_focus. + pt_mod, dp_mod, _ = self._build_descr_pair( + n_focus=2, so2_norm=True, mixing_layers=2 + ) + self._assert_descr_parity(pt_mod, dp_mod) + class TestNoTorchImport: def test_dpa4_nn_does_not_import_torch(self) -> None: diff --git a/source/tests/pt/model/test_dpa4_ptexpt_grad_parity.py b/source/tests/pt/model/test_dpa4_ptexpt_grad_parity.py index 530c76edac..ce1ce7bb1e 100644 --- a/source/tests/pt/model/test_dpa4_ptexpt_grad_parity.py +++ b/source/tests/pt/model/test_dpa4_ptexpt_grad_parity.py @@ -181,6 +181,66 @@ def test_descriptor_grad_parity(self, use_env_seed) -> None: # math, where fp64 accumulation-order drift reaches ~3e-11 rel _assert_grad_trees_match(pt_mod, expt_mod, rtol=1e-10, atol=1e-12) + def test_descriptor_grad_parity_native_spin(self) -> None: + # Native per-atom spin (``use_spin``) adds trainable Parameters that + # pt_expt must promote from dpmodel numpy->buffer: + # ``SpinEmbedding.{adam_spin_vec_weight, adam_spin_nbr_weight}``, + # its ``mag_layer1/2`` weights (NativeLayer auto-promotes), and + # ``EnvironmentInitialEmbedding.spin_scale``. A missing promotion + # surfaces as the trainable-parameter-count mismatch asserted inside + # ``_assert_grad_trees_match`` (n_pt vs n_expt). Type 0 carries spin; + # the fixture's local types include type-0 atoms that are also edge + # sources, so both the on-site (l=0 magnitude + l=1 direction) and the + # neighbor-aggregation (edge l=1, gated by ``spin_scale``) paths fire. + use_spin = [True, False] # ntypes == 2; type 0 is spin-active + pt_mod, expt_mod = self._build_pair(use_spin=use_spin) + inp = self._inputs() + coord = inp["coord"].reshape(self.nf, -1) + atype_ext, nlist, mapping = inp["atype_ext"], inp["nlist"], inp["mapping"] + # a spin-active type-0 atom must be local for the on-site spin path + assert (atype_ext[:, : self.nloc] == 0).any() + rng = np.random.default_rng(2170) + spin = rng.normal(size=(self.nf, self.nloc, 3)) + + out_pt = pt_mod( + to_pt(inp["coord"]), + to_pt(atype_ext), + to_pt(nlist), + mapping=to_pt(mapping), + spin=to_pt(spin), + )[0] + out_expt = expt_mod( + to_pt(coord), + to_pt(atype_ext.astype(np.int64)), + to_pt(nlist.astype(np.int64)), + mapping=to_pt(mapping.astype(np.int64)), + spin=to_pt(spin), + )[0] + # guard: forward outputs must match before comparing gradients + np.testing.assert_allclose( + out_expt.detach().cpu().numpy(), + out_pt.detach().cpu().numpy(), + rtol=1e-10, + atol=1e-12, + ) + # guard: the spin tensor must actually move the descriptor, otherwise + # the spin-parameter gradients below would be a trivial (zero) match + with torch.no_grad(): + out_pt_nospin = pt_mod( + to_pt(inp["coord"]), + to_pt(atype_ext), + to_pt(nlist), + mapping=to_pt(mapping), + spin=None, + )[0] + assert (out_pt - out_pt_nospin).abs().max().item() > 1e-3 + # quadratic loss -> dL/dw depends on the weights, not just the inputs + (out_pt**2).sum().backward() + (out_expt**2).sum().backward() + # count parity (validates spin-Parameter promotion) + name-aligned + # gradient parity across the full descriptor, spin parameters included + _assert_grad_trees_match(pt_mod, expt_mod, rtol=1e-10, atol=1e-12) + class TestFittingGradParity: nf = 2 diff --git a/source/tests/pt/model/test_sezm_export.py b/source/tests/pt/model/test_sezm_export.py index 2082b85293..5af20936cf 100644 --- a/source/tests/pt/model/test_sezm_export.py +++ b/source/tests/pt/model/test_sezm_export.py @@ -180,6 +180,18 @@ def _tiny_sezm_spin_model_params() -> dict: params["spin"] = { "use_spin": [True, False], "virtual_scale": 0.2, + "scheme": "deepspin", + } + return params + + +def _tiny_sezm_native_spin_model_params() -> dict: + """Minimal fp64 native-spin SeZM config for freeze routing tests.""" + params = copy.deepcopy(_tiny_sezm_model_params()) + params["type_map"] = ["O", "H"] + params["spin"] = { + "use_spin": [True, False], + "scheme": "native", } return params @@ -1010,6 +1022,57 @@ def fake_compile(_exported: torch.export.ExportedProgram, package_path: str): self.assertFalse(metadata["has_comm_artifact"]) self.assertNotIn("model/extra/forward_lower_with_comm.pt2", names) + @unittest.skipIf(_SKIP_OFF_COMPILE_TORCH, _SKIP_OFF_COMPILE_TORCH_REASON) + def test_freeze_accepts_native_spin_checkpoint_metadata(self) -> None: + """Native-spin checkpoints export the energy edge contract plus spin. + + The native scheme reuses the ``edge_vec`` lower ABI; the per-local-atom + spin is the only extra input, so the C++ backend builds the edge schema + exactly as for a non-spin model. The magnetic force and spin mask are + still emitted, and the type map / ntypes stay at the real-system sizes + (no virtual atoms). + """ + + def fake_compile(_exported: torch.export.ExportedProgram, package_path: str): + with zipfile.ZipFile(package_path, "w") as zf: + zf.writestr("model/data.pkl", b"") + + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + params = _tiny_sezm_native_spin_model_params() + ckpt_path = _write_tiny_sezm_checkpoint(tmp_path, params) + out = tmp_path / "native_spin.pt2" + + with mock.patch( + "torch._inductor.aoti_compile_and_package", + side_effect=fake_compile, + ): + freeze_sezm_to_pt2(str(ckpt_path), str(out), device=_CPU) + + with zipfile.ZipFile(str(out), "r") as zf: + names = zf.namelist() + metadata = json.loads( + zf.read("model/extra/metadata.json").decode("utf-8") + ) + + self.assertTrue(metadata["is_spin"]) + # Native spin shares the energy edge ABI; only the deepspin scheme keeps + # the nlist contract. + self.assertEqual(metadata["lower_input_kind"], "edge_vec") + # Native spin keeps the real-system type map and count (no virtual atoms). + self.assertEqual(metadata["type_map"], params["type_map"]) + self.assertEqual(metadata["ntypes"], len(params["type_map"])) + self.assertEqual(metadata["use_spin"], params["spin"]["use_spin"]) + self.assertEqual(metadata["ntypes_spin"], 1) + # The magnetic force and spin mask are still exported. + self.assertIn("energy_derv_r_mag", metadata["output_keys"]) + self.assertIn("mask_mag", metadata["output_keys"]) + # Native spin reuses the edge_vec contract and is rank-decomposable, so + # the freeze embeds the multi-rank with-comm artifact (extended spin leaf + # plus the eight border_op communication tensors). + self.assertTrue(metadata["has_comm_artifact"]) + self.assertIn("model/extra/forward_lower_with_comm.pt2", names) + @unittest.skipIf(_SKIP_OFF_COMPILE_TORCH, _SKIP_OFF_COMPILE_TORCH_REASON) def test_freeze_embeds_with_comm_artifact(self) -> None: """A plain SeZM checkpoint ships the nested multi-rank with-comm artifact.""" diff --git a/source/tests/pt/model/test_sezm_model.py b/source/tests/pt/model/test_sezm_model.py index fcc5ae01ea..ba6cad98ec 100644 --- a/source/tests/pt/model/test_sezm_model.py +++ b/source/tests/pt/model/test_sezm_model.py @@ -18,6 +18,7 @@ from deepmd.pt.loss import ( DeNSLoss, + EnergySpinLoss, EnergyStdLoss, PropertyLoss, ) @@ -33,12 +34,16 @@ build_merged_state_dict, ) from deepmd.pt.model.model import ( + get_model, get_sezm_model, ) from deepmd.pt.model.model.sezm_model import ( InterPotential, SeZMModel, ) +from deepmd.pt.model.model.sezm_native_spin_model import ( + SeZMNativeSpinModel, +) from deepmd.pt.model.model.sezm_property_model import ( SeZMPropertyModel, ) @@ -48,6 +53,12 @@ from deepmd.pt.utils import ( env, ) +from deepmd.pt.utils.nlist import ( + extend_input_and_build_neighbor_list, +) +from deepmd.pt_expt.utils.edge_schema import ( + edge_schema_from_extended, +) from deepmd.utils.path import ( DPPath, ) @@ -1544,6 +1555,381 @@ def test_atom_virial_sums_to_global_virial(self) -> None: ) +class TestSeZMNativeSpinModel(unittest.TestCase): + """Validate the native (virtual-atom-free) spin SeZM model. + + The spin vector enters the descriptor as an equivariant feature, so the + magnetic force is the negative spin gradient of the energy. float64 + finite-difference checks pin ``force_mag = -dE/dspin`` and the conservative + ``force = -dE/dx``; a joint rotation of geometry and spin confirms SO(3) + equivariance of energy, force and magnetic force. + """ + + def setUp(self) -> None: + self.device = env.DEVICE + + def _build_model(self, *, use_compile: bool = False) -> SeZMNativeSpinModel: + """Build a tiny float64 native-spin model with randomized parameters.""" + params = { + "type": "dpa4", + "type_map": ["Ni", "O"], + "spin": {"use_spin": [True, False], "scheme": "native"}, + "descriptor": { + "type": "dpa4", + "sel": [12, 12], + "rcut": 3.0, + "channels": 4, + "n_focus": 1, + "n_radial": 3, + "radial_mlp": [6], + "use_env_seed": True, + "random_gamma": False, + "l_schedule": [1, 0], + "mmax": 1, + "ffn_neurons": 8, + "ffn_blocks": 1, + "mlp_bias": False, + "layer_scale": False, + "use_amp": False, + "activation_function": "silu", + "precision": "float64", + "seed": 7, + }, + "fitting_net": { + "neuron": [8], + "activation_function": "silu", + "precision": "float64", + "seed": 7, + }, + "use_compile": use_compile, + } + model = get_model(params) + # Perturb away from the near-identity initialization so the spin + # embedding measurably shapes the output. + torch.manual_seed(1234) + with torch.no_grad(): + for p in model.parameters(): + p.copy_(torch.randn_like(p) * 0.1) + model.eval() + return model + + def _frame( + self, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Periodic frame; only Ni (type 0) atoms carry spin.""" + coord = torch.tensor( + [ + [ + [0.10, 0.05, 0.00], + [1.05, 0.30, 0.10], + [0.20, 1.40, 0.35], + [1.60, 1.15, 0.20], + [2.20, 0.10, 1.05], + ] + ], + dtype=torch.float64, + device=self.device, + ) + atype = torch.tensor([[0, 1, 0, 1, 0]], dtype=torch.int64, device=self.device) + spin = torch.zeros(1, 5, 3, dtype=torch.float64, device=self.device) + is_mag = atype[0] == 0 + torch.manual_seed(99) + spin[0, is_mag] = torch.randn( + int(is_mag.sum()), 3, dtype=torch.float64, device=self.device + ) + box = torch.tensor( + [[6.0, 0.0, 0.0, 0.0, 6.0, 0.0, 0.0, 0.0, 6.0]], + dtype=torch.float64, + device=self.device, + ) + return coord, atype, spin, box + + @staticmethod + def _proper_rotation(device: torch.device) -> torch.Tensor: + """A deterministic proper rotation matrix (det = +1).""" + torch.manual_seed(0) + q, _ = torch.linalg.qr(torch.randn(3, 3, dtype=torch.float64, device=device)) + if torch.det(q) < 0: + q[:, 0] = -q[:, 0] + return q + + def test_finite_difference_forces(self) -> None: + """Force = -dE/dx and force_mag = -dE/dspin to finite-difference accuracy. + + The same frame validates both endpoints of the single backward, and the + per-type spin gate (non-magnetic atoms carry exactly zero magnetic force). + """ + model = self._build_model() + coord, atype, spin, box = self._frame() + out = model(coord, atype, spin, box=box) + force, force_mag = out["force"], out["force_mag"] + + def energy(c: torch.Tensor, s: torch.Tensor) -> torch.Tensor: + return model(c, atype, s, box=box)["energy"].squeeze() + + eps = 1.0e-5 + nloc = coord.shape[1] + fd_force = torch.zeros_like(force) + fd_mag = torch.zeros_like(force_mag) + for a in range(nloc): + for d in range(3): + cp, cm = coord.clone(), coord.clone() + cp[0, a, d] += eps + cm[0, a, d] -= eps + fd_force[0, a, d] = -(energy(cp, spin) - energy(cm, spin)) / (2 * eps) + sp, sm = spin.clone(), spin.clone() + sp[0, a, d] += eps + sm[0, a, d] -= eps + fd_mag[0, a, d] = -(energy(coord, sp) - energy(coord, sm)) / (2 * eps) + torch.testing.assert_close( + force, fd_force, atol=1.0e-6, rtol=1.0e-4, msg="force != -dE/dx" + ) + torch.testing.assert_close( + fd_mag, force_mag, atol=1.0e-6, rtol=1.0e-4, msg="force_mag != -dE/dspin" + ) + + torch.testing.assert_close(out["mask_mag"], (atype == 0).reshape(1, -1, 1)) + self.assertEqual(force_mag[0, atype[0] == 1].abs().max().item(), 0.0) + + def test_joint_rotation_equivariance(self) -> None: + """Energy is invariant and force / force_mag rotate under joint rotation.""" + model = self._build_model() + coord, atype, spin, box = self._frame() + out = model(coord, atype, spin, box=box) + + rot = self._proper_rotation(self.device) + coord_r = torch.einsum("ij,nkj->nki", rot, coord) + spin_r = torch.einsum("ij,nkj->nki", rot, spin) + box_r = (box.view(1, 3, 3) @ rot.transpose(0, 1)).reshape(1, 9) + out_r = model(coord_r, atype, spin_r, box=box_r) + + torch.testing.assert_close(out_r["energy"], out["energy"], atol=1e-9, rtol=1e-7) + torch.testing.assert_close( + out_r["force"], + torch.einsum("ij,nkj->nki", rot, out["force"]), + atol=1e-8, + rtol=1e-6, + ) + torch.testing.assert_close( + out_r["force_mag"], + torch.einsum("ij,nkj->nki", rot, out["force_mag"]), + atol=1e-8, + rtol=1e-6, + ) + + def test_export_matches_forward(self) -> None: + """The traced ``.pt2`` export reduces to the public forward. + + The native scheme reuses the energy edge ABI plus the owned-atom spins, + so the C++ backend builds the edge schema exactly as for a non-spin + model. ``make_fx`` unfolds the single ``autograd.grad(energy, [edge_vec, + spin])``; the extended conservative force and the zero-padded magnetic + force reduce the LAMMPS way (``communicate_extended_output``) back to the + per-local-atom public force, while ``mask_mag`` is per-local-atom. + """ + model = self._build_model() + coord, atype, spin, box = self._frame() + nloc = coord.shape[1] + out = model(coord, atype, spin, box=box) + ext_coord, ext_atype, ext_spin, nlist, mapping = self._extended_spin_inputs( + model, coord, atype, spin, box + ) + # Guard the probe: the frame must carry ghosts (nall > nloc) so the + # magnetic-force path through ghost-image neighbours is actually + # exercised; otherwise the reduction would be trivial. + self.assertGreater(ext_coord.shape[1], nloc) + + edge = edge_schema_from_extended(ext_coord, ext_atype, nlist, mapping) + edge_inputs = ( + edge.coord, + edge.atype, + edge.edge_index, + edge.edge_vec, + edge.edge_scatter_index, + edge.edge_mask, + ext_spin[:, :nloc], + None, + None, + None, + ) + traced = model.forward_common_lower_exportable(*edge_inputs) + model_ret = traced(*edge_inputs) + + torch.testing.assert_close(model_ret["energy_redu"], out["energy"]) + torch.testing.assert_close( + self._reduce_extended( + model_ret["energy_derv_r"].squeeze(-2), mapping, nloc + ), + out["force"], + atol=1e-9, + rtol=1e-7, + ) + torch.testing.assert_close( + self._reduce_extended( + model_ret["energy_derv_r_mag"].squeeze(-2), mapping, nloc + ), + out["force_mag"], + atol=1e-9, + rtol=1e-7, + ) + torch.testing.assert_close(model_ret["mask_mag"], out["mask_mag"]) + + def test_serialization_roundtrip(self) -> None: + """Serialized native-spin model restores identical predictions.""" + model = self._build_model() + coord, atype, spin, box = self._frame() + out = model(coord, atype, spin, box=box) + + restored = SeZMNativeSpinModel.deserialize(model.serialize()) + restored.eval() + self.assertTrue(restored.has_spin()) + self.assertEqual(restored.get_type_map(), ["Ni", "O"]) + out2 = restored(coord, atype, spin, box=box) + for key in ["energy", "force", "force_mag"]: + torch.testing.assert_close( + out2[key], out[key], atol=1e-10, rtol=1e-8, msg=f"{key} mismatch" + ) + + def test_ener_spin_loss_smoke(self) -> None: + """The standard ``ener_spin`` loss runs unchanged on the native model.""" + model = self._build_model() + coord, atype, spin, box = self._frame() + nloc = coord.shape[1] + loss_fn = EnergySpinLoss( + starter_learning_rate=1.0e-3, + start_pref_e=1.0, + limit_pref_e=1.0, + start_pref_fr=1.0, + limit_pref_fr=1.0, + start_pref_fm=1.0, + limit_pref_fm=1.0, + ) + input_dict = {"coord": coord, "atype": atype, "spin": spin, "box": box} + label = { + "energy": torch.zeros(1, 1, dtype=torch.float64, device=self.device), + "force": torch.zeros(1, nloc, 3, dtype=torch.float64, device=self.device), + "force_mag": torch.zeros( + 1, nloc, 3, dtype=torch.float64, device=self.device + ), + "find_energy": 1.0, + "find_force": 1.0, + "find_force_mag": 1.0, + } + _, loss, more = loss_fn( + input_dict, model, label, natoms=nloc, learning_rate=1.0e-3 + ) + self.assertTrue(torch.isfinite(loss)) + self.assertIn("rmse_fm", more) + + @unittest.skipIf(_SKIP_OFF_COMPILE_TORCH, _SKIP_OFF_COMPILE_TORCH_REASON) + def test_compile_matches_eager(self) -> None: + """The compiled native-spin path matches eager force and magnetic force.""" + coord, atype, spin, box = self._frame() + model_eager = self._build_model(use_compile=False) + model_cmp = self._build_model(use_compile=True) + model_cmp.load_state_dict(model_eager.state_dict()) + model_eager.train() + model_cmp.train() + + out_e = model_eager(coord, atype, spin, box=box) + out_c = model_cmp(coord, atype, spin, box=box) + self.assertIn((True, False), model_cmp.compiled_core_compute_cache) + for key in ["energy", "force", "force_mag"]: + _assert_close_with_strict_warning( + out_c[key], + out_e[key], + atol=1.0e-6, + rtol=1.0e-6, + msg=f"native-spin compile mismatch on {key}", + ) + + @staticmethod + def _extended_spin_inputs( + model: SeZMNativeSpinModel, + coord: torch.Tensor, + atype: torch.Tensor, + spin: torch.Tensor, + box: torch.Tensor, + ) -> tuple[torch.Tensor, ...]: + """Build the 5-tuple lower inputs ``(ext_coord, ext_atype, ext_spin, nlist, mapping)``.""" + extended_coord, extended_atype, mapping, nlist = ( + extend_input_and_build_neighbor_list( + coord, + atype, + model.get_rcut(), + model.get_sel(), + mixed_types=model.mixed_types(), + box=box, + ) + ) + extended_spin = torch.gather(spin, 1, mapping.unsqueeze(-1).expand(-1, -1, 3)) + return extended_coord, extended_atype, extended_spin, nlist, mapping + + @staticmethod + def _reduce_extended( + extended: torch.Tensor, mapping: torch.Tensor, nloc: int + ) -> torch.Tensor: + """Scatter-sum an extended ``(nf, nall, 3)`` tensor onto local owners.""" + reduced = torch.zeros( + extended.shape[0], nloc, 3, dtype=extended.dtype, device=extended.device + ) + return reduced.scatter_reduce( + 1, mapping.unsqueeze(-1).expand(-1, -1, 3), extended, reduce="sum" + ) + + def test_allow_missing_label_relaxes_spin_data_requirement(self) -> None: + """``allow_missing_label`` relaxes the spin data requirement to optional with a + zero default, and the flag is excluded from serialization. + + ``use_spin`` is given as an element symbol here, so the test also covers the + symbol form being expanded against ``type_map`` into a per-type boolean list. + """ + from deepmd.pt.train.training import ( + get_additional_data_requirement, + ) + + def build(allow_missing_label: bool | None) -> SeZMNativeSpinModel: + params = { + "type": "dpa4", + "type_map": ["Ni", "O"], + # Element-symbol use_spin: expanded against type_map to [True, False]. + "spin": {"use_spin": ["Ni"], "scheme": "native"}, + "descriptor": { + "type": "dpa4", + "sel": [2, 2], + "rcut": 3.0, + "channels": 4, + "n_radial": 3, + "l_schedule": [1, 0], + "mmax": 1, + "use_env_seed": False, + "random_gamma": False, + "precision": "float64", + }, + "fitting_net": {"neuron": [4], "precision": "float64", "seed": 1}, + } + if allow_missing_label is not None: + params["spin"]["allow_missing_label"] = allow_missing_label + return get_model(params) + + for allow_missing_label, expected_must in ( + (None, True), + (False, True), + (True, False), + ): + model = build(allow_missing_label) + # The symbol ``["Ni"]`` expands to the per-type mask over ``["Ni", "O"]``. + self.assertEqual(model.spin.use_spin.tolist(), [True, False]) + self.assertEqual(model.spin.allow_missing_label, bool(allow_missing_label)) + spin_req = { + item.key: item for item in get_additional_data_requirement(model) + }["spin"] + self.assertEqual(spin_req.must, expected_must) + self.assertEqual(spin_req.default, 0.0) + + self.assertNotIn("allow_missing_label", build(True).spin.serialize()) + + class TestSeZMModelBridging(unittest.TestCase): """Test SeZM model with ZBL bridging enabled.""" @@ -2093,9 +2479,10 @@ def _build_base_and_lora( return base, lora def _random_input(self, lora: LoRASO2) -> torch.Tensor: + # Focus-major ``(F, E, D_m, C)`` contract; E=3 edges. return torch.randn( - 3, lora.n_focus, + 3, lora.reduced_dim, lora.in_channels, device=self.device, @@ -2128,8 +2515,8 @@ def test_z_rotation_equivariance(self) -> None: batch = 8 dtype = lora.dtype x = torch.randn( - batch, lora.n_focus, + batch, lora.reduced_dim, lora.in_channels, device=self.device, @@ -2137,9 +2524,9 @@ def test_z_rotation_equivariance(self) -> None: ) angles = torch.rand(batch, device=self.device, dtype=dtype) * 2 * math.pi z_mat = _build_m_major_z_rotation(angles, lmax, mmax, self.device) - x_rot = torch.einsum("bij,bfjc->bfic", z_mat, x) + x_rot = torch.einsum("eij,fejc->feic", z_mat, x) lhs = lora(x_rot) - rhs = torch.einsum("bij,bfjc->bfic", z_mat, lora(x)) + rhs = torch.einsum("eij,fejc->feic", z_mat, lora(x)) torch.testing.assert_close(lhs, rhs, atol=1e-5, rtol=1e-5) diff --git a/source/tests/pt/model/test_sezm_parallel.py b/source/tests/pt/model/test_sezm_parallel.py index 4b30b1a5e5..ceb53de9f5 100644 --- a/source/tests/pt/model/test_sezm_parallel.py +++ b/source/tests/pt/model/test_sezm_parallel.py @@ -355,16 +355,103 @@ def test_descriptor_parity_cpu(self) -> None: torch.testing.assert_close(par, ref, rtol=1e-8, atol=1e-9) +class TestSeZMNativeSpinParallelParity(unittest.TestCase): + """Parallel native-spin magnetic force must fold to the single-domain value. + + The local spin feeds the extended spin at the owner row and at every ghost + copy, so by the chain rule the single-domain magnetic force ``-dE/ds_local`` + equals the parallel per-node magnetic force ``-dE/ds_ext`` summed over each + owner and its ghost copies. This pins ``border_op``'s backward as the exact + Jacobian-vector product for a per-node leaf -- the regime the energy and + conservative-force parity tests cannot exercise, because ghost nodes carry + no edge-geometry leaf (they are never edge centres) and so a wrong ghost-row + gradient there dead-ends instead of corrupting the force. + """ + + @classmethod + def setUpClass(cls) -> None: + ensure_comm_registered() + + def _native_spin_model(self, device: torch.device) -> torch.nn.Module: + params = _tiny_parallel_model_params() + params["type_map"] = ["Ni", "O"] + params["spin"] = {"scheme": "native", "use_spin": [True, False]} + model = get_model(params) + model.eval() + model.to(device) + return model + + def test_native_spin_mag_force_fold_parity_cpu(self) -> None: + device = torch.device("cpu") + model = self._native_spin_model(device) + _perturb_descriptor(model.atomic_model.descriptor) + sysm = _build_extended_system(model, device) + nloc, nall, mapping = sysm["nloc"], sysm["nall"], sysm["mapping"] + comm = _self_comm_dict(mapping, nloc, nall) + + rng = np.random.default_rng(7) + atype_np = sysm["atype"][0].cpu().numpy() + magnetic = atype_np == 0 + local_spin = np.zeros((1, nloc, 3)) + local_spin[0, magnetic] = rng.standard_normal((int(magnetic.sum()), 3)) + ls = torch.tensor(local_spin, dtype=torch.float64, device=device) + es = torch.gather(ls, 1, mapping.unsqueeze(-1).expand(-1, -1, 3)) + + single = model.forward_common_lower( + sysm["coord"], + sysm["atype"], + sysm["edge_index"], + sysm["edge_vec"], + sysm["edge_scatter_index"], + sysm["edge_mask"], + spin=ls, + ) + par = model.forward_common_lower( + sysm["coord"], + sysm["atype"], + sysm["edge_scatter_index"], + sysm["edge_vec"], + sysm["edge_scatter_index"], + sysm["edge_mask"], + comm_dict=comm, + extended_atype=sysm["extended_atype"], + spin=es, + ) + m_single = single["energy_derv_r_mag"].squeeze(-2)[0] # (nloc, 3) + m_par = par["energy_derv_r_mag"].squeeze(-2)[0] # (nall, 3) + m_par_folded = torch.zeros(nloc, 3, dtype=m_par.dtype, device=device) + m_par_folded.index_add_(0, mapping[0], m_par) + + # Reject the degenerate zero-spin-force regime so the assertion below can + # never pass vacuously on a geometry-independent (untrained) model. + self.assertGreater(m_single.abs().max().item(), 1e-3) + torch.testing.assert_close(m_par_folded, m_single, rtol=1e-8, atol=1e-9) + + class TestSeZMEdgeParallelCapability(unittest.TestCase): - """The with-comm export predicate gates bridging and spin out.""" + """The with-comm export predicate admits the edge_vec contract. + + Plain energy and native spin both use the edge_vec lower interface and are + rank-decomposable, so they support the with-comm artifact; only analytical + bridging (Source Freeze Propagation) is gated out. + """ def test_plain_model_supports_edge_parallel(self) -> None: model = _build_model(torch.device("cpu")) self.assertTrue(model.supports_edge_parallel()) + self.assertEqual(model.export_lower_input_kind(), "edge_vec") self.assertTrue( model.atomic_model.descriptor.has_message_passing_across_ranks() ) + def test_native_spin_supports_edge_parallel(self) -> None: + params = _tiny_parallel_model_params() + params["type_map"] = ["Ni", "O"] + params["spin"] = {"scheme": "native", "use_spin": [True, False]} + model = get_model(params) + self.assertTrue(model.supports_edge_parallel()) + self.assertEqual(model.export_lower_input_kind(), "edge_vec") + def test_bridging_model_fails_fast(self) -> None: # ZBL needs real element symbols for its analytical pair potential. model = _build_model( diff --git a/source/tests/pt/model/test_sezm_spin_model.py b/source/tests/pt/model/test_sezm_spin_model.py index 207743ee8e..ba48c94889 100644 --- a/source/tests/pt/model/test_sezm_spin_model.py +++ b/source/tests/pt/model/test_sezm_spin_model.py @@ -157,6 +157,7 @@ def _build_model_params( "spin": { "use_spin": [True, False], "virtual_scale": 0.2, + "scheme": "deepspin", }, "descriptor": { "type": "SeZM", diff --git a/source/tests/pt/test_validation.py b/source/tests/pt/test_validation.py index 03ea41edca..c07e4652b4 100644 --- a/source/tests/pt/test_validation.py +++ b/source/tests/pt/test_validation.py @@ -20,21 +20,32 @@ ArgumentValueError, ) +from deepmd.pt.model.model import ( + get_model, +) from deepmd.pt.train.validation import ( BEST_METRIC_NAME_INFO_KEY, TOPK_RECORDS_INFO_KEY, FullValidator, resolve_full_validation_start_step, ) +from deepmd.pt.utils.env import ( + DEVICE, +) from deepmd.pt.utils.lmdb_dataset import ( LmdbDataset, ) from deepmd.utils.argcheck import ( normalize, ) +from deepmd.utils.eval_metrics import ( + SPIN_FULL_VALIDATION_PROFILE, + compute_full_validation_spin_metrics, +) from .model.test_permutation import ( model_se_e2_a, + model_spin, ) @@ -175,6 +186,21 @@ def _make_single_task_config() -> dict: } +def _make_spin_task_config() -> dict: + config = _make_single_task_config() + config["loss"] = { + "type": "ener_spin", + "start_pref_e": 1.0, + "limit_pref_e": 1.0, + "start_pref_fr": 1.0, + "limit_pref_fr": 1.0, + "start_pref_fm": 1.0, + "limit_pref_fm": 1.0, + } + config["validating"]["validation_metric"] = "FR:MAE" + return config + + class TestValidationHelpers(unittest.TestCase): def test_resolve_full_validation_start_step(self) -> None: self.assertEqual(resolve_full_validation_start_step(0, 2000000), 0) @@ -450,6 +476,15 @@ def test_full_validator_lmdb_snapshot_requires_type_map(self) -> None: class TestValidationArgcheck(unittest.TestCase): + def test_normalize_accepts_amp_infer(self) -> None: + config = _make_single_task_config() + normalized = normalize(config) + self.assertFalse(normalized["validating"]["amp_infer"]) + + config["validating"]["amp_infer"] = True + normalized = normalize(config) + self.assertTrue(normalized["validating"]["amp_infer"]) + def test_normalize_rejects_missing_validation_data(self) -> None: config = _make_single_task_config() del config["training"]["validation_data"] @@ -490,3 +525,140 @@ def test_normalize_rejects_nonpositive_max_best_ckpt(self) -> None: config["validating"]["max_best_ckpt"] = 0 with self.assertRaisesRegex(ArgumentValueError, "max_best_ckpt"): normalize(config) + + def test_normalize_accepts_spin_force_metric(self) -> None: + config = _make_spin_task_config() + normalized = normalize(config) + self.assertEqual(normalized["validating"]["validation_metric"], "FR:MAE") + + def test_normalize_rejects_energy_force_metric_for_spin(self) -> None: + config = _make_spin_task_config() + config["validating"]["validation_metric"] = "F:MAE" + with self.assertRaisesRegex(ValueError, "spin training"): + normalize(config) + + def test_normalize_rejects_spin_force_metric_for_energy(self) -> None: + config = _make_single_task_config() + config["validating"]["validation_metric"] = "FR:MAE" + with self.assertRaisesRegex(ValueError, "energy training"): + normalize(config) + + def test_normalize_rejects_inactive_spin_prefactor_metric(self) -> None: + config = _make_spin_task_config() + config["validating"]["validation_metric"] = "FM:RMSE" + config["loss"]["limit_pref_fm"] = 0.0 + with self.assertRaisesRegex(ValueError, "start_pref_fm"): + normalize(config) + + +class TestFullValidationMetricProfiles(unittest.TestCase): + def test_spin_profile_splits_real_and_magnetic_forces(self) -> None: + prediction = { + "energy": np.array([[6.0]]), + "force": np.array([[1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0]]), + "force_mag": np.array( + [[10.0, 10.0, 10.0, 99.0, 99.0, 99.0, 20.0, 20.0, 20.0]] + ), + "mask_mag": np.array([[True, False, True]]), + } + test_data = { + "find_energy": 1.0, + "find_force": 1.0, + "find_force_mag": 1.0, + "energy": np.array([[0.0]]), + "force": np.zeros((1, 9)), + "force_mag": np.zeros((1, 9)), + } + metrics = compute_full_validation_spin_metrics( + prediction, test_data, natoms=3, has_pbc=False + ) + # Energy is normalized per atom: |6| / 3 = 2. + self.assertAlmostEqual(metrics["mae_e_per_atom"][0], 2.0) + self.assertAlmostEqual(metrics["rmse_e_per_atom"][0], 2.0) + # Real force spans all three atoms (nine components). + self.assertAlmostEqual(metrics["mae_fr"][0], 2.0) + self.assertAlmostEqual(metrics["rmse_fr"][0], np.sqrt(42.0 / 9.0)) + self.assertEqual(metrics["mae_fr"][1], 9.0) + # Magnetic force only sees masked atoms 0 and 2 (six components). + self.assertAlmostEqual(metrics["mae_fm"][0], 15.0) + self.assertAlmostEqual(metrics["rmse_fm"][0], np.sqrt(250.0)) + self.assertEqual(metrics["mae_fm"][1], 6.0) + + def test_spin_profile_omits_magnetic_force_when_unavailable(self) -> None: + prediction = { + "energy": np.array([[3.0]]), + "force": np.zeros((1, 9)), + "force_mag": np.zeros((1, 9)), + "mask_mag": np.array([[True, False, True]]), + } + test_data = { + "find_energy": 1.0, + "find_force": 1.0, + "find_force_mag": 0.0, + "energy": np.array([[0.0]]), + "force": np.zeros((1, 9)), + } + metrics = compute_full_validation_spin_metrics( + prediction, test_data, natoms=3, has_pbc=False + ) + self.assertIn("mae_fr", metrics) + self.assertNotIn("mae_fm", metrics) + + def test_predict_outputs_emits_real_and_magnetic_forces(self) -> None: + model = get_model(deepcopy(model_spin)).to(DEVICE) + nframes = 2 + natoms = 5 + rng = np.random.default_rng(0) + coord = 3.0 * rng.random((nframes, natoms * 3)) + atom_types = np.tile(np.array([0, 0, 0, 1, 1]), (nframes, 1)) + box = np.tile((np.eye(3) * 6.0).reshape(9), (nframes, 1)) + spin = 0.5 * rng.random((nframes, natoms * 3)) + with tempfile.TemporaryDirectory() as tmpdir: + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + validator = FullValidator( + validating_params={ + "full_validation": True, + "validation_freq": 1, + "save_best": False, + "max_best_ckpt": 1, + "validation_metric": "FR:MAE", + "full_val_file": "val.log", + "full_val_start": 0.0, + }, + validation_data=_DummyValidationData(), + model=model, + state_store={}, + num_steps=10, + rank=0, + zero_stage=0, + restart_training=False, + ) + self.assertIs(validator.profile, SPIN_FULL_VALIDATION_PROFILE) + prediction = validator._predict_outputs( + coord=coord, + atom_types=atom_types, + box=box, + fparam=None, + aparam=None, + spin=spin, + include_virial=False, + natoms=natoms, + nframes=nframes, + ) + finally: + os.chdir(old_cwd) + + self.assertEqual(prediction["energy"].shape, (nframes, 1)) + self.assertEqual(prediction["force"].shape, (nframes, natoms * 3)) + self.assertEqual(prediction["force_mag"].shape, (nframes, natoms * 3)) + self.assertEqual(prediction["mask_mag"].shape, (nframes, natoms)) + self.assertNotIn("virial", prediction) + # use_spin=[True, False, False] makes only type-0 atoms magnetic. + expected_mask = np.tile( + np.array([True, True, True, False, False]), (nframes, 1) + ) + np.testing.assert_array_equal( + prediction["mask_mag"].astype(bool), expected_mask + ) diff --git a/source/tests/pt_expt/utils/test_border_op_backward.py b/source/tests/pt_expt/utils/test_border_op_backward.py index b33e575f1a..07a5ac67ab 100644 --- a/source/tests/pt_expt/utils/test_border_op_backward.py +++ b/source/tests/pt_expt/utils/test_border_op_backward.py @@ -116,10 +116,12 @@ def test_border_op_backward_direct(dtype: torch.dtype) -> None: def test_border_op_backward_accumulation_semantics() -> None: """Single-rank self-exchange backward: each ghost slot's grad is - accumulated into the local atom whose index sendlist points to. + accumulated into the local atom whose index sendlist points to, and the + ghost rows are zeroed. - Reference: for forward ``g_ext[nloc + i] = g[sendlist[i]]``, the - reverse is ``grad_g[sendlist[i]] += grad_g_ext[nloc + i]``. + Reference: for forward ``g_ext[nloc + i] = g[sendlist[i]]``, the reverse is + ``grad_g[sendlist[i]] += grad_g_ext[nloc + i]`` with ``grad_g[ghost] = 0`` + (the forward overwrites the ghost INPUT rows, so they carry no gradient). """ nloc, nghost = 4, 4 nall = nloc + nghost @@ -164,13 +166,15 @@ def test_border_op_backward_accumulation_semantics() -> None: comm[7], ) - # Expected: grad_g_local += grad_g_ext[nloc:] indexed by sendlist. - # Ghost rows pass through unchanged (the C++ backward does not - # zero them; the wrapper's autograd consumer is F.pad whose - # backward drops them anyway). + # Expected: grad_g_local += grad_g_ext[nloc:] indexed by sendlist, and the + # ghost rows are zero. The forward overwrites every ghost row + # (g_ext[ghost] = g[owner]), so a ghost INPUT never reaches the output and + # its gradient is exactly zero -- the backward zeros those rows to return + # the true Jacobian-vector product. expected = grad_g1_orig.clone() for i, src_local_idx in enumerate(sendlist_indices.tolist()): expected[src_local_idx] += grad_g1_orig[nloc + i] + expected[nloc:] = 0.0 np.testing.assert_allclose( grad_in.numpy(), expected.numpy(), @@ -241,8 +245,14 @@ def test_border_op_export_autograd(dtype: torch.dtype) -> None: atol=atol, rtol=rtol, ) - # Ghost rows of grad_in are not semantically meaningful: in - # production the wrapper's input is ``F.pad(node_ebd, value=0)`` - # so the ghost-row gradient is consumed by ``F.pad``'s backward - # (which drops it). The C++ backward leaves them as the upstream - # grad (here, ones), but we don't assert on it. + # Ghost INPUT rows are overwritten by the forward exchange, so the corrected + # C++ backward returns zero there (the true VJP). In production the DPA2/DPA3 + # wrappers feed ``F.pad(node_ebd, value=0)`` whose backward also drops these + # rows, so they are insensitive to the ghost gradient; SeZM feeds real + # per-node features through the exchange and is not. + np.testing.assert_allclose( + grad_in[nloc:].numpy(), + np.zeros((nghost, n_dim), dtype=grad_in.detach().numpy().dtype), + atol=atol, + rtol=rtol, + )