From f05f9ca8eab87ffe38d3798e5188bd8db3276b77 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Wed, 15 Jan 2025 16:55:21 +0800 Subject: [PATCH] fix: add more checks --- dpdata/driver.py | 3 ++- dpdata/plugins/ase.py | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/dpdata/driver.py b/dpdata/driver.py index b5ff53403..b63c417af 100644 --- a/dpdata/driver.py +++ b/dpdata/driver.py @@ -166,7 +166,8 @@ def label(self, data: dict) -> dict: labeled_data = lb_data.copy() else: labeled_data["energies"] += lb_data["energies"] - labeled_data["forces"] += lb_data["forces"] + if "forces" in labeled_data and "forces" in lb_data: + labeled_data["forces"] += lb_data["forces"] if "virials" in labeled_data and "virials" in lb_data: labeled_data["virials"] += lb_data["virials"] return labeled_data diff --git a/dpdata/plugins/ase.py b/dpdata/plugins/ase.py index bafd9e7e7..1361494d0 100644 --- a/dpdata/plugins/ase.py +++ b/dpdata/plugins/ase.py @@ -298,7 +298,8 @@ def from_labeled_system( dict_frames["energies"] = np.append( dict_frames["energies"], tmp["energies"][0] ) - dict_frames["forces"] = np.append(dict_frames["forces"], tmp["forces"][0]) + if "forces" in tmp.keys() and "forces" in dict_frames.keys(): + dict_frames["forces"] = np.append(dict_frames["forces"], tmp["forces"][0]) if "virials" in tmp.keys() and "virials" in dict_frames.keys(): dict_frames["virials"] = np.append( dict_frames["virials"], tmp["virials"][0] @@ -307,7 +308,8 @@ def from_labeled_system( ## Correct the shape of numpy arrays dict_frames["cells"] = dict_frames["cells"].reshape(-1, 3, 3) dict_frames["coords"] = dict_frames["coords"].reshape(len(sub_traj), -1, 3) - dict_frames["forces"] = dict_frames["forces"].reshape(len(sub_traj), -1, 3) + if "forces" in dict_frames.keys(): + dict_frames["forces"] = dict_frames["forces"].reshape(len(sub_traj), -1, 3) if "virials" in dict_frames.keys(): dict_frames["virials"] = dict_frames["virials"].reshape(-1, 3, 3)