Skip to content

Latest commit

 

History

History
268 lines (201 loc) · 14.8 KB

README.md

File metadata and controls

268 lines (201 loc) · 14.8 KB

Implementation of DreamBooth using KerasCV and TensorFlow

This repository provides an implementation of DreamBooth using KerasCV and TensorFlow. The implementation is heavily referred from Hugging Face's diffusers example.

DreamBooth is a way of quickly teaching (fine-tuning) Stable Diffusion about new visual concepts. For more details, refer to this document.

The code provided in this repository is for research purposes only. Please check out this section to know more about the potential use cases and limitations.

By loading this model you accept the CreativeML Open RAIL-M license at https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE.

If you're just looking for the accompanying resources of this repository, here are the links:

Table of contents

Update 15/02/2023: Thanks to Soumik Rakshit; we now have better utilities to support Weights and Biases (see #22).

Steps to perform DreamBooth training using the codebase

  1. Install the pre-requisites: pip install -r requirements.txt.

  2. You first need to choose a class to which a unique identifier is appended. This repository codebase was tested using sks as the unique idenitifer and dog as the class.

    Then two types of prompts are generated:

    (a) instance prompt: f"a photo of {self.unique_id} {self.class_category}" (b) class prompt: f"a photo of {self.class_category}"

  3. Instance images

    Get a few images (3 - 10) that are representative of the concept the model is going to be fine-tuned with. These images would be associated with the instance_prompt. These images are referred to as the instance_images from the codebase. Archive these images and host them somewhere online such that the archive can be downloaded using tf.keras.utils.get_file() function internally.

  4. Class images

    DreamBooth uses prior-preservation loss to regularize training. Long story cut short, prior-preservation loss helps the model to slowly adapt to the new concept under consideration from any prior knowledge it may have had about the concept. To use prior-preservation loss, we need the class prompt as shown above. The class prompt is used to generate a pre-defined number of images which are used for computing the final loss used for DreamBooth training.

    As per this resource, 200 - 300 images generated using the class prompt work well for most cases.

    So, after you have decided instance_prompt and class_prompt, use this Colab Notebook to generate some images that would be used for training with the prior-preservation loss. Then archive the generated images as a single archive and host it online such that it can be downloaded using using tf.keras.utils.get_file() function internally. In the codebase, we simply refer to these images as class_images.

It's possible to conduct DreamBooth training WITHOUT using a prior preservation loss. This repository always uses it. For people to easily test this codebase, we hosted the instance and class images here.

  1. Launch training! There are a number of hyperparameters you can play around with. Refer to the train_dreambooth.py script to know more about them. Here's a command that launches training with mixed-precision and other default values:

    python train_dreambooth.py --mp

    You can also fine-tune the text encoder by specifying the --train_text_encoder option.

    Additionally, the script supports integration with Weights and Biases (wandb). If you specify --log_wandb,

    • it will automatically log the training metrics to your wandb dashboard using the WandbMetricsLogger callback.
    • it will also upload your model checkpoints at the end of each epoch to your wandb project as an artifacts for model versioning. This is done using the DreamBoothCheckpointCallback which was built using WandbModelCheckpoint callback.
    • it will also perform inference with the DreamBoothed model parameters at the end of each epoch and log them into a wandb.Table in your wandb dashboard. This is done using the QualitativeValidationCallback, which also logs generated images into a media panel on your wandb dashboard at the end of the training.

    Here's a command that launches training and logs training metrics and generated images to your Weights & Biases workspace:

    python train_dreambooth.py \
      --log_wandb \
      --validation_prompts \
        "a photo of sks dog with a cat" \
        "a photo of sks dog riding a bicycle" \
        "a photo of sks dog peeing" \
        "a photo of sks dog playing cricket" \
        "a photo of sks dog as an astronaut"

    Here's an example wandb run where you can find the generated images as well as the model checkpoints.

Inference

Results

We have tested our implementation in two different methods: (a) fine-tuning the diffusion model (the UNet) only, (b) fine-tuning the diffusion model along with the text encoder. The experiments were conducted over a wide range of hyperparameters for learning rate and training steps for during training and for number of steps and unconditional guidance scale (ugs) during inference. But only the most salient results (from our perspective) are included here. If you are curious about how different hyperparameters affect the generated image quality, find the link to the full reports in each section.

Note that our experiments were guided by this blog post from Hugging Face.

(a) Fine-tuning diffusion model

Here are a selected few results from various experiments we conducted. Our experimental logs for this setting are available here. More visualization images (generated with the checkpoints from these experiments) are available here.

Images Steps UGS Setting
50 30 LR: 1e-6 Training steps: 800 (Weights)
25 15 LR: 1e-6 Training steps: 1000 (Weights)
75 15 LR: 3e-6 Training steps: 1200 (Weights)
Caption: "A photo of sks dog in a bucket"

(b) Fine-tuning text encoder + diffusion model

Images Steps ugs
75 15
75 30
"Caption: A photo of sks dog in a bucket"

w/ learning rate=9e-06, max train steps=200 (weights | reports)


Images Steps ugs
150 15
75 30
"Caption: A photo of sks person without mustache, handsome, ultra realistic, 4k, 8k"

w/ learning rate=9e-06, max train steps=200 (datasets | reports)


Using in Diffusers 🧨

The diffusers library provides state-of-the-art tooling for experimenting with different Diffusion models, including Stable Diffusion. It includes different optimization techniques that can be leveraged to perform efficient inference with diffusers when using large Stable Diffusion checkpoints. One particularly advantageous feature diffusers has is its support for different schedulers that can be configured during runtime and can be integrated into any compatible Diffusion model.

Once you have obtained the DreamBooth fine-tuned checkpoints using this codebase, you can actually export those into a handy StableDiffusionPipeline and use it from the diffusers library directly.

Consider this repository: chansung/dreambooth-dog. You can use the checkpoints of this repository in a StableDiffusionPipeline after running some small steps:

from diffusers import StableDiffusionPipeline

# checkpoint of the converted Stable Diffusion from KerasCV
model_ckpt = "sayakpaul/text-unet-dogs-kerascv_sd_diffusers_pipeline"
pipeline = StableDiffusionPipeline.from_pretrained(model_ckpt)
pipeline.to("cuda")

unique_id = "sks"
class_label = "dog"
prompt = f"A photo of {unique_id} {class_label} in a bucket"
image = pipeline(prompt, num_inference_steps=50).images[0]

Follow this guide to know more.

Experimental results through various scheduler settings:

We have converted fine-tuned checkpoint for the dog images into Diffusers compatible StableDiffusionPipeline and ran various experiments with different scheduler settings. For example, the following parameters of the DDIMScheduler are tested on a different set of guidance_scale and num_inference_steps.

num_inference_steps_list = [25, 50, 75, 100]
guidance_scale_list = [7.5, 15, 30]

scheduler_configs = {
  "DDIMScheduler": {
      "beta_value": [
          [0.000001, 0.02], 
          [0.000005, 0.02], 
          [0.00001, 0.02], 
          [0.00005, 0.02], 
          [0.0001, 0.02], 
          [0.0005, 0.02]
      ],
      "beta_schedule": [
          "linear", 
          "scaled_linear", 
          "squaredcos_cap_v2"
      ],
      "clip_sample": [True, False],
      "set_alpha_to_one": [True, False],
      "prediction_type": [
          "epsilon", 
          "sample", 
          "v_prediction"
      ]
  }
}

Below is the comparison between different values of beta_schedule parameters while others are fixed to their default values. Take a look at the original report which includes the results from other schedulers such as PNDMScheduler and LMSDiscreteScheduler.

It is often observed the default settings do guarantee to generate better quality images. For example, the default values of guidance_scale and beta_schedule are set to 7.5 and linear. However, when guidance_scale is set to 7.5, scaled_linear of the beta_schedule seems to work better. Or, when beta_schedule is set to linear, higher guidance_scale seems to work better.

We ran 4,800 experiments which generated 38,400 images in total. Those experiments are logged in Weights and Biases. If you are curious, do check them out here as well as the script that was used to run the experiments.

Notes on preparing data for DreamBooth training of faces

In addition to the tips and tricks shared in this blog post, we followed these things while preparing the instances for conducting DreamBooth training on human faces:

  • Instead of 3 - 5 images, use 20 - 25 images of the same person varying different angles, backgrounds, and poses.
  • No use of images containing multiple persons.
  • If the person wears glasses, don't include images only with glasses. Combine images with and without glasses.

Thanks to Abhishek Thakur for sharing these tips.

Acknowledgements

  • Thanks to Hugging Face for providing the original example. It's very readable and easy to understand.
  • Thanks to the ML Developer Programs' team at Google for providing GCP credits.