-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathcommander.py
38 lines (28 loc) · 1.18 KB
/
commander.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import fire
import subprocess
class Commander(object):
@staticmethod
def _create_starter(model_name):
print("Creating starter according to model name ....")
command = '''sed "s|STUB_MODEL_NAME|{}|g" training/train_start_template.py>training/train_start_{}.py'''.format(model_name, model_name)
subprocess.call(command, shell=True)
print("Starter created!")
@staticmethod
def _clean_up(model_name):
print("Clean up!")
command = '''rm -rf training/train_start_{}.py'''.format(model_name)
subprocess.call(command, shell=True)
def train(self, model_name, gpu=0, clean_up=False):
self._create_starter(model_name)
if gpu == -1:
_gpu_env_var = 'export CUDA_VISIBLE_DEVICES='
else:
_gpu_env_var = 'export CUDA_VISIBLE_DEVICES={}'.format(gpu)
command = "& python -m training.train_start_{}".format(model_name)
command = " ".join([_gpu_env_var, command])
print("Commands: {}".format(command))
subprocess.call(command, shell=True)
if clean_up:
self._clean_up(model_name=model_name)
if __name__ == '__main__':
fire.Fire(Commander)