Skip to content

Commit

Permalink
Feat: set force label optional (#772)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **Improvements**
- Enhanced flexibility in data handling by making forces data optional
in the system configuration.
- Added a method to check for the presence of forces data in the system.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Han Wang <[email protected]>
  • Loading branch information
3 people authored Jan 13, 2025
1 parent 2905792 commit 46251a7
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions dpdata/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,7 +1209,11 @@ class LabeledSystem(System):
DTYPES: tuple[DataType, ...] = System.DTYPES + (
DataType("energies", np.ndarray, (Axis.NFRAMES,), deepmd_name="energy"),
DataType(
"forces", np.ndarray, (Axis.NFRAMES, Axis.NATOMS, 3), deepmd_name="force"
"forces",
np.ndarray,
(Axis.NFRAMES, Axis.NATOMS, 3),
required=False,
deepmd_name="force",
),
DataType(
"virials",
Expand Down Expand Up @@ -1269,13 +1273,17 @@ def __add__(self, others):
raise RuntimeError("Unspported data structure")
return self.__class__.from_dict({"data": self_copy.data})

def has_forces(self) -> bool:
return "forces" in self.data

def has_virial(self) -> bool:
# return ('virials' in self.data) and (len(self.data['virials']) > 0)
return "virials" in self.data

def affine_map_fv(self, trans, f_idx: int | numbers.Integral):
assert np.linalg.det(trans) != 0
self.data["forces"][f_idx] = np.matmul(self.data["forces"][f_idx], trans)
if self.has_forces():
self.data["forces"][f_idx] = np.matmul(self.data["forces"][f_idx], trans)
if self.has_virial():
self.data["virials"][f_idx] = np.matmul(
trans.T, np.matmul(self.data["virials"][f_idx], trans)
Expand Down Expand Up @@ -1308,7 +1316,8 @@ def correction(self, hl_sys: LabeledSystem) -> LabeledSystem:
raise RuntimeError("high_sys should be LabeledSystem")
corrected_sys = self.copy()
corrected_sys.data["energies"] = hl_sys.data["energies"] - self.data["energies"]
corrected_sys.data["forces"] = hl_sys.data["forces"] - self.data["forces"]
if "forces" in self.data and "forces" in hl_sys.data:
corrected_sys.data["forces"] = hl_sys.data["forces"] - self.data["forces"]
if "virials" in self.data and "virials" in hl_sys.data:
corrected_sys.data["virials"] = (
hl_sys.data["virials"] - self.data["virials"]
Expand Down

0 comments on commit 46251a7

Please sign in to comment.