-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathtrain_dreambooth.py
205 lines (171 loc) · 6.92 KB
/
train_dreambooth.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import warnings
warnings.filterwarnings("ignore")
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import argparse
import math
import tensorflow as tf
from keras_cv.models.stable_diffusion.diffusion_model import DiffusionModel
from keras_cv.models.stable_diffusion.image_encoder import ImageEncoder
from keras_cv.models.stable_diffusion.noise_scheduler import NoiseScheduler
import tensorflow as tf
from tensorflow.keras import mixed_precision
from src import utils
from src.constants import MAX_PROMPT_LENGTH
from src.datasets import DatasetUtils
from src.dreambooth_trainer import DreamBoothTrainer
from src.utils import QualitativeValidationCallback, DreamBoothCheckpointCallback
import wandb
from wandb.keras import WandbMetricsLogger
# These hyperparameters come from this tutorial by Hugging Face:
# https://github.com/huggingface/diffusers/tree/main/examples/dreambooth
def get_optimizer(
lr=5e-6, beta_1=0.9, beta_2=0.999, weight_decay=(1e-2,), epsilon=1e-08
):
"""Instantiates the AdamW optimizer."""
optimizer = tf.keras.optimizers.experimental.AdamW(
learning_rate=lr,
weight_decay=weight_decay,
beta_1=beta_1,
beta_2=beta_2,
epsilon=epsilon,
)
return optimizer
def prepare_trainer(
img_resolution: int, train_text_encoder: bool, use_mp: bool, **kwargs
):
"""Instantiates and compiles `DreamBoothTrainer` for training."""
image_encoder = ImageEncoder(img_resolution, img_resolution)
dreambooth_trainer = DreamBoothTrainer(
diffusion_model=DiffusionModel(
img_resolution, img_resolution, MAX_PROMPT_LENGTH
),
# Remove the top layer from the encoder, which cuts off
# the variance and only returns the mean.
vae=tf.keras.Model(
image_encoder.input,
image_encoder.layers[-2].output,
),
noise_scheduler=NoiseScheduler(),
train_text_encoder=train_text_encoder,
use_mixed_precision=use_mp,
**kwargs,
)
optimizer = get_optimizer()
dreambooth_trainer.compile(optimizer=optimizer, loss="mse")
print("DreamBooth trainer initialized and compiled.")
return dreambooth_trainer
def train(dreambooth_trainer, train_dataset, max_train_steps, callbacks):
"""Performs DreamBooth training `DreamBoothTrainer` with the given `train_dataset`."""
num_update_steps_per_epoch = train_dataset.cardinality()
epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
print(f"Training for {epochs} epochs.")
dreambooth_trainer.fit(train_dataset, epochs=epochs, callbacks=callbacks)
def parse_args():
parser = argparse.ArgumentParser(
description="Script to perform DreamBooth training using Stable Diffusion."
)
# Dataset related.
parser.add_argument(
"--instance_images_url",
default="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/instance-images.tar.gz",
type=str,
)
parser.add_argument(
"--class_images_url",
default="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/class-images.tar.gz",
type=str,
)
parser.add_argument("--unique_id", default="sks", type=str)
parser.add_argument("--class_category", default="dog", type=str)
parser.add_argument("--img_resolution", default=512, type=int)
# Optimization hyperparameters.
parser.add_argument("--seed", default=42, type=int)
parser.add_argument("--lr", default=5e-6, type=float)
parser.add_argument("--wd", default=1e-2, type=float)
parser.add_argument("--beta_1", default=0.9, type=float)
parser.add_argument("--beta_2", default=0.999, type=float)
parser.add_argument("--epsilon", default=1e-08, type=float)
# Training hyperparameters.
parser.add_argument("--batch_size", default=1, type=int)
parser.add_argument("--max_train_steps", default=800, type=int)
parser.add_argument(
"--train_text_encoder",
action="store_true",
help="If fine-tune the text-encoder too.",
)
parser.add_argument(
"--mp", action="store_true", help="Whether to use mixed-precision."
)
# Misc.
parser.add_argument(
"--log_wandb", action="store_true", help="Whether to use Weights & Biases for experiment tracking.",
)
parser.add_argument(
"--validation_prompts",
nargs="+",
default=None,
type=str,
help="Prompts to generate samples for validation purposes and logging on Weights & Biases",
)
parser.add_argument(
"--num_images_to_generate",
default=5,
type=int,
help="Number of validation image to generate per prompt.",
)
return parser.parse_args()
def run(args):
# Set random seed for reproducibility
tf.keras.utils.set_random_seed(args.seed)
validation_prompts = [f"A photo of {args.unique_id} {args.class_category} in a bucket"]
if args.validation_prompts is not None:
validation_prompts = args.validation_prompts
run_name = f"lr@{args.lr}-max_train_steps@{args.max_train_steps}-train_text_encoder@{args.train_text_encoder}"
if args.log_wandb:
wandb.init(project="dreambooth-keras", name=run_name, config=vars(args))
if args.mp:
print("Enabling mixed-precision...")
policy = mixed_precision.Policy("mixed_float16")
mixed_precision.set_global_policy(policy)
assert policy.compute_dtype == "float16"
assert policy.variable_dtype == "float32"
print("Initializing dataset...")
data_util = DatasetUtils(
instance_images_url=args.instance_images_url,
class_images_url=args.class_images_url,
unique_id=args.unique_id,
class_category=args.class_category,
train_text_encoder=args.train_text_encoder,
batch_size=args.batch_size,
)
train_dataset = data_util.prepare_datasets()
print("Initializing trainer...")
ckpt_path_prefix = run_name
dreambooth_trainer = prepare_trainer(
args.img_resolution, args.train_text_encoder, args.mp
)
callbacks = [
# save model checkpoint and optionally log model checkpoints to
# Weights & Biases as artifacts
DreamBoothCheckpointCallback(ckpt_path_prefix, save_weights_only=True)
]
if args.log_wandb:
# log training metrics to Weights & Biases
callbacks.append(WandbMetricsLogger(log_freq="batch"))
# perform inference on validation prompts at the end of every epoch and
# log the resuts to a Weights & Biases table
callbacks.append(
QualitativeValidationCallback(
img_heigth=args.img_resolution,
img_width=args.img_resolution,
prompts=validation_prompts,
num_imgs_to_gen=args.num_images_to_generate,
)
)
train(dreambooth_trainer, train_dataset, args.max_train_steps, callbacks)
if args.log_wandb:
wandb.finish()
if __name__ == "__main__":
args = parse_args()
run(args)