diff --git a/dpgen2/exploration/task/lmp_template_task_group.py b/dpgen2/exploration/task/lmp_template_task_group.py index aeecde92..b82e1695 100644 --- a/dpgen2/exploration/task/lmp_template_task_group.py +++ b/dpgen2/exploration/task/lmp_template_task_group.py @@ -43,14 +43,19 @@ def set_lmp( plm_template_fname: Optional[str] = None, revisions: dict = {}, traj_freq: int = 10, + extra_pair_style_args: str = "", ) -> None: self.lmp_template = Path(lmp_template_fname).read_text().split("\n") self.revisions = revisions self.traj_freq = traj_freq + self.extra_pair_style_args = extra_pair_style_args self.lmp_set = True self.model_list = sorted([model_name_pattern % ii for ii in range(numb_models)]) self.lmp_template = revise_lmp_input_model( - self.lmp_template, self.model_list, self.traj_freq + self.lmp_template, + self.model_list, + self.traj_freq, + self.extra_pair_style_args, ) self.lmp_template = revise_lmp_input_dump(self.lmp_template, self.traj_freq) if plm_template_fname is not None: @@ -138,12 +143,20 @@ def find_only_one_key(lmp_lines, key): return found[0] -def revise_lmp_input_model(lmp_lines, task_model_list, trj_freq, deepmd_version="1"): +def revise_lmp_input_model( + lmp_lines, task_model_list, trj_freq, extra_pair_style_args="", deepmd_version="1" +): idx = find_only_one_key(lmp_lines, ["pair_style", "deepmd"]) + if extra_pair_style_args: + extra_pair_style_args = " " + extra_pair_style_args graph_list = " ".join(task_model_list) - lmp_lines[idx] = "pair_style deepmd %s out_freq %d out_file model_devi.out" % ( - graph_list, - trj_freq, + lmp_lines[idx] = ( + "pair_style deepmd %s out_freq %d out_file model_devi.out%s" + % ( + graph_list, + trj_freq, + extra_pair_style_args, + ) ) return lmp_lines diff --git a/dpgen2/exploration/task/make_task_group_from_config.py b/dpgen2/exploration/task/make_task_group_from_config.py index 37b7f8b4..c467fd8e 100644 --- a/dpgen2/exploration/task/make_task_group_from_config.py +++ b/dpgen2/exploration/task/make_task_group_from_config.py @@ -116,6 +116,7 @@ def lmp_template_task_group_args(): doc_plm_template_fname = "The file name of plumed input template" doc_revisions = "The revisions. Should be a dict providing the key - list of desired values pair. Key is the word to be replaced in the templates, and it may appear in both the lammps and plumed input templates. All values in the value list will be enmerated." doc_traj_freq = "The frequency of dumping configurations and thermodynamic states" + doc_extra_pair_style_args = "The extra arguments for pair_style" return [ Argument("conf_idx", list, optional=False, doc=doc_conf_idx, alias=["sys_idx"]), @@ -141,7 +142,7 @@ def lmp_template_task_group_args(): doc=doc_plm_template_fname, alias=["plm_template", "plm"], ), - Argument("revisions", dict, optional=True, default={}), + Argument("revisions", dict, optional=True, default={}, doc=doc_revisions), Argument( "traj_freq", int, @@ -150,6 +151,13 @@ def lmp_template_task_group_args(): doc=doc_traj_freq, alias=["t_freq", "trj_freq", "trj_freq"], ), + Argument( + "extra_pair_style_args", + str, + optional=True, + default="", + doc=doc_extra_pair_style_args, + ), ]