From a6ced9fb3f7ec8ba8648363b9a90d4abacd504a3 Mon Sep 17 00:00:00 2001 From: Peng Xingliang <91927439+pxlxingliang@users.noreply.github.com> Date: Tue, 15 Oct 2024 08:08:24 +0800 Subject: [PATCH] abacus: add checks on pp and orb in construction of STRU (#737) ## Summary by CodeRabbit - **New Features** - Introduced a mandatory parameter for pseudopotential files and an optional parameter to specify the destination directory for symbolic links. - **Improvements** - Enhanced clarity and functionality in handling input parameters for pseudopotential and orbital files. - Streamlined output construction for `pp_file` and `numerical_orbital`. - **Bug Fixes** - Added tests to validate error handling for mismatches between provided pseudopotential files and numerical orbitals, ensuring appropriate exceptions are raised. - **Documentation** - Added comments for better understanding of code functionality. --------- Co-authored-by: root Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- dpdata/abacus/scf.py | 90 +++++++++++++++++----------------- tests/test_abacus_stru_dump.py | 43 ++++++++++++++++ 2 files changed, 87 insertions(+), 46 deletions(-) diff --git a/dpdata/abacus/scf.py b/dpdata/abacus/scf.py index 9919e9128..b1b2cfed9 100644 --- a/dpdata/abacus/scf.py +++ b/dpdata/abacus/scf.py @@ -578,7 +578,7 @@ def get_frame_from_stru(fname): def make_unlabeled_stru( data, frame_idx, - pp_file=None, + pp_file, numerical_orbital=None, numerical_descriptor=None, mass=None, @@ -601,7 +601,7 @@ def make_unlabeled_stru( System data frame_idx : int The index of the frame to dump - pp_file : list of string or dict, optional + pp_file : list of string or dict List of pseudo potential files, or a dictionary of pseudo potential files for each atomnames numerical_orbital : list of string or dict, optional List of orbital files, or a dictionary of orbital files for each atomnames @@ -628,6 +628,8 @@ def make_unlabeled_stru( link_file : bool, optional Whether to link the pseudo potential files and orbital files in the STRU file. If True, then only filename will be written in the STRU file, and make a soft link to the real file. + dest_dir : str, optional + The destination directory to make the soft link of the pseudo potential files and orbital files. For velocity, mag, angle1, angle2, sc, and lambda_, if the value is None, then the corresponding information will not be written. ABACUS support defining "mag" and "angle1"/"angle2" at the same time, and in this case, the "mag" only define the norm of the magnetic moment, and "angle1" and "angle2" define the direction of the magnetic moment. If data has spins, then it will be written as mag to STRU file; while if mag is passed at the same time, then mag will be used. @@ -655,6 +657,23 @@ def ndarray2list(i): else: return i + def process_file_input(file_input, atom_names, input_name): + # For pp_file and numerical_orbital, process the file input, and return a list of file names + # file_input can be a list of file names, or a dictionary of file names for each atom names + if isinstance(file_input, (list, tuple)): + if len(file_input) != len(atom_names): + raise ValueError( + f"{input_name} length is not equal to the number of atom types" + ) + return file_input + elif isinstance(file_input, dict): + for element in atom_names: + if element not in file_input: + raise KeyError(f"{input_name} does not contain {element}") + return [file_input[element] for element in atom_names] + else: + raise ValueError(f"Invalid {input_name}: {file_input}") + if link_file and dest_dir is None: print( "WARNING: make_unlabeled_stru: link_file is True, but dest_dir is None. Will write the filename to STRU but not making soft link." @@ -680,8 +699,8 @@ def ndarray2list(i): # ATOMIC_SPECIES block out = "ATOMIC_SPECIES\n" - if pp_file is not None: - pp_file = ndarray2list(pp_file) + ppfiles = process_file_input(ndarray2list(pp_file), data["atom_names"], "pp_file") + for iele in range(len(data["atom_names"])): if data["atom_numbs"][iele] == 0: continue @@ -690,57 +709,36 @@ def ndarray2list(i): out += f"{mass[iele]:.3f} " else: out += "1 " - if pp_file is not None: - if isinstance(pp_file, (list, tuple)): - ipp_file = pp_file[iele] - elif isinstance(pp_file, dict): - if data["atom_names"][iele] not in pp_file: - print( - f"ERROR: make_unlabeled_stru: pp_file does not contain {data['atom_names'][iele]}" - ) - ipp_file = None - else: - ipp_file = pp_file[data["atom_names"][iele]] - else: - ipp_file = None - if ipp_file is not None: - if not link_file: - out += ipp_file - else: - out += os.path.basename(ipp_file.rstrip("/")) - if dest_dir is not None: - _link_file(dest_dir, ipp_file) + ipp_file = ppfiles[iele] + if not link_file: + out += ipp_file + else: + out += os.path.basename(ipp_file.rstrip("/")) + if dest_dir is not None: + _link_file(dest_dir, ipp_file) out += "\n" out += "\n" # NUMERICAL_ORBITAL block if numerical_orbital is not None: - assert len(numerical_orbital) == len(data["atom_names"]) numerical_orbital = ndarray2list(numerical_orbital) + orbfiles = process_file_input( + numerical_orbital, data["atom_names"], "numerical_orbital" + ) + orbfiles = [ + orbfiles[i] + for i in range(len(data["atom_names"])) + if data["atom_numbs"][i] != 0 + ] out += "NUMERICAL_ORBITAL\n" - for iele in range(len(data["atom_names"])): - if data["atom_numbs"][iele] == 0: - continue - if isinstance(numerical_orbital, (list, tuple)): - inum_orbital = numerical_orbital[iele] - elif isinstance(numerical_orbital, dict): - if data["atom_names"][iele] not in numerical_orbital: - print( - f"ERROR: make_unlabeled_stru: numerical_orbital does not contain {data['atom_names'][iele]}" - ) - inum_orbital = None - else: - inum_orbital = numerical_orbital[data["atom_names"][iele]] + for iorb in orbfiles: + if not link_file: + out += iorb else: - inum_orbital = None - if inum_orbital is not None: - if not link_file: - out += inum_orbital - else: - out += os.path.basename(inum_orbital.rstrip("/")) - if dest_dir is not None: - _link_file(dest_dir, inum_orbital) + out += os.path.basename(iorb.rstrip("/")) + if dest_dir is not None: + _link_file(dest_dir, iorb) out += "\n" out += "\n" diff --git a/tests/test_abacus_stru_dump.py b/tests/test_abacus_stru_dump.py index 084a5473c..4549b6d16 100644 --- a/tests/test_abacus_stru_dump.py +++ b/tests/test_abacus_stru_dump.py @@ -51,6 +51,49 @@ def test_dumpStruLinkFile(self): if os.path.isdir("abacus.scf/tmp"): shutil.rmtree("abacus.scf/tmp") + def test_dump_stru_pporb_mismatch(self): + with self.assertRaises(KeyError, msg="pp_file is a dict and lack of pp for H"): + self.system_ch4.to( + "stru", + "STRU_tmp", + mass=[12, 1], + pp_file={"C": "C.upf", "O": "O.upf"}, + numerical_orbital={"C": "C.orb", "H": "H.orb"}, + ) + + with self.assertRaises( + ValueError, msg="pp_file is a list and lack of pp for H" + ): + self.system_ch4.to( + "stru", + "STRU_tmp", + mass=[12, 1], + pp_file=["C.upf"], + numerical_orbital={"C": "C.orb", "H": "H.orb"}, + ) + + with self.assertRaises( + KeyError, msg="numerical_orbital is a dict and lack of orbital for H" + ): + self.system_ch4.to( + "stru", + "STRU_tmp", + mass=[12, 1], + pp_file={"C": "C.upf", "H": "H.upf"}, + numerical_orbital={"C": "C.orb", "O": "O.orb"}, + ) + + with self.assertRaises( + ValueError, msg="numerical_orbital is a list and lack of orbital for H" + ): + self.system_ch4.to( + "stru", + "STRU_tmp", + mass=[12, 1], + pp_file=["C.upf", "H.upf"], + numerical_orbital=["C.orb"], + ) + def test_dump_spinconstrain(self): self.system_ch4.to( "stru",