Skip to content

Commit

Permalink
use list concat over extend/concat, better var names
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Mar 12, 2024
1 parent 266cc0f commit 93bd2d6
Show file tree
Hide file tree
Showing 22 changed files with 202 additions and 218 deletions.
32 changes: 16 additions & 16 deletions custodian/ansible/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,12 @@ def push_all(input_dict, settings, directory=None):
settings (dict): The specification of the modification to be made.
directory (None): dummy parameter for compatibility with FileActions
"""
for k, v in settings.items():
(d, key) = get_nested_dict(input_dict, k)
if key in d:
d[key].extend(v)
for k1, val in settings.items():
dct, k2 = get_nested_dict(input_dict, k1)
if k2 in dct:
dct[k2] += val
else:
d[key] = v
dct[k2] = val

@staticmethod
def inc(input_dict, settings, directory=None):
Expand Down Expand Up @@ -167,12 +167,12 @@ def pull(input_dict, settings, directory=None):
settings (dict): The specification of the modification to be made.
directory (None): dummy parameter for compatibility with FileActions
"""
for k, v in settings.items():
(d, key) = get_nested_dict(input_dict, k)
if key in d and (not isinstance(d[key], list)):
raise ValueError(f"Keyword {k} does not refer to an array.")
if key in d:
d[key] = [i for i in d[key] if i != v]
for k1, val in settings.items():
dct, k2 = get_nested_dict(input_dict, k1)
if k2 in dct and (not isinstance(dct[k2], list)):
raise ValueError(f"Keyword {k1} does not refer to an array.")
if k2 in dct:
dct[k2] = [itm for itm in dct[k2] if itm != val]

