Skip to content

Commit

Permalink
Skip train when the init model is provided (#116)
Browse files Browse the repository at this point in the history
At iteration 0, it is not necessary to train the model if a init_model
is provided.

Co-authored-by: Han Wang <[email protected]>
  • Loading branch information
wanghan-iapcm and Han Wang authored Jan 25, 2023
1 parent 21bad63 commit dc21dfe
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 20 deletions.
2 changes: 2 additions & 0 deletions dpgen2/entrypoint/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ def dp_train_args():
doc_numb_models = "Number of models trained for evaluating the model deviation"
doc_config = "Configuration of training"
doc_template_script = "File names of the template training script. It can be a `List[Dict]`, the length of which is the same as `numb_models`. Each template script in the list is used to train a model. Can be a `Dict`, the models share the same template training script. "
doc_init_models_paths = "the paths to initial models"

return [
Argument("config", dict, RunDPTrain.training_args(), optional=True, default=RunDPTrain.normalize_config({}), doc=doc_numb_models),
Argument("numb_models", int, optional=True, default=4, doc=doc_numb_models),
Argument("template_script", [list,str], optional=False, doc=doc_template_script),
Argument("init_models_paths", list, optional=True, doc=doc_init_models_paths, alias=['training_iter0_model_path']),
]

def variant_train():
Expand Down
2 changes: 1 addition & 1 deletion dpgen2/entrypoint/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def workflow_concurrent_learning(
collect_data_config = normalize_step_dict(config.get('collect_data_config', default_config)) if old_style else config['step_configs']['collect_data_config']
cl_step_config = normalize_step_dict(config.get('cl_step_config', default_config)) if old_style else config['step_configs']['cl_step_config']
upload_python_packages = config.get('upload_python_packages', None)
init_models_paths = config.get('training_iter0_model_path', None) if old_style else config['train'].get('training_iter0_model_path', None)
init_models_paths = config.get('training_iter0_model_path', None) if old_style else config['train'].get('init_models_paths', None)
if upload_python_packages is not None and isinstance(upload_python_packages, str):
upload_python_packages = [upload_python_packages]
if upload_python_packages is not None:
Expand Down
34 changes: 33 additions & 1 deletion dpgen2/op/run_dp_train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os, json, dpdata, glob
import os, json, dpdata, glob, shutil
from pathlib import Path
from dpgen2.utils.run_command import run_command
from dpgen2.utils.chdir import set_directory
Expand Down Expand Up @@ -125,6 +125,14 @@ def execute(
train_dict = RunDPTrain.write_other_to_input_script(
train_dict, config, do_init_model, major_version)

if RunDPTrain.skip_training(work_dir, train_dict, init_model, iter_data):
return OPIO({
"script" : work_dir / train_script_name,
"model" : work_dir / "frozen_model.pb",
"lcurve" : work_dir / "lcurve.out",
"log" : work_dir / "train.log",
})

with set_directory(work_dir):
# open log
fplog = open('train.log', 'w')
Expand Down Expand Up @@ -224,6 +232,30 @@ def write_other_to_input_script(
raise RuntimeError('unsupported DeePMD-kit major version', major_version)
return odict

@staticmethod
def skip_training(
work_dir,
train_dict,
init_model,
iter_data,
):
# we have init model and no iter data, skip training
if (init_model is not None) and \
(iter_data is None or len(iter_data) == 0) :
with set_directory(work_dir):
with open(train_script_name, 'w') as fp:
json.dump(train_dict, fp, indent=4)
Path('train.log').write_text(
f'We have init model {init_model} and '
f'no iteration training data. '
f'The training is skipped.\n'
)
Path('lcurve.out').touch()
shutil.copy(init_model, 'frozen_model.pb')
return True
else:
return False

@staticmethod
def decide_init_model(
config,
Expand Down
30 changes: 12 additions & 18 deletions tests/op/test_run_dp_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ def setUp(self):


def tearDown(self):
for ii in ['init', self.task_path, self.task_name, 'foo' ]:
for ii in ['init', self.task_path, self.task_name, 'foo']:
if Path(ii).exists():
shutil.rmtree(str(ii))

Expand All @@ -592,10 +592,7 @@ def test_update_input_dict_v2_empty_list(self):
self.assertDictEqual(odict, self.expected_odict_v2)


@patch('dpgen2.op.run_dp_train.run_command')
def test_exec_v2_empty_list(self, mocked_run):
mocked_run.side_effect = [ (0, 'foo\n', ''), (0, 'bar\n', '') ]

def test_exec_v2_empty_list(self):
config = self.config.copy()
config['init_model_policy'] = 'no'

Expand All @@ -606,6 +603,9 @@ def test_exec_v2_empty_list(self, mocked_run):
task_name = self.task_name
work_dir = Path(task_name)

self.init_model = self.init_model.absolute()
self.init_model.write_text('this is init model')

ptrain = RunDPTrain()
out = ptrain.execute(
OPIO({
Expand All @@ -621,26 +621,20 @@ def test_exec_v2_empty_list(self, mocked_run):
self.assertEqual(out['model'], work_dir/'frozen_model.pb')
self.assertEqual(out['lcurve'], work_dir/'lcurve.out')
self.assertEqual(out['log'], work_dir/'train.log')

calls = [
call(['dp', 'train', train_script_name]),
call(['dp', 'freeze', '-o', 'frozen_model.pb']),
]
mocked_run.assert_has_calls(calls)


self.assertTrue(work_dir.is_dir())
self.assertTrue(out['log'].is_file())
self.assertEqual(out['log'].read_text(),
'#=================== train std out ===================\n'
'foo\n'
'#=================== train std err ===================\n'
'#=================== freeze std out ===================\n'
'bar\n'
'#=================== freeze std err ===================\n'
f'We have init model {self.init_model} and '
f'no iteration training data. '
f'The training is skipped.\n'
)
with open(out['script']) as fp:
jdata = json.load(fp)
self.assertDictEqual(jdata, self.expected_odict_v2)
self.assertEqual(Path(out['model']).read_text(), "this is init model")

os.remove(self.init_model)


@patch('dpgen2.op.run_dp_train.run_command')
Expand Down

0 comments on commit dc21dfe

Please sign in to comment.