diff --git a/dpdata/vasp/poscar.py b/dpdata/vasp/poscar.py index 075d8f2be..30073e2b3 100644 --- a/dpdata/vasp/poscar.py +++ b/dpdata/vasp/poscar.py @@ -121,12 +121,11 @@ def from_system_data(system, f_idx=0, skip_zeros=True): line = f"{ii_posi[0]:15.10f} {ii_posi[1]:15.10f} {ii_posi[2]:15.10f}" if move is not None and len(move) > 0: move_flags = move[idx] - if isinstance(move_flags, list) and len(move_flags) == 3: - line += " " + " ".join(["T" if flag else "F" for flag in move_flags]) - else: + if not isinstance(move_flags, list) or len(move_flags) != 3: raise RuntimeError( f"Invalid move flags: {move_flags}, should be a list of 3 bools" ) + line += " " + " ".join("T" if flag else "F" for flag in move_flags) posi_list.append(line) diff --git a/tests/test_vasp_poscar_to_system.py b/tests/test_vasp_poscar_to_system.py index 23b1e982b..16cd64b83 100644 --- a/tests/test_vasp_poscar_to_system.py +++ b/tests/test_vasp_poscar_to_system.py @@ -19,7 +19,7 @@ def test_move_flags(self): self.assertTrue(np.array_equal(self.system["move"], expected)) -class TestPOSCARCart(unittest.TestCase): +class TestPOSCARMoveFlags(unittest.TestCase): def test_move_flags_error1(self): with self.assertRaisesRegex(RuntimeError, "Invalid move flags.*?"): dpdata.System().from_vasp_poscar(os.path.join("poscars", "POSCAR.oh.err1"))