Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew-S-Rosen committed May 21, 2024
1 parent 4270f3f commit 197177e
Show file tree
Hide file tree
Showing 4 changed files with 1,942 additions and 21 deletions.
39 changes: 19 additions & 20 deletions custodian/vasp/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,21 +456,7 @@ def correct(self, directory="./"):

if "too_few_bands" in self.errors:
nbands = None
if "NBANDS" in vi["INCAR"]:
nbands = vi["INCAR"]["NBANDS"]
else:
with open(os.path.join(directory, "OUTCAR")) as file:
for line in file:
# Have to take the last NBANDS line since sometimes VASP
# updates it automatically even if the user specifies it.
# The last one is marked by NBANDS= (no space).
if "NBANDS=" in line:
try:
d = line.split("=")
nbands = int(d[-1].strip())
break
except (IndexError, ValueError):
pass
nbands = vi["INCAR"]["NBANDS"] if "NBANDS" in vi["INCAR"] else self._get_nbands_from_outcar(directory)
if nbands:
new_nbands = max(int(1.1 * nbands), nbands + 1) # This handles the case when nbands is too low (< 8).
actions.append({"dict": "INCAR", "action": {"_set": {"NBANDS": new_nbands}}})
Expand Down Expand Up @@ -677,11 +663,8 @@ def correct(self, directory="./"):
)
self.error_count["algo_tet"] += 1

if "auto_nbands" in self.errors and (nbands := vi["INCAR"].get("NBANDS")):
try:
nelect = load_outcar(os.path.join(directory, "OUTCAR")).nelect
except Exception:
nelect = None # dummy value
if "auto_nbands" in self.errors and not (nbands := self._get_nbands_from_outcar(directory)):
nelect = load_outcar(os.path.join(directory, "OUTCAR")).nelect
if nelect and nbands > 2 * nelect:
self.error_count["auto_nbands"] += 1
warnings.warn(
Expand All @@ -693,6 +676,22 @@ def correct(self, directory="./"):
VaspModder(vi=vi, directory=directory).apply_actions(actions)
return {"errors": list(self.errors), "actions": actions}

@staticmethod
def _get_nbands_from_outcar(directory: str) -> int | None:
with open(os.path.join(directory, "OUTCAR")) as file:
nbands = None
for line in file:
# Have to take the last NBANDS line since sometimes VASP
# updates it automatically even if the user specifies it.
# The last one is marked by NBANDS= (no space).
if "NBANDS=" in line:
try:
d = line.split("=")
nbands = int(d[-1].strip())
break
except (IndexError, ValueError):
pass
return nbands

class LrfCommutatorHandler(ErrorHandler):
"""
Expand Down
Loading

0 comments on commit 197177e

Please sign in to comment.