@staticmethod
def pull_all(input_dict, settings, directory=None):
Expand All @@ -184,11 +184,11 @@ def pull_all(input_dict, settings, directory=None):
settings (dict): The specification of the modification to be made.
directory (None): dummy parameter for compatibility with FileActions
"""
for k, v in settings.items():
if k in input_dict and (not isinstance(input_dict[k], list)):
raise ValueError(f"Keyword {k} does not refer to an array.")
for i in v:
DictActions.pull(input_dict, {k: i})
for key, val in settings.items():
if key in input_dict and (not isinstance(input_dict[key], list)):
raise ValueError(f"Keyword {key} does not refer to an array.")
for itm in val:
DictActions.pull(input_dict, {key: itm})

@staticmethod
def pop(input_dict, settings, directory=None):
Expand Down
6 changes: 3 additions & 3 deletions custodian/ansible/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def __init__(self, actions=None, strict=True, directory="./"):
self.supported_actions = {}
actions = actions if actions is not None else [DictActions]
for action in actions:
for i in dir(action):
if (not re.match(r"__\w+__", i)) and callable(getattr(action, i)):
self.supported_actions["_" + i] = getattr(action, i)
for attr in dir(action):
if (not re.match(r"__\w+__", attr)) and callable(getattr(action, attr)):
self.supported_actions[f"_{attr}"] = getattr(action, attr)
self.strict = strict
self.directory = directory

Expand Down
10 changes: 5 additions & 5 deletions custodian/cli/converge_kpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
def get_runs(vasp_command, target=1e-3, max_steps=10, mode="linear"):
"""Generate the runs using a generator until convergence is achieved."""
energy = 0
vinput = VaspInput.from_directory(".")
kpoints = vinput["KPOINTS"].kpts[0]
for i in range(max_steps):
m = [(k * (i + 1)) for k in kpoints] if mode == "linear" else [(k + 1) for k in kpoints]
if i == 0:
vasp_input = VaspInput.from_directory(".")
kpoints = vasp_input["KPOINTS"].kpts[0]
for step in range(max_steps):
m = [(kpt * (step + 1)) for kpt in kpoints] if mode == "linear" else [(kpt + 1) for kpt in kpoints]
if step == 0:
settings = None
backup = True
else:
Expand Down
82 changes: 35 additions & 47 deletions custodian/cli/run_vasp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ def load_class(mod, name):
toks = name.split("?")
params = {}
if len(toks) == 2:
for p in toks[-1].split(","):
ptoks = p.split("=")
params[ptoks[0]] = YAML(typ="rt").load(ptoks[1])
for tok in toks[-1].split(","):
p_toks = tok.split("=")
params[p_toks[0]] = YAML(typ="rt").load(p_toks[1])
elif len(toks) > 2:
print("Bad handler specification")
sys.exit(-1)
Expand All @@ -43,7 +43,7 @@ def get_jobs(args):
post_settings = [] # append to this list to have settings applied on next job
for i, job in enumerate(args.jobs):
final = i == n_jobs - 1
suffix = "." + job if any(c.isdigit() for c in job) else f".{job}{i + 1}"
suffix = "." + job if any(char.isdigit() for char in job) else f".{job}{i + 1}"
settings = post_settings
post_settings = []
backup = i == 0
Expand All @@ -66,12 +66,10 @@ def get_jobs(args):
user_incar_settings={"LWAVE": True, "EDIFF": 1e-6},
ediff_per_atom=False,
)
settings.extend(
[
{"dict": "INCAR", "action": {"_set": dict(vis.incar)}},
{"dict": "KPOINTS", "action": {"_set": vis.kpoints.as_dict()}},
]
)
settings += [
{"dict": "INCAR", "action": {"_set": dict(vis.incar)}},
{"dict": "KPOINTS", "action": {"_set": vis.kpoints.as_dict()}},
]

if job_type.startswith("static_dielectric_derived"):
from pymatgen.io.vasp.sets import MPStaticDielectricDFPTVaspInputSet, MPStaticSet
Expand All @@ -92,37 +90,31 @@ def get_jobs(args):
vis = MPStaticDielectricDFPTVaspInputSet()
incar = vis.get_incar(vinput["POSCAR"].structure)
unset = {}
for k in ["NPAR", "KPOINT_BSE", "LAECHG", "LCHARG", "LVHAR", "NSW"]:
incar.pop(k, None)
if k in vinput["INCAR"]:
unset[k] = 1
for key in ("NPAR", "KPOINT_BSE", "LAECHG", "LCHARG", "LVHAR", "NSW"):
incar.pop(key, None)
if key in vinput["INCAR"]:
unset[key] = 1
kpoints = vis.get_kpoints(vinput["POSCAR"].structure)
settings.extend(
[
{"dict": "INCAR", "action": {"_set": dict(incar), "_unset": unset}},
{"dict": "KPOINTS", "action": {"_set": kpoints.as_dict()}},
]
)
settings += [
{"dict": "INCAR", "action": {"_set": dict(incar), "_unset": unset}},
{"dict": "KPOINTS", "action": {"_set": kpoints.as_dict()}},
]
auto_npar = False
elif job_type.startswith("static") and vinput["KPOINTS"]:
m = [i * args.static_kpoint for i in vinput["KPOINTS"].kpts[0]]
settings.extend(
[
{"dict": "INCAR", "action": {"_set": {"NSW": 0}}},
{"dict": "KPOINTS", "action": {"_set": {"kpoints": [m]}}},
]
)
m = [kpt * args.static_kpoint for kpt in vinput["KPOINTS"].kpts[0]]
settings += [
{"dict": "INCAR", "action": {"_set": {"NSW": 0}}},
{"dict": "KPOINTS", "action": {"_set": {"kpoints": [m]}}},
]

elif job_type.startswith("nonscf_derived"):
from pymatgen.io.vasp.sets import MPNonSCFSet

vis = MPNonSCFSet.from_prev_calc(".", copy_chgcar=False, user_incar_settings={"LWAVE": True})
settings.extend(
[
{"dict": "INCAR", "action": {"_set": dict(vis.incar)}},
{"dict": "KPOINTS", "action": {"_set": vis.kpoints.as_dict()}},
]
)
settings += [
{"dict": "INCAR", "action": {"_set": dict(vis.incar)}},
{"dict": "KPOINTS", "action": {"_set": vis.kpoints.as_dict()}},
]

elif job_type.startswith("optics_derived"):
from pymatgen.io.vasp.sets import MPNonSCFSet
Expand All @@ -142,12 +134,10 @@ def get_jobs(args):
},
ediff_per_atom=False,
)
settings.extend(
[
{"dict": "INCAR", "action": {"_set": dict(vis.incar)}},
{"dict": "KPOINTS", "action": {"_set": vis.kpoints.as_dict()}},
]
)
settings += [
{"dict": "INCAR", "action": {"_set": dict(vis.incar)}},
{"dict": "KPOINTS", "action": {"_set": vis.kpoints.as_dict()}},
]

elif job_type.startswith("rampu"):
f = ramps / (n_ramp_u - 1)
Expand Down Expand Up @@ -175,12 +165,10 @@ def get_jobs(args):
post_settings.append({"dict": "KPOINTS", "action": {"_set": kpoints.as_dict()}})
# lattice vectors with length < 9 will get >1 KPOINT
low_kpoints = Kpoints.gamma_automatic([max(int(18 / length), 1) for length in structure.lattice.abc])
settings.extend(
[
{"dict": "INCAR", "action": {"_set": {"ISMEAR": 0}}},
{"dict": "KPOINTS", "action": {"_set": low_kpoints.as_dict()}},
]
)
settings += [
{"dict": "INCAR", "action": {"_set": {"ISMEAR": 0}}},
{"dict": "KPOINTS", "action": {"_set": low_kpoints.as_dict()}},
]

# let vasp determine encut (will be lower than
# needed for compatibility with other runs)
Expand Down Expand Up @@ -213,8 +201,8 @@ def do_run(args):
FORMAT = "%(asctime)s %(message)s"
logging.basicConfig(format=FORMAT, level=logging.INFO, filename="run.log")
logging.info(f"Handlers used are {args.handlers}")
handlers = [load_class("custodian.vasp.handlers", n) for n in args.handlers]
validators = [load_class("custodian.vasp.validators", n) for n in args.validators]
handlers = [load_class("custodian.vasp.handlers", handler) for handler in args.handlers]
validators = [load_class("custodian.vasp.validators", validator) for validator in args.validators]

c = Custodian(
handlers,
Expand Down
24 changes: 11 additions & 13 deletions custodian/cp2k/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,12 @@ def __correct_ot(self, ci):

# Try going from DIIS -> CG (slower, but more robust)
if minimizer == "DIIS":
actions.append(
actions += [
{
"dict": self.input_file,
"action": {"_set": {"FORCE_EVAL": {"DFT": {"SCF": {"OT": {"MINIMIZER": "CG"}}}}}},
}
)
]

# Try going from 2pnt to 3pnt line search (slower, but more robust)
elif minimizer == "CG":
Expand All @@ -200,14 +200,14 @@ def __correct_ot(self, ci):
!= "3PNT"
and not rotate
):
actions.append(
actions += [
{
"dict": self.input_file,
"action": {"_set": {"FORCE_EVAL": {"DFT": {"SCF": {"OT": {"LINESEARCH": "3PNT"}}}}}},
}
)
]
elif ci["FORCE_EVAL"]["DFT"]["SCF"].get("MAX_SCF", Keyword("MAX_SCF", 50)).values[0] < 50:
actions.append(
actions += [
{
"dict": self.input_file,
"action": {
Expand All @@ -220,7 +220,7 @@ def __correct_ot(self, ci):
}
},
}
)
]

"""
Switch to more robust OT framework.
Expand Down Expand Up @@ -415,9 +415,7 @@ def check(self, directory="./"):
"""Check for diverging SCF."""
conv = get_conv(os.path.join(directory, self.output_file))
tmp = np.diff(conv[-10:])
if len(conv) > 10 and all(_ > 0 for _ in tmp) and any(_ > 1 for _ in conv):
return True
return False
return len(conv) > 10 and all(_ > 0 for _ in tmp) and any(_ > 1 for _ in conv)

