Skip to content

Commit

Permalink
abacus: add checks on pp and orb in construction of STRU
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed Oct 9, 2024
1 parent b797acb commit ec6c29d
Showing 1 changed file with 38 additions and 36 deletions.
74 changes: 38 additions & 36 deletions dpdata/abacus/scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,9 @@ def make_unlabeled_stru(
link_file : bool, optional
Whether to link the pseudo potential files and orbital files in the STRU file.
If True, then only filename will be written in the STRU file, and make a soft link to the real file.
dest_dir : str, optional
The destination directory to make the soft link of the pseudo potential files and orbital files.
For velocity, mag, angle1, angle2, sc, and lambda_, if the value is None, then the corresponding information will not be written.
ABACUS support defining "mag" and "angle1"/"angle2" at the same time, and in this case, the "mag" only define the norm of the magnetic moment, and "angle1" and "angle2" define the direction of the magnetic moment.
If data has spins, then it will be written as mag to STRU file; while if mag is passed at the same time, then mag will be used.
Expand Down Expand Up @@ -682,6 +685,20 @@ def ndarray2list(i):
out = "ATOMIC_SPECIES\n"
if pp_file is not None:
pp_file = ndarray2list(pp_file)
ppfiles = None
if isinstance(pp_file,(list, tuple)):
assert len(pp_file) == len(data["atom_names"]), "ERROR: make_unlabeled_stru: pp_file length is not equal to the number of atom types"
ppfiles = pp_file
elif isinstance(pp_file, dict):
for iele in data["atom_names"]:
if iele not in pp_file:
raise RuntimeError(f"ERROR: make_unlabeled_stru: pp_file does not contain {iele}")
ppfiles = [pp_file[data["atom_names"][i]] for i in range(len(data["atom_names"]))]
else:
raise RuntimeError(f"ERROR: invalid pp_file: {pp_file}")
else:
ppfiles = None

for iele in range(len(data["atom_names"])):
if data["atom_numbs"][iele] == 0:
continue
Expand All @@ -690,57 +707,42 @@ def ndarray2list(i):
out += f"{mass[iele]:.3f} "
else:
out += "1 "
if pp_file is not None:
if isinstance(pp_file, (list, tuple)):
ipp_file = pp_file[iele]
elif isinstance(pp_file, dict):
if data["atom_names"][iele] not in pp_file:
print(
f"ERROR: make_unlabeled_stru: pp_file does not contain {data['atom_names'][iele]}"
)
ipp_file = None
else:
ipp_file = pp_file[data["atom_names"][iele]]
else:
ipp_file = None
if ppfiles is not None:
ipp_file = ppfiles[iele]
if ipp_file is not None:
if not link_file:
out += ipp_file
else:
out += os.path.basename(ipp_file.rstrip("/"))
if dest_dir is not None:
_link_file(dest_dir, ipp_file)

out += "\n"
out += "\n"

# NUMERICAL_ORBITAL block
if numerical_orbital is not None:
assert len(numerical_orbital) == len(data["atom_names"])
numerical_orbital = ndarray2list(numerical_orbital)
orbfiles = []
if isinstance(numerical_orbital, (list, tuple)):
assert len(numerical_orbital) == len(data["atom_names"]), "ERROR: make_unlabeled_stru: numerical_orbital length is not equal to the number of atom types"
orbfiles = [numerical_orbital[i] for i in range(len(data["atom_names"])) if data["atom_numbs"][i] != 0]
elif isinstance(numerical_orbital, dict):
for iele in data["atom_names"]:
if iele not in numerical_orbital:
raise RuntimeError(f"ERROR: make_unlabeled_stru: numerical_orbital does not contain {iele}")
orbfiles = [numerical_orbital[data["atom_names"][i]] for i in range(len(data["atom_names"])) if data["atom_numbs"][i] != 0]
else:
raise RuntimeError(f"ERROR: invalid numerical_orbital: {numerical_orbital}")


out += "NUMERICAL_ORBITAL\n"
for iele in range(len(data["atom_names"])):
if data["atom_numbs"][iele] == 0:
continue
if isinstance(numerical_orbital, (list, tuple)):
inum_orbital = numerical_orbital[iele]
elif isinstance(numerical_orbital, dict):
if data["atom_names"][iele] not in numerical_orbital:
print(
f"ERROR: make_unlabeled_stru: numerical_orbital does not contain {data['atom_names'][iele]}"
)
inum_orbital = None
else:
inum_orbital = numerical_orbital[data["atom_names"][iele]]
for iorb in orbfiles:
if not link_file:
out += iorb
else:
inum_orbital = None
if inum_orbital is not None:
if not link_file:
out += inum_orbital
else:
out += os.path.basename(inum_orbital.rstrip("/"))
if dest_dir is not None:
_link_file(dest_dir, inum_orbital)
out += os.path.basename(iorb.rstrip("/"))
if dest_dir is not None:
_link_file(dest_dir, iorb)
out += "\n"
out += "\n"

Expand Down

0 comments on commit ec6c29d

Please sign in to comment.