-
Notifications
You must be signed in to change notification settings - Fork 6
/
train.py
48 lines (40 loc) · 1.17 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from typing import *
import hydra
import torch
import random
import numpy as np
from pytorch_lightning import Trainer
from loguru import logger
from model import Classifier
@hydra.main(config_path="config.yaml")
def train(config):
logger.info(config)
np.random.seed(42)
random.seed(42)
if torch.cuda.is_available():
torch.backends.cuda.deterministic = True
torch.backends.cuda.benchmark = False
model = Classifier(config)
trainer = Trainer(
gradient_clip_val = 0,
num_nodes=1,
gpus = None if not torch.cuda.is_available() else [i for i in range(torch.cuda.device_count())],
log_gpu_memory=True,
show_progress_bar=True,
accumulate_grad_batches=config["accumulate_grad_batches"],
max_epochs=config["max_epochs"],
min_epochs=1,
val_check_interval=0.1,
log_save_interval=100,
row_log_interval=10,
distributed_backend = "ddp",
use_amp=config["use_amp"],
weights_summary= 'top',
amp_level='O2',
num_sanity_val_steps=5,
resume_from_checkpoint=None,
)
trainer.fit(model)
pass
if __name__ == "__main__":
train()