def correct(self, directory="./"):
"""Correct issue if possible."""
Expand Down Expand Up @@ -631,8 +629,8 @@ def check(self, directory="./"):
terminate_on_match=True,
postprocess=str,
)
for m in matches:
self.responses.append(m)
for match in matches:
self.responses.append(match)
return True
return False

Expand Down Expand Up @@ -826,8 +824,8 @@ def __init__(
def check(self, directory="./"):
"""Check for stuck SCF convergence."""
conv = get_conv(os.path.join(directory, self.output_file))
counts = [sum(1 for i in g) for k, g in itertools.groupby(conv)]
if any(c > self.max_same for c in counts):
counts = [len([*group]) for _k, group in itertools.groupby(conv)]
if any(cnt > self.max_same for cnt in counts):
return True
return False

Expand Down
14 changes: 7 additions & 7 deletions custodian/cp2k/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,16 @@ def apply_actions(self, actions):
'action': moddermodification}.
"""
modified = []
for a in actions:
if "dict" in a:
k = a["dict"]
for action in actions:
if "dict" in action:
k = action["dict"]
modified.append(k)
Cp2kModder._modify(a["action"], self.ci)
elif "file" in a:
self.modify(a["action"], a["file"])
Cp2kModder._modify(action["action"], self.ci)
elif "file" in action:
self.modify(action["action"], action["file"])
self.ci = Cp2kInput.from_file(os.path.join(self.directory, self.filename))
else:
raise ValueError(f"Unrecognized format: {a}")
raise ValueError(f"Unrecognized format: {action}")
cleanup_input(self.ci)
self.ci.write_file(os.path.join(self.directory, self.filename))

Expand Down
2 changes: 1 addition & 1 deletion custodian/cp2k/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def run(self, directory="./"):
"""
# TODO: cp2k has bizarre in/out streams. Some errors that should go to std_err are not sent anywhere...
cmd = list(self.cp2k_cmd)
cmd.extend(["-i", self.input_file])
cmd += ["-i", self.input_file]
cmd_str = " ".join(cmd)
logger.info(f"Running {cmd_str}")
with (
Expand Down
8 changes: 4 additions & 4 deletions custodian/cp2k/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def cleanup_input(ci):
return
if any(k.upper() == "POTENTIAL" for k in ci.subsections):
ci.subsections.pop("POTENTIAL")
for v in ci.subsections.values():
cleanup_input(v)
for val in ci.subsections.values():
cleanup_input(val)


def activate_ot(actions, ci):
Expand Down Expand Up @@ -115,7 +115,7 @@ def activate_ot(actions, ci):
],
},
]
actions.extend(ot_actions)
actions += ot_actions


def activate_diag(actions):
Expand Down Expand Up @@ -149,7 +149,7 @@ def activate_diag(actions):
),
},
]
actions.extend(diag_actions)
actions += diag_actions


def can_use_ot(output, ci, minimum_band_gap=0.1):
Expand Down
Loading

0 comments on commit 93bd2d6

Please sign in to comment.