-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathfinetune.py
136 lines (115 loc) · 4.14 KB
/
finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""
Adapted from https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py
# Usage
python finetune.py
"""
import warnings
warnings.filterwarnings("ignore")
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import argparse
import tensorflow as tf
from keras_cv.models.stable_diffusion.diffusion_model import DiffusionModel
from keras_cv.models.stable_diffusion.image_encoder import ImageEncoder
from keras_cv.models.stable_diffusion.noise_scheduler import NoiseScheduler
from tensorflow.keras import mixed_precision
from datasets import DatasetUtils
from trainer import Trainer
MAX_PROMPT_LENGTH = 77
CKPT_PREFIX = "ckpt"
def parse_args():
parser = argparse.ArgumentParser(
description="Script to fine-tune a Stable Diffusion model."
)
# Dataset related.
parser.add_argument("--dataset_archive", default=None, type=str)
parser.add_argument("--img_height", default=256, type=int)
parser.add_argument("--img_width", default=256, type=int)
# Optimization hyperparameters.
parser.add_argument("--lr", default=1e-5, type=float)
parser.add_argument("--wd", default=1e-2, type=float)
parser.add_argument("--beta_1", default=0.9, type=float)
parser.add_argument("--beta_2", default=0.999, type=float)
parser.add_argument("--epsilon", default=1e-08, type=float)
parser.add_argument("--ema", default=0.9999, type=float)
parser.add_argument("--max_grad_norm", default=1.0, type=float)
# Training hyperparameters.
parser.add_argument("--batch_size", default=4, type=int)
parser.add_argument("--num_epochs", default=100, type=int)
# Others.
parser.add_argument(
"--mp", action="store_true", help="Whether to use mixed-precision."
)
parser.add_argument(
"--pretrained_ckpt",
default=None,
type=str,
help="Provide a local path to a diffusion model checkpoint in the `h5`"
" format if you want to start over fine-tuning from this checkpoint.",
)
return parser.parse_args()
def run(args):
if args.mp:
print("Enabling mixed-precision...")
policy = mixed_precision.Policy("mixed_float16")
mixed_precision.set_global_policy(policy)
assert policy.compute_dtype == "float16"
assert policy.variable_dtype == "float32"
print("Initializing dataset...")
data_utils = DatasetUtils(
dataset_archive=args.dataset_archive,
batch_size=args.batch_size,
img_height=args.img_height,
img_width=args.img_width,
)
training_dataset = data_utils.prepare_dataset()
print("Initializing trainer...")
ckpt_path = (
CKPT_PREFIX
+ f"_epochs_{args.num_epochs}"
+ f"_res_{args.img_height}"
+ f"_mp_{args.mp}"
+ ".h5"
)
image_encoder = ImageEncoder(args.img_height, args.img_width)
diffusion_ft_trainer = Trainer(
diffusion_model=DiffusionModel(
args.img_height, args.img_width, MAX_PROMPT_LENGTH
),
# Remove the top layer from the encoder, which cuts off the variance and only returns
# the mean
vae=tf.keras.Model(
image_encoder.input,
image_encoder.layers[-2].output,
),
noise_scheduler=NoiseScheduler(),
pretrained_ckpt=args.pretrained_ckpt,
mp=args.mp,
ema=args.ema,
max_grad_norm=args.max_grad_norm,
)
print("Initializing optimizer...")
optimizer = tf.keras.optimizers.experimental.AdamW(
learning_rate=args.lr,
weight_decay=args.wd,
beta_1=args.beta_1,
beta_2=args.beta_2,
epsilon=args.epsilon,
)
if args.mp:
optimizer = mixed_precision.LossScaleOptimizer(optimizer)
print("Compiling trainer...")
diffusion_ft_trainer.compile(optimizer=optimizer, loss="mse")
print("Training...")
ckpt_callback = tf.keras.callbacks.ModelCheckpoint(
ckpt_path,
save_weights_only=True,
monitor="loss",
mode="min",
)
diffusion_ft_trainer.fit(
training_dataset, epochs=args.num_epochs, callbacks=[ckpt_callback]
)
if __name__ == "__main__":
args = parse_args()
run(args)