diff --git a/dpdata/system.py b/dpdata/system.py index 001726020..645628ccd 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -1209,7 +1209,11 @@ class LabeledSystem(System): DTYPES: tuple[DataType, ...] = System.DTYPES + ( DataType("energies", np.ndarray, (Axis.NFRAMES,), deepmd_name="energy"), DataType( - "forces", np.ndarray, (Axis.NFRAMES, Axis.NATOMS, 3), deepmd_name="force" + "forces", + np.ndarray, + (Axis.NFRAMES, Axis.NATOMS, 3), + required=False, + deepmd_name="force", ), DataType( "virials", @@ -1269,13 +1273,17 @@ def __add__(self, others): raise RuntimeError("Unspported data structure") return self.__class__.from_dict({"data": self_copy.data}) + def has_forces(self) -> bool: + return "forces" in self.data + def has_virial(self) -> bool: # return ('virials' in self.data) and (len(self.data['virials']) > 0) return "virials" in self.data def affine_map_fv(self, trans, f_idx: int | numbers.Integral): assert np.linalg.det(trans) != 0 - self.data["forces"][f_idx] = np.matmul(self.data["forces"][f_idx], trans) + if self.has_forces(): + self.data["forces"][f_idx] = np.matmul(self.data["forces"][f_idx], trans) if self.has_virial(): self.data["virials"][f_idx] = np.matmul( trans.T, np.matmul(self.data["virials"][f_idx], trans) @@ -1308,7 +1316,8 @@ def correction(self, hl_sys: LabeledSystem) -> LabeledSystem: raise RuntimeError("high_sys should be LabeledSystem") corrected_sys = self.copy() corrected_sys.data["energies"] = hl_sys.data["energies"] - self.data["energies"] - corrected_sys.data["forces"] = hl_sys.data["forces"] - self.data["forces"] + if "forces" in self.data and "forces" in hl_sys.data: + corrected_sys.data["forces"] = hl_sys.data["forces"] - self.data["forces"] if "virials" in self.data and "virials" in hl_sys.data: corrected_sys.data["virials"] = ( hl_sys.data["virials"] - self.data["virials"]