From 956e4bbcc4490ed1e731ec1d8b9673a37947f018 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 4 Jul 2026 01:49:57 +0800 Subject: [PATCH] fix(tf): reject non-prefix use_spin layouts in the spin helper The legacy TensorFlow spin implementation assumes spin-enabled types form a contiguous prefix of the type map. The SE-A sel extension takes the first ntypes_spin selections (sel_a[:ntypes_spin]), and the coordinate splitting (se_a.py), force splitting (model/ener.py), and bias merging (fit/ener.py) all address the virtual block with a dense real->virtual offset (i + len(use_spin)). For a non-prefix layout such as use_spin=[False, True], these read the wrong real/virtual type ranges or raise deep inside the graph, and nothing rejected the configuration up front. Rather than refactor all four sites in this legacy backend (the maintained PyTorch backend already supports the sparse layout via Spin.spin_type), guard against it: Spin now rejects a use_spin where a non-spin type precedes a spin-enabled one, with a clear message telling the user to list spin-enabled types first. This turns a silent-wrong result or an obscure crash into an actionable error and documents the invariant. Adds source/tests/tf/test_spin_prefix_guard.py: non-prefix layouts ([False, True] and [True, False, True]) raise ValueError, while prefix layouts ([True, False], [True, True]) and an all-non-spin list ([False, False]) are accepted. The prefix requirement previously had no test. Fix #5680 --- deepmd/tf/utils/spin.py | 19 ++++++++++ source/tests/tf/test_spin_prefix_guard.py | 44 +++++++++++++++++++++++ 2 files changed, 63 insertions(+) create mode 100644 source/tests/tf/test_spin_prefix_guard.py diff --git a/deepmd/tf/utils/spin.py b/deepmd/tf/utils/spin.py index 9d1d7018fd..6e2933be17 100644 --- a/deepmd/tf/utils/spin.py +++ b/deepmd/tf/utils/spin.py @@ -26,6 +26,25 @@ def __init__( virtual_len: list[float] | None = None, ) -> None: """Constructor.""" + # The TensorFlow spin implementation assumes spin-enabled types form a + # contiguous prefix of the type map: the SE-A ``sel`` extension takes the + # first ``ntypes_spin`` selections, and the coordinate/force splitting + # and bias merging address the virtual block with a dense real->virtual + # offset. Reject a layout where a non-spin type precedes a spin type, + # which would silently read the wrong real/virtual type ranges. + if use_spin is not None: + seen_non_spin = False + for flag in use_spin: + if flag: + if seen_non_spin: + raise ValueError( + "The TensorFlow spin implementation requires " + "spin-enabled types (use_spin=True) to form a prefix " + f"of the type map; got use_spin={use_spin}. List all " + "spin-enabled types first in the type map." + ) + else: + seen_non_spin = True self.use_spin = use_spin self.spin_norm = spin_norm self.virtual_len = virtual_len diff --git a/source/tests/tf/test_spin_prefix_guard.py b/source/tests/tf/test_spin_prefix_guard.py new file mode 100644 index 0000000000..31a8ea0727 --- /dev/null +++ b/source/tests/tf/test_spin_prefix_guard.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""The TF Spin helper must reject non-prefix use_spin layouts. + +The legacy TensorFlow spin implementation assumes spin-enabled types form a +contiguous prefix of the type map: the SE-A ``sel`` extension takes the first +``ntypes_spin`` selections, and the coordinate/force splitting and bias merging +address the virtual block with a dense real->virtual offset. A non-prefix layout +such as ``use_spin=[False, True]`` silently reads the wrong real/virtual type +ranges (or raises deep inside the graph), so it must be rejected up front with a +clear error. +""" + +import unittest + +from deepmd.tf.utils.spin import ( + Spin, +) + + +class TestSpinPrefixGuard(unittest.TestCase): + def test_non_prefix_rejected(self) -> None: + with self.assertRaises(ValueError): + Spin(use_spin=[False, True], spin_norm=[1.0], virtual_len=[0.4]) + + def test_non_prefix_rejected_middle(self) -> None: + with self.assertRaises(ValueError): + Spin( + use_spin=[True, False, True], + spin_norm=[1.0, 1.0], + virtual_len=[0.4, 0.4], + ) + + def test_prefix_accepted(self) -> None: + # spin-enabled types first: the supported layout + Spin(use_spin=[True, False], spin_norm=[1.0], virtual_len=[0.4]) + self.assertEqual(Spin(use_spin=[True, True]).ntypes_spin, 2) + + def test_all_non_spin_accepted(self) -> None: + # no spin types at all is not a non-prefix violation + self.assertEqual(Spin(use_spin=[False, False]).ntypes_spin, 0) + + +if __name__ == "__main__": + unittest.main()