diff --git a/dpdata/formats/deepmd/comp.py b/dpdata/formats/deepmd/comp.py index 410d789e..1cc04305 100644 --- a/dpdata/formats/deepmd/comp.py +++ b/dpdata/formats/deepmd/comp.py @@ -152,6 +152,12 @@ def dump(folder, data, set_size=5000, comp_prec=np.float32, remove_sets=True): f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted to deepmd/npy format." ) continue + if nframes > 0 and np.asarray(data[dtype.name]).size == 0: + # an optional frame property (e.g. forces/virials when + # cal_force/cal_stress is disabled) may be empty while the + # system still has frames. Skip it instead of writing a + # meaningless (nframes, 0) array that cannot be reshaped on load. + continue ddata = np.reshape(data[dtype.name], [nframes, -1]) if np.issubdtype(ddata.dtype, np.floating): ddata = ddata.astype(comp_prec) diff --git a/dpdata/formats/deepmd/raw.py b/dpdata/formats/deepmd/raw.py index 50dc5afd..8a1479e1 100644 --- a/dpdata/formats/deepmd/raw.py +++ b/dpdata/formats/deepmd/raw.py @@ -136,5 +136,11 @@ def dump(folder, data): f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted to deepmd/raw format." ) continue + if nframes > 0 and np.asarray(data[dtype.name]).size == 0: + # an optional frame property (e.g. forces/virials when + # cal_force/cal_stress is disabled) may be empty while the + # system still has frames. Skip it instead of writing a + # meaningless (nframes, 0) array that cannot be reshaped on load. + continue ddata = np.reshape(data[dtype.name], [nframes, -1]) np.savetxt(os.path.join(folder, f"{dtype.deepmd_name}.raw"), ddata) diff --git a/tests/test_abacus_pw_scf.py b/tests/test_abacus_pw_scf.py index 0d2bdef5..6383f3f6 100644 --- a/tests/test_abacus_pw_scf.py +++ b/tests/test_abacus_pw_scf.py @@ -2,6 +2,7 @@ import os import shutil +import tempfile import unittest import numpy as np @@ -163,6 +164,34 @@ def test_noforcestress_job(self): # test append self system_ch4.append(system_ch4) + def test_noforcestress_deepmd_roundtrip(self): + # a converged scf without force/stress should survive a + # round-trip through deepmd/npy without raising a reshape error + system_ch4 = dpdata.LabeledSystem("abacus.scf", fmt="abacus/scf") + tmp_dir = tempfile.mkdtemp() + try: + system_ch4.to("deepmd/npy", tmp_dir) + reloaded = dpdata.LabeledSystem(tmp_dir, fmt="deepmd/npy") + self.assertEqual(reloaded.get_nframes(), system_ch4.get_nframes()) + # empty force/virial should not be written as bogus data + self.assertFalse(reloaded.data.get("forces", np.empty(0)).size) + self.assertTrue("virials" not in reloaded.data) + finally: + shutil.rmtree(tmp_dir) + + def test_noforcestress_deepmd_raw_roundtrip(self): + # same as above but for the deepmd/raw format + system_ch4 = dpdata.LabeledSystem("abacus.scf", fmt="abacus/scf") + tmp_dir = tempfile.mkdtemp() + try: + system_ch4.to("deepmd/raw", tmp_dir) + reloaded = dpdata.LabeledSystem(tmp_dir, fmt="deepmd/raw") + self.assertEqual(reloaded.get_nframes(), system_ch4.get_nframes()) + self.assertFalse(reloaded.data.get("forces", np.empty(0)).size) + self.assertTrue("virials" not in reloaded.data) + finally: + shutil.rmtree(tmp_dir) + if __name__ == "__main__": unittest.main()