-
Notifications
You must be signed in to change notification settings - Fork 82
/
Copy pathtrain.py
46 lines (35 loc) · 1.27 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import logging
import traceback
from finetrainers import Trainer, parse_arguments
from finetrainers.constants import FINETRAINERS_LOG_LEVEL
logger = logging.getLogger("finetrainers")
logger.setLevel(FINETRAINERS_LOG_LEVEL)
def main():
try:
import multiprocessing
multiprocessing.set_start_method("fork")
except Exception as e:
logger.error(
f'Failed to set multiprocessing start method to "fork". This can lead to poor performance, high memory usage, or crashes. '
f"See: https://pytorch.org/docs/stable/notes/multiprocessing.html\n"
f"Error: {e}"
)
try:
args = parse_arguments()
trainer = Trainer(args)
trainer.prepare_dataset()
trainer.prepare_models()
trainer.prepare_precomputations()
trainer.prepare_trainable_parameters()
trainer.prepare_optimizer()
trainer.prepare_for_training()
trainer.prepare_trackers()
trainer.train()
# trainer.evaluate()
except KeyboardInterrupt:
logger.info("Received keyboard interrupt. Exiting...")
except Exception as e:
logger.error(f"An error occurred during training: {e}")
logger.error(traceback.format_exc())
if __name__ == "__main__":
main()