forked from PaddlePaddle/PaddleMIX
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_pretrain_dist.py
264 lines (226 loc) · 9.38 KB
/
run_pretrain_dist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
# coding:utf-8
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import numpy as np
parent_path = os.path.abspath(os.path.join(__file__, *([".."] * 4)))
sys.path.insert(0, parent_path)
import pprint
import socket
from dataclasses import dataclass, field
import paddle
from paddlemix.datasets import load_dataset
from paddlemix.datasets.dataset import ImageFolder
from paddlemix.metrics.clip_zero_shot import ClipZeroShot
from paddlemix.models.clip.clip_model import CLIP, CLIPConfig
from paddlemix.optimization import create_optimizer_simple
from paddlemix.processors.clip_processing import (
CLIPImageProcessor,
CLIPProcessor,
CLIPTextProcessor,
)
from paddlemix.processors.tokenizer import SimpleTokenizer
from paddlemix.trainer import CLIPTrainer
from paddlemix.utils.env import setdistenv
from paddlenlp.trainer import PdArgumentParser, TrainingArguments
@dataclass
class DataArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
Using `PdArgumentParser` we can turn this class
into argparse arguments to be able to specify them on
the command line.
"""
task_name: str = field(
default="coco_clip",
metadata={
"help": "The name of the task to use (via the datasets library), coco or laion-aes"
" is support, if set to laion-aes, this should be the path to filelist file. "
"option: [coco_clip/[path to laion-aes.filelist]], default: coco_clip"
},
)
classification_eval: str = field(
default="",
metadata={"help": "Path to IN1K data."},
)
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model: str = field(
default="paddlemix/EVA/EVA02-CLIP-L-14",
metadata={"help": "model name to create, for example [EVA02-CLIP-B-16/coca_EVA02-B-16]"},
)
@dataclass
class PreTrainingArguments(TrainingArguments):
"""
Arguments pertaining to what training options we are going to use during pretraining.
"""
pretrained: bool = field(
default=False,
metadata={"help": "Whether to use pretrained model."},
)
text_wd: float = field(default=0.05, metadata={"help": "Weight decay for text tower"})
visual_wd: float = field(default=0.05, metadata={"help": "Weight decay for visual tower"})
text_lr: float = field(default=2e-5, metadata={"help": "The initial learning rate of text tower."})
visual_lr: float = field(default=2e-4, metadata={"help": "The initial learning rate of visual tower."})
layer_decay: float = field(default=1.0, metadata={"help": "The basic layer decay."})
text_ld: float = field(default=0.75, metadata={"help": "The layer decay of text tower."})
visual_ld: float = field(default=0.75, metadata={"help": "The layer decay of visual tower."})
start_epoch: int = field(
default=0,
metadata={"help": " manual epoch number (useful on restarts)"},
)
context_length: int = field(
default=77,
metadata={"help": " context length for text."},
)
optimizer: str = field(default="lamb", metadata={"help": "optimizer setting, [lamb/adamw]"})
last_epoch: int = field(default=-1, metadata={"help": "the last epoch to resume"})
gather_with_grad: bool = field(
default=False,
metadata={"help": "Whether to use gather_with_grad in loss."},
)
local_loss: bool = field(
default=False,
metadata={"help": "Whether to use local loss in loss."},
)
tensorboard: bool = field(
default=False,
metadata={"help": "Whether to use tensorboard to record loss."},
)
pretrained_text_model: str = field(default="openclip", metadata={"help": "the model to pre-extract text feats"})
tensor_fusion: bool = field(
default=False,
metadata={"help": "Whether to use tensor fusion."},
)
class SelfTrainer(CLIPTrainer):
def create_optimizer_and_scheduler(self, num_training_steps: int):
"""
Setup the optimizer and the learning rate scheduler.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or
`create_scheduler`) in a subclass.
"""
self.lr_scheduler = paddle.optimizer.lr.CosineAnnealingDecay(
self.args.learning_rate,
num_training_steps - self.args.warmup_steps,
last_epoch=self.args.last_epoch,
)
if self.args.warmup_steps > 0:
self.lr_scheduler = paddle.optimizer.lr.LinearWarmup(
self.lr_scheduler,
self.args.warmup_steps,
0,
1.0,
last_epoch=self.args.last_epoch,
)
self.optimizer = create_optimizer_simple(self.args, self.model, self.lr_scheduler)
class Collator:
"""
Data collator that will dynamically pad the inputs to the longest sequence in the batch.
Args:
processor (`paddlemix.processors.ProcessorMixin`):
The processor used for pre-process the data.
"""
def __init__(self, processor):
self.processor = processor
def __call__(self, data_list):
if isinstance(data_list[0], dict):
images = [sample["image"] for sample in data_list]
text = [sample["text"] for sample in data_list]
batch = self.processor(
images=images,
text=text,
max_length=77,
return_tensors="pd",
return_attention_mask=False,
mode="train",
padding_zero=True,
)
return batch
else:
images = [sample[0] for sample in data_list]
labels = [sample[1] for sample in data_list]
batch = self.processor(
images=images,
text=None,
max_length=77,
return_tensors="pd",
return_attention_mask=False,
mode="eval",
do_resize=True,
do_crop=True,
padding_zero=True,
)
batch["labels"] = paddle.to_tensor(np.array(labels))
return batch
def main_worker(training_args, model_args, data_args):
if training_args.bf16 and training_args.fp16_opt_level == "O2":
paddle.set_default_dtype("bfloat16")
config = CLIPConfig.from_pretrained(model_args.model)
model = CLIP(
config,
disable_text=False,
local_loss=training_args.local_loss,
gather_with_grad=training_args.gather_with_grad,
data_world_rank=training_args.data_world_rank,
data_world_size=training_args.data_world_size,
)
if training_args.pretrained:
model.load_pretrained(model_args.model)
if training_args.bf16 and training_args.fp16_opt_level == "O2":
paddle.set_default_dtype("float32")
if "laion" in data_args.task_name:
from paddlemix.datasets.laiondata import LaionDataset
train_dataset = LaionDataset(data_args.task_name)
else:
train_dataset = load_dataset(data_args.task_name, splits="train")
image_processor = CLIPImageProcessor.from_pretrained(os.path.join(model_args.model, "processor", "train"))
text_processor = CLIPTextProcessor.from_pretrained(os.path.join(model_args.model, "processor", "train"))
tokenizer = SimpleTokenizer()
processor = CLIPProcessor(image_processor, text_processor, tokenizer)
collator = Collator(processor)
eval_dataset = ImageFolder(f"{data_args.classification_eval}/images")
zeroshot = ClipZeroShot(model, training_args)
trainer = SelfTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=collator,
compute_metrics=zeroshot.zero_shot_eval,
)
# Training
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
if training_args.do_train:
trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_model()
trainer.save_state()
if __name__ == "__main__":
parser = PdArgumentParser((ModelArguments, DataArguments, PreTrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
training_args.hostname = socket.gethostname()
pprint.pprint(data_args)
pprint.pprint(model_args)
pprint.pprint(training_args)
setdistenv(training_args)
model_args.data_world_rank = training_args.data_world_rank
model_args.data_world_size = training_args.data_world_size
training_args.classification_eval = data_args.classification_eval
main_worker(training_args, model_args, data_args)