forked from zhengbw0324/LC-Rec
-
Notifications
You must be signed in to change notification settings - Fork 5
/
finetune.py
executable file
·119 lines (93 loc) · 3.4 KB
/
finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import argparse
import os
import sys
from typing import List
import torch
import transformers
from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig
from utils import *
from collator import Collator
def train(args):
set_seed(args.seed)
ensure_dir(args.output_dir)
device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
local_rank = int(os.environ.get("LOCAL_RANK") or 0)
if local_rank == 0:
print(vars(args))
if ddp:
device_map = {"": local_rank}
config = LlamaConfig.from_pretrained(args.base_model)
tokenizer = LlamaTokenizer.from_pretrained(
args.base_model,
model_max_length = args.model_max_length,
padding_side="right",
)
tokenizer.pad_token_id = 0
gradient_checkpointing = True
train_data, valid_data = load_datasets(args)
add_num = tokenizer.add_tokens(train_data.datasets[0].get_new_tokens())
config.vocab_size = len(tokenizer)
if local_rank == 0:
print("add {} new token.".format(add_num))
print("data num:", len(train_data))
tokenizer.save_pretrained(args.output_dir)
config.save_pretrained(args.output_dir)
collator = Collator(args, tokenizer)
model = LlamaForCausalLM.from_pretrained(
args.base_model,
# torch_dtype=torch.float16,
device_map=device_map,
)
model.resize_token_embeddings(len(tokenizer))
if not ddp and torch.cuda.device_count() > 1:
model.is_parallelizable = True
model.model_parallel = True
trainer = transformers.Trainer(
model=model,
train_dataset=train_data,
eval_dataset=valid_data,
args=transformers.TrainingArguments(
seed=args.seed,
per_device_train_batch_size=args.per_device_batch_size,
per_device_eval_batch_size=args.per_device_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
warmup_ratio=args.warmup_ratio,
num_train_epochs=args.epochs,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
lr_scheduler_type=args.lr_scheduler_type,
fp16=args.fp16,
bf16=args.bf16,
logging_steps=args.logging_step,
optim=args.optim,
gradient_checkpointing=gradient_checkpointing,
evaluation_strategy=args.save_and_eval_strategy,
save_strategy=args.save_and_eval_strategy,
eval_steps=args.save_and_eval_steps,
save_steps=args.save_and_eval_steps,
output_dir=args.output_dir,
save_total_limit=5,
load_best_model_at_end=True,
deepspeed=args.deepspeed,
ddp_find_unused_parameters=False if ddp else None,
report_to=None,
eval_delay= 1 if args.save_and_eval_strategy=="epoch" else 2000,
),
tokenizer=tokenizer,
data_collator=collator,
)
model.config.use_cache = False
trainer.train(
resume_from_checkpoint=args.resume_from_checkpoint,
)
trainer.save_state()
trainer.save_model(output_dir=args.output_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='LLMRec')
parser = parse_global_args(parser)
parser = parse_train_args(parser)
parser = parse_dataset_args(parser)
args = parser.parse_args()
train(args)