This repository has been archived by the owner on Feb 24, 2024. It is now read-only.
-
-
Notifications
You must be signed in to change notification settings - Fork 74
/
Copy pathtrain.py
148 lines (109 loc) · 3.89 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import torch
import colossalai
from colossalai.core import global_context as gpc
from colossalai.trainer import Trainer, hooks
from colossalai.utils import MultiTimer
from colossalai.logging import disable_existing_loggers, get_dist_logger
import wandb
from lamda_pytorch.config.config import CFG
from lamda_pytorch.build_dataloader import build_dataloaders
from lamda_pytorch.lamda_pytorch import lamda_model
from lamda_pytorch.utils.utils import LaMDA_Loss, AutoregressiveWrapper
from transformers import AutoTokenizer
def LaMDA_Trainer(cfg: CFG):
assert torch.cuda.is_available()
disable_existing_loggers()
parser = colossalai.get_default_parser()
parser.add_argument(
'--use_trainer',
action='store_true',
help='whether to use trainer'
)
args = parser.parse_args()
if cfg.use_zero == True:
pass
else:
colossalai.launch_from_torch(
config='./lamda_pytorch/config/colossal_config.py',
seed = cfg.seed
)
assert hasattr(gpc.config, "EPOCHS"), "Please provide NUM_EPOCHS in your configuration"
# Colossal logger
logger = get_dist_logger()
logger.info("Initialized environment", ranks=[0])
# LaMDA model
model = lamda_model()
model = AutoregressiveWrapper(model)
# setup dataloaders
if cfg.use_huggingface == True:
tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer_name)
train_dataloader, eval_dataloader = build_dataloaders(cfg, tokenizer)
# loss function
loss_fn = LaMDA_Loss()
# optimizer function
optimizer = torch.optim.AdamW(
model.parameters(),
lr = gpc.config.LEARNING_RATE,
weight_decay=gpc.config.WEIGHT_DECAY
)
# initialze model, optimizer, criterion, and data loaders
engine, train_dataloader, _, _ = colossalai.initialize(
model,
optimizer,
loss_fn,
train_dataloader = train_dataloader
)
def batch_data_process_func(batch_data):
data = batch_data["input_ids"]
labels = batch_data["labels"]
return data, labels
engine.schedule.batch_data_process_func = batch_data_process_func
if cfg.use_wandb == True:
# initialize Weights and Biases Logging
wandb.init(project = cfg.project_name)
engine.train()
for step, batch in enumerate(train_dataloader):
inputs, labels = batch['inputs'].cuda(), batch['labels'].cuda()
engine.zero_grad()
outputs = engine(inputs)
train_loss = engine.loss_fn(outputs, labels)
wandb.log({"train_loss": train_loss})
engine.backward(train_loss)
engine.step()
wandb.log({"step": step})
engine.eval()
for step, batch in enumerate(eval_dataloader):
inputs, labels = batch['inputs'].cuda(), batch['labels'].cuda()
with torch.no_grad():
outputs = engine(inputs)
test_loss = engine.loss_fn(outputs, labels)
wandb.log({"test_loss": test_loss})
engine.backward(test_loss)
engine.step()
wandb.alert(
title = 'Training Complete',
text = "Training complete."
)
else:
# Time session with ColossalAI
timer = MultiTimer()
# trainer
trainer = Trainer(
engine = engine,
timer = timer,
logger = logger
)
hook_list = [
hooks.LogMetricByStepHook(),
hooks.LossHook(),
hooks.LogMetricByEpochHook(logger)
]
trainer.fit(
train_dataloader = train_dataloader,
epochs = gpc.config.EPOCHS,
hooks = hook_list,
display_progress = True
)
if __name__ == "__main__":
cfg = CFG()
LaMDA_Trainer(cfg)