Skip to content

Commit

Permalink
abacus: add checks on pp and orb in construction of STRU (#737)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced a mandatory parameter for pseudopotential files and an
optional parameter to specify the destination directory for symbolic
links.

- **Improvements**
- Enhanced clarity and functionality in handling input parameters for
pseudopotential and orbital files.
- Streamlined output construction for `pp_file` and `numerical_orbital`.

- **Bug Fixes**
- Added tests to validate error handling for mismatches between provided
pseudopotential files and numerical orbitals, ensuring appropriate
exceptions are raised.

- **Documentation**
	- Added comments for better understanding of code functionality.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: root <pxlxingliang>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
pxlxingliang and pre-commit-ci[bot] authored Oct 15, 2024
1 parent b797acb commit a6ced9f
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 46 deletions.
90 changes: 44 additions & 46 deletions dpdata/abacus/scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ def get_frame_from_stru(fname):
def make_unlabeled_stru(
data,
frame_idx,
pp_file=None,
pp_file,
numerical_orbital=None,
numerical_descriptor=None,
mass=None,
Expand All @@ -601,7 +601,7 @@ def make_unlabeled_stru(
System data
frame_idx : int
The index of the frame to dump
pp_file : list of string or dict, optional
pp_file : list of string or dict
List of pseudo potential files, or a dictionary of pseudo potential files for each atomnames
numerical_orbital : list of string or dict, optional
List of orbital files, or a dictionary of orbital files for each atomnames
Expand All @@ -628,6 +628,8 @@ def make_unlabeled_stru(
link_file : bool, optional
Whether to link the pseudo potential files and orbital files in the STRU file.
If True, then only filename will be written in the STRU file, and make a soft link to the real file.
dest_dir : str, optional
The destination directory to make the soft link of the pseudo potential files and orbital files.
For velocity, mag, angle1, angle2, sc, and lambda_, if the value is None, then the corresponding information will not be written.
ABACUS support defining "mag" and "angle1"/"angle2" at the same time, and in this case, the "mag" only define the norm of the magnetic moment, and "angle1" and "angle2" define the direction of the magnetic moment.
If data has spins, then it will be written as mag to STRU file; while if mag is passed at the same time, then mag will be used.
Expand Down Expand Up @@ -655,6 +657,23 @@ def ndarray2list(i):
else:
return i

def process_file_input(file_input, atom_names, input_name):
# For pp_file and numerical_orbital, process the file input, and return a list of file names
# file_input can be a list of file names, or a dictionary of file names for each atom names
if isinstance(file_input, (list, tuple)):
if len(file_input) != len(atom_names):
raise ValueError(
f"{input_name} length is not equal to the number of atom types"
)
return file_input
elif isinstance(file_input, dict):
for element in atom_names:
if element not in file_input:
raise KeyError(f"{input_name} does not contain {element}")
return [file_input[element] for element in atom_names]
else:
raise ValueError(f"Invalid {input_name}: {file_input}")

if link_file and dest_dir is None:
print(
"WARNING: make_unlabeled_stru: link_file is True, but dest_dir is None. Will write the filename to STRU but not making soft link."
Expand All @@ -680,8 +699,8 @@ def ndarray2list(i):

# ATOMIC_SPECIES block
out = "ATOMIC_SPECIES\n"
if pp_file is not None:
pp_file = ndarray2list(pp_file)
ppfiles = process_file_input(ndarray2list(pp_file), data["atom_names"], "pp_file")

for iele in range(len(data["atom_names"])):
if data["atom_numbs"][iele] == 0:
continue
Expand All @@ -690,57 +709,36 @@ def ndarray2list(i):
out += f"{mass[iele]:.3f} "
else:
out += "1 "
if pp_file is not None:
if isinstance(pp_file, (list, tuple)):
ipp_file = pp_file[iele]
elif isinstance(pp_file, dict):
if data["atom_names"][iele] not in pp_file:
print(
f"ERROR: make_unlabeled_stru: pp_file does not contain {data['atom_names'][iele]}"
)
ipp_file = None
else:
ipp_file = pp_file[data["atom_names"][iele]]
else:
ipp_file = None
if ipp_file is not None:
if not link_file:
out += ipp_file
else:
out += os.path.basename(ipp_file.rstrip("/"))
if dest_dir is not None:
_link_file(dest_dir, ipp_file)

ipp_file = ppfiles[iele]
if not link_file:
out += ipp_file
else:
out += os.path.basename(ipp_file.rstrip("/"))
if dest_dir is not None:
_link_file(dest_dir, ipp_file)
out += "\n"
out += "\n"

# NUMERICAL_ORBITAL block
if numerical_orbital is not None:
assert len(numerical_orbital) == len(data["atom_names"])
numerical_orbital = ndarray2list(numerical_orbital)
orbfiles = process_file_input(
numerical_orbital, data["atom_names"], "numerical_orbital"
)
orbfiles = [
orbfiles[i]
for i in range(len(data["atom_names"]))
if data["atom_numbs"][i] != 0
]
out += "NUMERICAL_ORBITAL\n"
for iele in range(len(data["atom_names"])):
if data["atom_numbs"][iele] == 0:
continue
if isinstance(numerical_orbital, (list, tuple)):
inum_orbital = numerical_orbital[iele]
elif isinstance(numerical_orbital, dict):
if data["atom_names"][iele] not in numerical_orbital:
print(
f"ERROR: make_unlabeled_stru: numerical_orbital does not contain {data['atom_names'][iele]}"
)
inum_orbital = None
else:
inum_orbital = numerical_orbital[data["atom_names"][iele]]
for iorb in orbfiles:
if not link_file:
out += iorb
else:
inum_orbital = None
if inum_orbital is not None:
if not link_file:
out += inum_orbital
else:
out += os.path.basename(inum_orbital.rstrip("/"))
if dest_dir is not None:
_link_file(dest_dir, inum_orbital)
out += os.path.basename(iorb.rstrip("/"))
if dest_dir is not None:
_link_file(dest_dir, iorb)
out += "\n"
out += "\n"

Expand Down
43 changes: 43 additions & 0 deletions tests/test_abacus_stru_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,49 @@ def test_dumpStruLinkFile(self):
if os.path.isdir("abacus.scf/tmp"):
shutil.rmtree("abacus.scf/tmp")

def test_dump_stru_pporb_mismatch(self):
with self.assertRaises(KeyError, msg="pp_file is a dict and lack of pp for H"):
self.system_ch4.to(
"stru",
"STRU_tmp",
mass=[12, 1],
pp_file={"C": "C.upf", "O": "O.upf"},
numerical_orbital={"C": "C.orb", "H": "H.orb"},
)

with self.assertRaises(
ValueError, msg="pp_file is a list and lack of pp for H"
):
self.system_ch4.to(
"stru",
"STRU_tmp",
mass=[12, 1],
pp_file=["C.upf"],
numerical_orbital={"C": "C.orb", "H": "H.orb"},
)

with self.assertRaises(
KeyError, msg="numerical_orbital is a dict and lack of orbital for H"
):
self.system_ch4.to(
"stru",
"STRU_tmp",
mass=[12, 1],
pp_file={"C": "C.upf", "H": "H.upf"},
numerical_orbital={"C": "C.orb", "O": "O.orb"},
)

with self.assertRaises(
ValueError, msg="numerical_orbital is a list and lack of orbital for H"
):
self.system_ch4.to(
"stru",
"STRU_tmp",
mass=[12, 1],
pp_file=["C.upf", "H.upf"],
numerical_orbital=["C.orb"],
)

def test_dump_spinconstrain(self):
self.system_ch4.to(
"stru",
Expand Down

0 comments on commit a6ced9f

Please sign in to comment.