Skip to content

Commit

Permalink
check skip logic and def name
Browse files Browse the repository at this point in the history
  • Loading branch information
SumGuo-88 committed Jan 10, 2025
1 parent 10ef768 commit c2dc7ef
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 9 deletions.
2 changes: 1 addition & 1 deletion deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,7 @@ def main_parser() -> argparse.ArgumentParser:
)
parser_change_bias.add_argument(
"--skip-elementcheck",
action="store_false",
action="store_true",
help="Enable this option to skip element checks if any error occurs while retrieving statistical data.",
)
parser_change_bias.add_argument(
Expand Down
8 changes: 4 additions & 4 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def change_bias(
numb_batch: int = 0,
model_branch: Optional[str] = None,
output: Optional[str] = None,
elem_check_stat: bool = True,
skip_elem_check: bool = True,
min_frames: int = 10,
) -> None:
if input_file.endswith(".pt"):
Expand Down Expand Up @@ -474,8 +474,8 @@ def change_bias(
data_single.systems,
data_single.dataloaders,
nbatches,
min_frames_per_element_forstat=min_frames,
enable_element_completion=elem_check_stat,
min_frames_per_element_forstat = min_frames,
enable_element_completion = not skip_elem_check,
)
updated_model = training.model_change_out_bias(
model_to_change, sampled_data, _bias_adjust_mode=bias_adjust_mode
Expand Down Expand Up @@ -559,7 +559,7 @@ def main(args: Optional[Union[list[str], argparse.Namespace]] = None) -> None:
numb_batch=FLAGS.numb_batch,
model_branch=FLAGS.model_branch,
output=FLAGS.output,
elem_check_stat=FLAGS.skip_elementcheck,
skip_elem_check=FLAGS.skip_elementcheck,
min_frames=FLAGS.min_frames,
)
elif FLAGS.command == "compress":
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __getitem__(self, index):
b_data["natoms"] = self._natoms_vec
return b_data

def get_frame_index(self):
def get_frame_index_for_elements(self):
"""
Get the frame index and the number of frames with all the elements in the system.
This function is only used in the mixed type.
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def finalize_stats(sys_stat):

# get frame index
if datasets[0].mixed_type and enable_element_completion:
element_counts = dataset.get_frame_index()
element_counts = dataset.get_frame_index_for_elements()
for elem, data in element_counts.items():
indices = data["indices"]
count = data["frames"]
Expand Down
10 changes: 8 additions & 2 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -2787,6 +2787,12 @@ def training_args(
"If the gradient norm exceeds this value, it will be clipped to this limit. "
"No gradient clipping will occur if set to 0."
)
doc_min_frames_per_element_forstat = (
"The minimum number of frames per element used for statistics when using the mixed type."
)
doc_enable_element_completion = (
"Whether to check elements when using the mixed type"
)
doc_stat_file = (
"The file path for saving the data statistics results. "
"If set, the results will be saved and directly loaded during the next training session, "
Expand Down Expand Up @@ -2898,14 +2904,14 @@ def training_args(
int,
default=10,
optional=True,
doc="The minimum number of frames per element used for statistics when using the mixed type.",
doc=doc_only_pt_supported + doc_min_frames_per_element_forstat,
),
Argument(
"enable_element_completion",
bool,
optional=True,
default=True,
doc="Whether to check elements when using the mixed type",
doc=doc_only_pt_supported + doc_enable_element_completion,
),
]
variants = [
Expand Down

0 comments on commit c2dc7ef

Please sign in to comment.