Skip to content

Commit

Permalink
fix: add more checks
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml committed Jan 15, 2025
1 parent 944af9d commit f05f9ca
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
3 changes: 2 additions & 1 deletion dpdata/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ def label(self, data: dict) -> dict:
labeled_data = lb_data.copy()
else:
labeled_data["energies"] += lb_data["energies"]
labeled_data["forces"] += lb_data["forces"]
if "forces" in labeled_data and "forces" in lb_data:
labeled_data["forces"] += lb_data["forces"]
if "virials" in labeled_data and "virials" in lb_data:
labeled_data["virials"] += lb_data["virials"]
return labeled_data
Expand Down
6 changes: 4 additions & 2 deletions dpdata/plugins/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,8 @@ def from_labeled_system(
dict_frames["energies"] = np.append(
dict_frames["energies"], tmp["energies"][0]
)
dict_frames["forces"] = np.append(dict_frames["forces"], tmp["forces"][0])
if "forces" in tmp.keys() and "forces" in dict_frames.keys():
dict_frames["forces"] = np.append(dict_frames["forces"], tmp["forces"][0])
if "virials" in tmp.keys() and "virials" in dict_frames.keys():
dict_frames["virials"] = np.append(
dict_frames["virials"], tmp["virials"][0]
Expand All @@ -307,7 +308,8 @@ def from_labeled_system(
## Correct the shape of numpy arrays
dict_frames["cells"] = dict_frames["cells"].reshape(-1, 3, 3)
dict_frames["coords"] = dict_frames["coords"].reshape(len(sub_traj), -1, 3)
dict_frames["forces"] = dict_frames["forces"].reshape(len(sub_traj), -1, 3)
if "forces" in dict_frames.keys():
dict_frames["forces"] = dict_frames["forces"].reshape(len(sub_traj), -1, 3)
if "virials" in dict_frames.keys():
dict_frames["virials"] = dict_frames["virials"].reshape(-1, 3, 3)

Expand Down

0 comments on commit f05f9ca

Please sign in to comment.