-
Notifications
You must be signed in to change notification settings - Fork 129
/
train.py
executable file
·112 lines (96 loc) · 4.42 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#!/usr/bin/env python3
# Scene Text Recognition Model Hub
# Copyright 2022 Darwin Bautista
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from pathlib import Path
import hydra
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, open_dict
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, StochasticWeightAveraging
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.utilities.model_summary import summarize
from strhub.data.module import SceneTextDataModule
from strhub.models.base import BaseSystem
from strhub.models.utils import get_pretrained_weights
# Copied from OneCycleLR
def _annealing_cos(start, end, pct):
'Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0.'
cos_out = math.cos(math.pi * pct) + 1
return end + (start - end) / 2.0 * cos_out
def get_swa_lr_factor(warmup_pct, swa_epoch_start, div_factor=25, final_div_factor=1e4) -> float:
"""Get the SWA LR factor for the given `swa_epoch_start`. Assumes OneCycleLR Scheduler."""
total_steps = 1000 # Can be anything. We use 1000 for convenience.
start_step = int(total_steps * warmup_pct) - 1
end_step = total_steps - 1
step_num = int(total_steps * swa_epoch_start) - 1
pct = (step_num - start_step) / (end_step - start_step)
return _annealing_cos(1, 1 / (div_factor * final_div_factor), pct)
@hydra.main(config_path='configs', config_name='main', version_base='1.2')
def main(config: DictConfig):
trainer_strategy = 'auto'
with open_dict(config):
# Resolve absolute path to data.root_dir
config.data.root_dir = hydra.utils.to_absolute_path(config.data.root_dir)
# Special handling for GPU-affected config
gpu = config.trainer.get('accelerator') == 'gpu'
devices = config.trainer.get('devices', 0)
if gpu:
# Use mixed-precision training
config.trainer.precision = 'bf16-mixed' if torch.get_autocast_gpu_dtype() is torch.bfloat16 else '16-mixed'
if gpu and devices > 1:
# Use DDP with optimizations
trainer_strategy = DDPStrategy(find_unused_parameters=False, gradient_as_bucket_view=True)
# Scale steps-based config
config.trainer.val_check_interval //= devices
if config.trainer.get('max_steps', -1) > 0:
config.trainer.max_steps //= devices
# Special handling for PARseq
if config.model.get('perm_mirrored', False):
assert config.model.perm_num % 2 == 0, 'perm_num should be even if perm_mirrored = True'
model: BaseSystem = hydra.utils.instantiate(config.model)
# If specified, use pretrained weights to initialize the model
if config.pretrained is not None:
m = model.model if config.model._target_.endswith('PARSeq') else model
m.load_state_dict(get_pretrained_weights(config.pretrained))
print(summarize(model, max_depth=2))
datamodule: SceneTextDataModule = hydra.utils.instantiate(config.data)
checkpoint = ModelCheckpoint(
monitor='val_accuracy',
mode='max',
save_top_k=3,
save_last=True,
filename='{epoch}-{step}-{val_accuracy:.4f}-{val_NED:.4f}',
)
swa_epoch_start = 0.75
swa_lr = config.model.lr * get_swa_lr_factor(config.model.warmup_pct, swa_epoch_start)
swa = StochasticWeightAveraging(swa_lr, swa_epoch_start)
cwd = (
HydraConfig.get().runtime.output_dir
if config.ckpt_path is None
else str(Path(config.ckpt_path).parents[1].absolute())
)
trainer: Trainer = hydra.utils.instantiate(
config.trainer,
logger=TensorBoardLogger(cwd, '', '.'),
strategy=trainer_strategy,
enable_model_summary=False,
callbacks=[checkpoint, swa],
)
trainer.fit(model, datamodule=datamodule, ckpt_path=config.ckpt_path)
if __name__ == '__main__':
main()