From 61778c7e5042bad16cd858fc9a3313ee99745e4f Mon Sep 17 00:00:00 2001 From: Simo Ryu <35953539+cloneofsimo@users.noreply.github.com> Date: Thu, 29 Dec 2022 22:58:20 +0900 Subject: [PATCH 1/2] Revert "Other Acceleration tricks (#93)" This reverts commit eacf5014ddc58a20fda895e0a5126ae4b7adbae9. --- lora_diffusion/__init__.py | 1 - lora_diffusion/cli_lora_pti.py | 52 +++------------------- lora_diffusion/dataset.py | 79 ++++------------------------------ lora_diffusion/utils.py | 12 ------ 4 files changed, 16 insertions(+), 128 deletions(-) delete mode 100644 lora_diffusion/utils.py diff --git a/lora_diffusion/__init__.py b/lora_diffusion/__init__.py index 6434b23..0efbb94 100644 --- a/lora_diffusion/__init__.py +++ b/lora_diffusion/__init__.py @@ -1,3 +1,2 @@ from .lora import * from .dataset import * -from .utils import * diff --git a/lora_diffusion/cli_lora_pti.py b/lora_diffusion/cli_lora_pti.py index 8e9aa98..8bb91dc 100644 --- a/lora_diffusion/cli_lora_pti.py +++ b/lora_diffusion/cli_lora_pti.py @@ -142,10 +142,6 @@ def collate_fn(examples): "input_ids": input_ids, "pixel_values": pixel_values, } - - if examples[0].get("mask", None) is not None: - batch["mask"] = torch.stack([example["mask"] for example in examples]) - return batch train_dataloader = torch.utils.data.DataLoader( @@ -153,15 +149,14 @@ def collate_fn(examples): batch_size=train_batch_size, shuffle=True, collate_fn=collate_fn, + num_workers=2, ) return train_dataloader @torch.autocast("cuda") -def loss_step( - batch, unet, vae, text_encoder, scheduler, weight_dtype, t_mutliplier=1.0 -): +def loss_step(batch, unet, vae, text_encoder, scheduler, weight_dtype): latents = vae.encode( batch["pixel_values"].to(dtype=weight_dtype).to(unet.device) ).latent_dist.sample() @@ -172,7 +167,7 @@ def loss_step( timesteps = torch.randint( 0, - int(scheduler.config.num_train_timesteps * t_mutliplier), + scheduler.config.num_train_timesteps, (bsz,), device=latents.device, ) @@ -191,31 +186,6 @@ def loss_step( else: raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}") - if batch.get("mask", None) is not None: - - mask = ( - batch["mask"] - .to(model_pred.device) - .reshape( - model_pred.shape[0], 1, model_pred.shape[2] * 8, model_pred.shape[3] * 8 - ) - ) - # resize to match model_pred - mask = ( - F.interpolate( - mask.float(), - size=model_pred.shape[-2:], - mode="nearest", - ) - + 0.1 - ) - - mask = mask / mask.mean() - - model_pred = model_pred * mask - - target = target * mask - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") return loss @@ -303,15 +273,7 @@ def perform_tuning( for batch in dataloader: optimizer.zero_grad() - loss = loss_step( - batch, - unet, - vae, - text_encoder, - scheduler, - weight_dtype, - t_mutliplier=0.8, - ) + loss = loss_step(batch, unet, vae, text_encoder, scheduler, weight_dtype) loss.backward() torch.nn.utils.clip_grad_norm_( itertools.chain(unet.parameters(), text_encoder.parameters()), 1.0 @@ -360,7 +322,7 @@ def train( class_data_dir: Optional[str] = None, stochastic_attribute: Optional[str] = None, perform_inversion: bool = True, - use_template: Literal[None, "object", "style"] = None, + use_template: Optional[str] = Literal[None, "object", "style"], placeholder_tokens: str = "", placeholder_token_at_data: Optional[str] = None, initializer_tokens: str = "dog", @@ -370,6 +332,7 @@ def train( num_class_images: int = 100, seed: int = 42, resolution: int = 512, + center_crop: bool = False, color_jitter: bool = True, train_batch_size: int = 1, sample_batch_size: int = 1, @@ -387,7 +350,6 @@ def train( learning_rate_ti: float = 5e-4, continue_inversion: bool = True, continue_inversion_lr: Optional[float] = None, - use_face_segmentation_condition: bool = False, scale_lr: bool = False, lr_scheduler: str = "constant", lr_warmup_steps: int = 100, @@ -451,8 +413,8 @@ def train( class_prompt=class_prompt, tokenizer=tokenizer, size=resolution, + center_crop=center_crop, color_jitter=color_jitter, - use_face_segmentation_condition=use_face_segmentation_condition, ) train_dataloader = text2img_dataloader( diff --git a/lora_diffusion/dataset.py b/lora_diffusion/dataset.py index 37de046..6d2d712 100644 --- a/lora_diffusion/dataset.py +++ b/lora_diffusion/dataset.py @@ -1,12 +1,11 @@ from torch.utils.data import Dataset from typing import List, Tuple, Dict, Union, Optional -from PIL import Image, ImageFilter +from PIL import Image from torchvision import transforms from pathlib import Path -import cv2 + import random -import numpy as np OBJECT_TEMPLATE = [ "a photo of a {}", @@ -91,12 +90,12 @@ def __init__( class_prompt=None, size=512, h_flip=True, + center_crop=False, color_jitter=False, resize=True, - use_face_segmentation_condition=False, - blur_amount: int = 70, ): self.size = size + self.center_crop = center_crop self.tokenizer = tokenizer self.resize = resize @@ -122,7 +121,7 @@ def __init__( self.class_prompt = class_prompt else: self.class_data_root = None - self.h_flip = h_flip + self.image_transforms = transforms.Compose( [ transforms.Resize( @@ -130,24 +129,17 @@ def __init__( ) if resize else transforms.Lambda(lambda x: x), - transforms.ColorJitter(0.1, 0.1) + transforms.ColorJitter(0.2, 0.1) if color_jitter else transforms.Lambda(lambda x: x), + transforms.RandomHorizontalFlip() + if h_flip + else transforms.Lambda(lambda x: x), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) - self.use_face_segmentation_condition = use_face_segmentation_condition - if self.use_face_segmentation_condition: - import mediapipe as mp - - mp_face_detection = mp.solutions.face_detection - self.face_detection = mp_face_detection.FaceDetection( - model_selection=1, min_detection_confidence=0.5 - ) - self.blur_amount = blur_amount - def __len__(self): return self._length @@ -171,59 +163,6 @@ def __getitem__(self, index): for token, value in self.token_map.items(): text = text.replace(token, value) - if self.use_face_segmentation_condition: - image = cv2.imread( - str(self.instance_images_path[index % self.num_instance_images]) - ) - results = self.face_detection.process( - cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - ) - black_image = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8) - - if results.detections: - - for detection in results.detections: - - x_min = int( - detection.location_data.relative_bounding_box.xmin - * image.shape[1] - ) - y_min = int( - detection.location_data.relative_bounding_box.ymin - * image.shape[0] - ) - width = int( - detection.location_data.relative_bounding_box.width - * image.shape[1] - ) - height = int( - detection.location_data.relative_bounding_box.height - * image.shape[0] - ) - - # draw the colored rectangle - black_image[y_min : y_min + height, x_min : x_min + width] = 255 - - # blur the image - black_image = Image.fromarray(black_image, mode="L").filter( - ImageFilter.GaussianBlur(radius=self.blur_amount) - ) - # to tensor - black_image = transforms.ToTensor()(black_image) - # resize as the instance image - black_image = transforms.Resize( - self.size, interpolation=transforms.InterpolationMode.BILINEAR - )(black_image) - - example["mask"] = black_image - - if self.h_flip and random.random() > 0.5: - hflip = transforms.RandomHorizontalFlip(p=1) - - example["instance_images"] = hflip(example["instance_images"]) - if self.use_face_segmentation_condition: - example["mask"] = hflip(example["mask"]) - example["instance_prompt_ids"] = self.tokenizer( text, padding="do_not_pad", diff --git a/lora_diffusion/utils.py b/lora_diffusion/utils.py deleted file mode 100644 index cde4ba4..0000000 --- a/lora_diffusion/utils.py +++ /dev/null @@ -1,12 +0,0 @@ -from PIL import Image - - -def image_grid(_imgs, rows, cols): - - w, h = _imgs[0].size - grid = Image.new("RGB", size=(cols * w, rows * h)) - grid_w, grid_h = grid.size - - for i, img in enumerate(_imgs): - grid.paste(img, box=(i % cols * w, i // cols * h)) - return grid From 9624138c24cada3fffe38820b5cf290d2e76f5f7 Mon Sep 17 00:00:00 2001 From: Vincent Lordier Date: Fri, 30 Dec 2022 00:38:19 +0100 Subject: [PATCH 2/2] pep8 and comments --- lora_diffusion/__init__.py | 2 +- lora_diffusion/cli_lora_add.py | 45 +++++++----- lora_diffusion/cli_lora_pti.py | 89 +++++++++++++----------- lora_diffusion/dataset.py | 23 ++++--- lora_diffusion/lora.py | 116 +++++++++++++++---------------- lora_diffusion/to_ckpt_v2.py | 21 +++--- setup.py | 3 +- train_lora_dreambooth.py | 95 +++++++++++++++---------- train_lora_pt_caption.py | 121 +++++++++++++++++++------------- train_lora_w_ti.py | 122 ++++++++++++++++++++------------- 10 files changed, 373 insertions(+), 264 deletions(-) diff --git a/lora_diffusion/__init__.py b/lora_diffusion/__init__.py index 0efbb94..ad400cc 100644 --- a/lora_diffusion/__init__.py +++ b/lora_diffusion/__init__.py @@ -1,2 +1,2 @@ -from .lora import * from .dataset import * +from .lora import * diff --git a/lora_diffusion/cli_lora_add.py b/lora_diffusion/cli_lora_add.py index 3a416af..10992b9 100644 --- a/lora_diffusion/cli_lora_add.py +++ b/lora_diffusion/cli_lora_add.py @@ -1,23 +1,25 @@ -from typing import Literal, Union, Dict import os import shutil +from pathlib import Path +from typing import Literal, Union + import fire +import torch from diffusers import StableDiffusionPipeline -import torch -from .lora import tune_lora_scale, weight_apply_lora +from .lora import weight_apply_lora from .to_ckpt_v2 import convert_to_ckpt -def _text_lora_path(path: str) -> str: +def _text_lora_path(path: Union[str, Path]) -> Union[str, Path]: assert path.endswith(".pt"), "Only .pt files are supported" return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"]) def add( - path_1: str, - path_2: str, - output_path: str, + path_1: Union[str, Path], + path_2: Union[str, Path], + output_path: Union[str, Path], alpha: float = 0.5, mode: Literal[ "lpl", @@ -26,14 +28,14 @@ def add( ] = "lpl", with_text_lora: bool = False, ): - print("Lora Add, mode " + mode) + print(f"Lora Add, mode {mode}") if mode == "lpl": for _path_1, _path_2, opt in [(path_1, path_2, "unet")] + ( [(_text_lora_path(path_1), _text_lora_path(path_2), "text_encoder")] if with_text_lora else [] ): - print("Loading", _path_1, _path_2) + print(f"Loading {_path_1} {_path_2}") out_list = [] if opt == "text_encoder": if not os.path.exists(_path_1): @@ -50,7 +52,7 @@ def add( l2pairs = zip(l2[::2], l2[1::2]) for (x1, y1), (x2, y2) in zip(l1pairs, l2pairs): - # print("Merging", x1.shape, y1.shape, x2.shape, y2.shape) + # print(f'Merging {x1.shape} {y1.shape} {x2.shape} {y2.shape}') x1.data = alpha * x1.data + (1 - alpha) * x2.data y1.data = alpha * y1.data + (1 - alpha) * y2.data @@ -59,11 +61,15 @@ def add( if opt == "unet": - print("Saving merged UNET to", output_path) + print(f"Saving merged UNET to {output_path}") torch.save(out_list, output_path) elif opt == "text_encoder": - print("Saving merged text encoder to", _text_lora_path(output_path)) + print( + f"Saving merged text encoder to \ + {_text_lora_path(output_path)}" + ) + torch.save( out_list, _text_lora_path(output_path), @@ -75,7 +81,10 @@ def add( path_1, ).to("cpu") - weight_apply_lora(loaded_pipeline.unet, torch.load(path_2), alpha=alpha) + weight_apply_lora( + loaded_pipeline.unet, + torch.load(path_2), + alpha=alpha) if with_text_lora: weight_apply_lora( @@ -93,7 +102,10 @@ def add( path_1, ).to("cpu") - weight_apply_lora(loaded_pipeline.unet, torch.load(path_2), alpha=alpha) + weight_apply_lora( + loaded_pipeline.unet, + torch.load(path_2), + alpha=alpha) if with_text_lora: weight_apply_lora( loaded_pipeline.text_encoder, @@ -102,7 +114,8 @@ def add( target_replace_module=["CLIPAttention"], ) - _tmp_output = output_path + ".tmp" + new_suffix = ".tmp" + _tmp_output = output_path.with_suffix(output_path.suffix + new_suffix) loaded_pipeline.save_pretrained(_tmp_output) convert_to_ckpt(_tmp_output, output_path, as_half=True) @@ -110,7 +123,7 @@ def add( shutil.rmtree(_tmp_output) else: - print("Unknown mode", mode) + print(f"Unknown mode {mode}") raise ValueError(f"Unknown mode {mode}") diff --git a/lora_diffusion/cli_lora_pti.py b/lora_diffusion/cli_lora_pti.py index 8bb91dc..1dbefd6 100644 --- a/lora_diffusion/cli_lora_pti.py +++ b/lora_diffusion/cli_lora_pti.py @@ -1,43 +1,24 @@ # Bootstrapped from: # https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py -import argparse -import hashlib -import inspect import itertools import math import os -import random -import re -from pathlib import Path -from typing import Optional, List, Literal +from typing import List, Literal, Optional +import fire import torch import torch.nn.functional as F import torch.optim as optim import torch.utils.checkpoint -from diffusers import ( - AutoencoderKL, - DDPMScheduler, - StableDiffusionPipeline, - UNet2DConditionModel, -) -from diffusers.optimization import get_scheduler -from huggingface_hub import HfFolder, Repository, whoami -from PIL import Image -from torch.utils.data import Dataset -from torchvision import transforms +from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer -import fire - from lora_diffusion import ( PivotalTuningDatasetCapation, - extract_lora_ups_down, inject_trainable_lora, inspect_lora, - save_lora_weight, save_all, ) @@ -69,8 +50,9 @@ def get_models( num_added_tokens = tokenizer.add_tokens(token) if num_added_tokens == 0: raise ValueError( - f"The tokenizer already contains the token {token}. Please pass a different" - " `placeholder_token` that is not already in the tokenizer." + f"The tokenizer already contains the token {token}. \ + Please pass a different `placeholder_token` that is not \ + already in the tokenizer." ) placeholder_token_id = tokenizer.convert_tokens_to_ids(token) @@ -87,12 +69,15 @@ def get_models( token_embeds[0] ).clamp(-0.5, 0.5) elif init_tok == "": - token_embeds[placeholder_token_id] = torch.zeros_like(token_embeds[0]) + token_embeds[placeholder_token_id] = torch.zeros_like( + token_embeds[0]) else: token_ids = tokenizer.encode(init_tok, add_special_tokens=False) - # Check if initializer_token is a single token or a sequence of tokens + # Check if initializer_token is a single token or a sequence of + # tokens if len(token_ids) > 1: - raise ValueError("The initializer token must be a single token.") + raise ValueError( + "The initializer token must be a single token.") initializer_token_id = token_ids[0] token_embeds[placeholder_token_id] = token_embeds[initializer_token_id] @@ -117,7 +102,8 @@ def get_models( ) -def text2img_dataloader(train_dataset, train_batch_size, tokenizer, vae, text_encoder): +def text2img_dataloader(train_dataset, train_batch_size, + tokenizer, vae, text_encoder): def collate_fn(examples): input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] @@ -129,7 +115,8 @@ def collate_fn(examples): pixel_values += [example["class_images"] for example in examples] pixel_values = torch.stack(pixel_values) - pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + pixel_values = pixel_values.to( + memory_format=torch.contiguous_format).float() input_ids = tokenizer.pad( {"input_ids": input_ids}, @@ -175,7 +162,8 @@ def loss_step(batch, unet, vae, text_encoder, scheduler, weight_dtype): noisy_latents = scheduler.add_noise(latents, noise, timesteps) - encoder_hidden_states = text_encoder(batch["input_ids"].to(text_encoder.device))[0] + encoder_hidden_states = text_encoder( + batch["input_ids"].to(text_encoder.device))[0] model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample @@ -184,7 +172,8 @@ def loss_step(batch, unet, vae, text_encoder, scheduler, weight_dtype): elif scheduler.config.prediction_type == "v_prediction": target = scheduler.get_velocity(latents, noise, timesteps) else: - raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}") + raise ValueError( + f"Unknown prediction type {scheduler.config.prediction_type}") loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") return loss @@ -219,7 +208,13 @@ def train_inversion( text_encoder.train() for batch in dataloader: - loss = loss_step(batch, unet, vae, text_encoder, scheduler, weight_dtype) + loss = loss_step( + batch, + unet, + vae, + text_encoder, + scheduler, + weight_dtype) loss.backward() optimizer.step() progress_bar.update(1) @@ -238,7 +233,8 @@ def train_inversion( text_encoder=text_encoder, placeholder_token_ids=placeholder_token_ids, placeholder_tokens=placeholder_tokens, - save_path=os.path.join(save_path, f"step_inv_{global_step}.pt"), + save_path=os.path.join( + save_path, f"step_inv_{global_step}.pt"), save_lora=False, ) @@ -273,10 +269,17 @@ def perform_tuning( for batch in dataloader: optimizer.zero_grad() - loss = loss_step(batch, unet, vae, text_encoder, scheduler, weight_dtype) + loss = loss_step( + batch, + unet, + vae, + text_encoder, + scheduler, + weight_dtype) loss.backward() torch.nn.utils.clip_grad_norm_( - itertools.chain(unet.parameters(), text_encoder.parameters()), 1.0 + itertools.chain( + unet.parameters(), text_encoder.parameters()), 1.0 ) optimizer.step() progress_bar.update(1) @@ -289,10 +292,14 @@ def perform_tuning( text_encoder, placeholder_token_ids=placeholder_token_ids, placeholder_tokens=placeholder_tokens, - save_path=os.path.join(save_path, f"step_{global_step}.pt"), + save_path=os.path.join( + save_path, f"step_{global_step}.pt"), ) moved = ( - torch.tensor(list(itertools.chain(*inspect_lora(unet).values()))) + torch.tensor( + list( + itertools.chain( + *inspect_lora(unet).values()))) .mean() .item() ) @@ -300,7 +307,9 @@ def perform_tuning( print("LORA Unet Moved", moved) moved = ( torch.tensor( - list(itertools.chain(*inspect_lora(text_encoder).values())) + list( + itertools.chain( + *inspect_lora(text_encoder).values())) ) .mean() .item() @@ -508,7 +517,9 @@ def train( ] inspect_lora(text_encoder) - lora_optimizers = optim.AdamW(params_to_optimize, weight_decay=weight_decay_lora) + lora_optimizers = optim.AdamW( + params_to_optimize, + weight_decay=weight_decay_lora) unet.train() if train_text_encoder: diff --git a/lora_diffusion/dataset.py b/lora_diffusion/dataset.py index 6d2d712..fb31630 100644 --- a/lora_diffusion/dataset.py +++ b/lora_diffusion/dataset.py @@ -1,11 +1,10 @@ -from torch.utils.data import Dataset +import random +from pathlib import Path +from typing import Optional -from typing import List, Tuple, Dict, Union, Optional from PIL import Image +from torch.utils.data import Dataset from torchvision import transforms -from pathlib import Path - -import random OBJECT_TEMPLATE = [ "a photo of a {}", @@ -75,7 +74,8 @@ def _shuffle(lis): class PivotalTuningDatasetCapation(Dataset): """ - A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + A dataset to prepare the instance and class images + with the prompts for fine-tuning the model. It pre-processes the images and the tokenizes prompts. """ @@ -89,7 +89,7 @@ def __init__( class_data_root=None, class_prompt=None, size=512, - h_flip=True, + horizontal_flip=True, center_crop=False, color_jitter=False, resize=True, @@ -101,7 +101,7 @@ def __init__( self.instance_data_root = Path(instance_data_root) if not self.instance_data_root.exists(): - raise ValueError("Instance images root doesn't exists.") + raise ValueError("Instance images path doesn't exist.") self.instance_images_path = list(Path(instance_data_root).iterdir()) self.num_instance_images = len(self.instance_images_path) @@ -129,11 +129,11 @@ def __init__( ) if resize else transforms.Lambda(lambda x: x), - transforms.ColorJitter(0.2, 0.1) + transforms.ColorJitter(brightness=0.2, contrast=0.1) if color_jitter else transforms.Lambda(lambda x: x), transforms.RandomHorizontalFlip() - if h_flip + if horizontal_flip else transforms.Lambda(lambda x: x), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), @@ -158,7 +158,8 @@ def __getitem__(self, index): text = random.choice(self.templates).format(input_tok) else: - text = self.instance_images_path[index % self.num_instance_images].stem + text = self.instance_images_path[index % + self.num_instance_images].stem if self.token_map is not None: for token, value in self.token_map.items(): text = text.replace(token, value) diff --git a/lora_diffusion/lora.py b/lora_diffusion/lora.py index 0e4e3af..a254cdb 100644 --- a/lora_diffusion/lora.py +++ b/lora_diffusion/lora.py @@ -1,36 +1,42 @@ -import math from itertools import groupby -from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union +from pathlib import Path +from typing import Dict, List, Optional, Set, Tuple, Type, Union -import numpy as np -import PIL import torch import torch.nn as nn -import torch.nn.functional as F class LoraInjectedLinear(nn.Module): - def __init__(self, in_features, out_features, bias=False, r=4): + """ + Injects and initialises lora_rank Linear layers + """ + + def __init__(self, in_features, out_features, bias=False, lora_rank=4): super().__init__() - if r > min(in_features, out_features): + if lora_rank > min(in_features, out_features): raise ValueError( - f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}" + f"LoRA rank {lora_rank} must be less or equal than \ + {min(in_features, out_features)}" ) self.linear = nn.Linear(in_features, out_features, bias) - self.lora_down = nn.Linear(in_features, r, bias=False) - self.lora_up = nn.Linear(r, out_features, bias=False) + self.lora_down = nn.Linear(in_features, lora_rank, bias=False) + self.lora_up = nn.Linear(lora_rank, out_features, bias=False) self.scale = 1.0 - nn.init.normal_(self.lora_down.weight, std=1 / r) + nn.init.normal_(self.lora_down.weight, std=1 / lora_rank) nn.init.zeros_(self.lora_up.weight) def forward(self, input): - return self.linear(input) + self.lora_up(self.lora_down(input)) * self.scale + return self.linear(input) + \ + self.lora_up(self.lora_down(input)) * self.scale +# Default layers to replace in UNet UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"} + +# Default layers to replace in Text Encoder TEXT_ENCODER_DEFAULT_TARGET_REPLACE = {"CLIPAttention"} DEFAULT_TARGET_REPLACE = UNET_DEFAULT_TARGET_REPLACE @@ -46,18 +52,20 @@ def _find_children( Returns all matching modules, along with the parent of those moduless and the names they are referenced by. """ - # For each target find every linear_class module that isn't a child of a LoraInjectedLinear + # For each target find every linear_class module that isn't a child of a + # LoraInjectedLinear for parent in model.modules(): for name, module in parent.named_children(): if any([isinstance(module, _class) for _class in search_class]): yield parent, name, module -def _find_modules_v2( +def _find_modules( model, ancestor_class: Set[str] = DEFAULT_TARGET_REPLACE, search_class: List[Type[nn.Module]] = [nn.Linear], - exclude_children_of: Optional[List[Type[nn.Module]]] = [LoraInjectedLinear], + exclude_children_of: Optional[List[Type[nn.Module]]] = [ + LoraInjectedLinear], ): """ Find all modules of a certain class (or union of classes) that are direct or @@ -74,44 +82,27 @@ def _find_modules_v2( if module.__class__.__name__ in ancestor_class ) - # For each target find every linear_class module that isn't a child of a LoraInjectedLinear + # For each target find every linear_class module that isn't a child of a + # LoraInjectedLinear for ancestor in ancestors: for fullname, module in ancestor.named_modules(): if any([isinstance(module, _class) for _class in search_class]): - # Find the direct parent if this is a descendant, not a child, of target + # Find the direct parent if this is a descendant, not a child, + # of target *path, name = fullname.split(".") parent = ancestor while path: parent = parent.get_submodule(path.pop(0)) # Skip this linear if it's a child of a LoraInjectedLinear if exclude_children_of and any( - [isinstance(parent, _class) for _class in exclude_children_of] + [isinstance(parent, _class) + for _class in exclude_children_of] ): continue # Otherwise, yield it yield parent, name, module -def _find_modules_old( - model, - ancestor_class: Set[str] = DEFAULT_TARGET_REPLACE, - search_class: List[Type[nn.Module]] = [nn.Linear], - exclude_children_of: Optional[List[Type[nn.Module]]] = [LoraInjectedLinear], -): - ret = [] - for _module in model.modules(): - if _module.__class__.__name__ in ancestor_class: - - for name, _child_module in _module.named_modules(): - if _child_module.__class__ in search_class: - ret.append((_module, name, _child_module)) - print(ret) - return ret - - -_find_modules = _find_modules_v2 - - def inject_trainable_lora( model: nn.Module, target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE, @@ -125,7 +116,7 @@ def inject_trainable_lora( require_grad_params = [] names = [] - if loras != None: + if loras is not None: loras = torch.load(loras) for _module, name, _child_module in _find_modules( @@ -148,9 +139,10 @@ def inject_trainable_lora( _module._modules[name] = _tmp require_grad_params.append(_module._modules[name].lora_up.parameters()) - require_grad_params.append(_module._modules[name].lora_down.parameters()) + require_grad_params.append( + _module._modules[name].lora_down.parameters()) - if loras != None: + if loras is not None: _module._modules[name].lora_up.weight = loras.pop(0) _module._modules[name].lora_down.weight = loras.pop(0) @@ -178,7 +170,7 @@ def extract_lora_ups_down(model, target_replace_module=DEFAULT_TARGET_REPLACE): def save_lora_weight( model, - path="./lora.pt", + path: Union[str, Path] = "./lora.pt", target_replace_module=DEFAULT_TARGET_REPLACE, ): weights = [] @@ -237,7 +229,7 @@ def save_safeloras( def convert_loras_to_safeloras( modelmap: Dict[str, Tuple[str, Set[str], int]] = {}, - outpath="./lora.safetensors", + outpath: Union[str, Path] = "./lora.safetensors", ): """ Converts the Lora from multiple pytorch .pt files into a single safetensor file. @@ -293,7 +285,8 @@ def parse_safeloras( metadata = safeloras.metadata() - get_name = lambda k: k.split(":")[0] + def get_name(k): + return k.split(":")[0] keys = list(safeloras.keys()) keys.sort(key=get_name) @@ -302,7 +295,8 @@ def parse_safeloras( # Extract the targets target = json.loads(metadata[name]) - # Build the result lists - Python needs us to preallocate lists to insert into them + # Build the result lists - Python needs us to preallocate lists to + # insert into them module_keys = list(module_keys) ranks = [None] * (len(module_keys) // 2) weights = [None] * len(module_keys) @@ -325,7 +319,9 @@ def parse_safeloras( def load_safeloras(path, device="cpu"): - + """ + Load a LoRA using safetensors + """ from safetensors.torch import safe_open safeloras = safe_open(path, framework="pt", device=device) @@ -426,7 +422,8 @@ def monkeypatch_or_replace_lora( r: Union[int, List[int]] = 4, ): for _module, name, _child_module in _find_modules( - model, target_replace_module, search_class=[nn.Linear, LoraInjectedLinear] + model, target_replace_module, search_class=[ + nn.Linear, LoraInjectedLinear] ): _source = ( _child_module.linear @@ -483,7 +480,10 @@ def monkeypatch_remove_lora(model): _source = _child_module.linear weight, bias = _source.weight, _source.bias - _tmp = nn.Linear(_source.in_features, _source.out_features, bias is not None) + _tmp = nn.Linear( + _source.in_features, + _source.out_features, + bias is not None) _tmp.weight = weight if bias is not None: @@ -525,18 +525,18 @@ def tune_lora_scale(model, alpha: float = 1.0): _module.scale = alpha -def _text_lora_path(path: str) -> str: +def _text_lora_path(path: Union[str, Path]) -> str: assert path.endswith(".pt"), "Only .pt files are supported" return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"]) -def _ti_lora_path(path: str) -> str: +def _ti_lora_path(path: Union[str, Path]) -> str: assert path.endswith(".pt"), "Only .pt files are supported" return ".".join(path.split(".")[:-1] + ["ti", "pt"]) def load_learned_embed_in_clip( - learned_embeds_path, + learned_embeds_path: Union[str, Path], text_encoder, tokenizer, token: Union[str, List[str]] = None, @@ -585,7 +585,7 @@ def load_learned_embed_in_clip( def patch_pipe( pipe, - unet_path, + unet_path: Union[str, Path], token: Optional[str] = None, r: int = 4, patch_unet=True, @@ -652,24 +652,24 @@ def save_all( text_encoder, placeholder_token_ids, placeholder_tokens, - save_path, + save_path: Union[str, Path], save_lora=True, target_replace_module_text=TEXT_ENCODER_DEFAULT_TARGET_REPLACE, target_replace_module_unet=DEFAULT_TARGET_REPLACE, ): - # save ti + # save textual inversion ti_path = _ti_lora_path(save_path) learned_embeds_dict = {} for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids): learned_embeds = text_encoder.get_input_embeddings().weight[tok_id] print( - f"Current Learned Embeddings for {tok}:, id {tok_id} ", learned_embeds[:4] + f"Current Learned Embeddings for {tok}:, id {tok_id} {learned_embeds[:4]}" ) learned_embeds_dict[tok] = learned_embeds.detach().cpu() torch.save(learned_embeds_dict, ti_path) - print("Ti saved to ", ti_path) + print(f"Textual Inversion saved to {ti_path}") # save text encoder if save_lora: @@ -677,11 +677,11 @@ def save_all( save_lora_weight( unet, save_path, target_replace_module=target_replace_module_unet ) - print("Unet saved to ", save_path) + print(f"Unet saved to {save_path}") save_lora_weight( text_encoder, _text_lora_path(save_path), target_replace_module=target_replace_module_text, ) - print("Text Encoder saved to ", _text_lora_path(save_path)) + print(f"Text Encoder saved to {_text_lora_path(save_path)}") diff --git a/lora_diffusion/to_ckpt_v2.py b/lora_diffusion/to_ckpt_v2.py index 15f3947..b99b9bb 100644 --- a/lora_diffusion/to_ckpt_v2.py +++ b/lora_diffusion/to_ckpt_v2.py @@ -3,12 +3,10 @@ # *Only* converts the UNet, VAE, and Text Encoder. # Does not convert optimizer state or any other thing. # Written by jachiam -import argparse import os.path as osp import torch - # =================# # UNet Conversion # # =================# @@ -47,13 +45,15 @@ # loop over resnets/attentions for downblocks hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." - unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + unet_conversion_map_layer.append( + (sd_down_res_prefix, hf_down_res_prefix)) if i < 3: # no attention layers in down_blocks.3 hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." - unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + unet_conversion_map_layer.append( + (sd_down_atn_prefix, hf_down_atn_prefix)) for j in range(3): # loop over resnets/attentions for upblocks @@ -65,18 +65,21 @@ # no attention layers in up_blocks.0 hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." - unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) + unet_conversion_map_layer.append( + (sd_up_atn_prefix, hf_up_atn_prefix)) if i < 3: # no downsample in down_blocks.3 hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." - unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + unet_conversion_map_layer.append( + (sd_downsample_prefix, hf_downsample_prefix)) # no upsample in up_blocks.3 hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." - unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) + unet_conversion_map_layer.append( + (sd_upsample_prefix, hf_upsample_prefix)) hf_mid_atn_prefix = "mid_block.attentions.0." sd_mid_atn_prefix = "middle_block.1." @@ -215,7 +218,9 @@ def convert_to_ckpt(model_path, checkpoint_path, as_half): # Convert the VAE model vae_state_dict = torch.load(vae_path, map_location="cpu") vae_state_dict = convert_vae_state_dict(vae_state_dict) - vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} + vae_state_dict = { + "first_stage_model." + k: v for k, + v in vae_state_dict.items()} # Convert the text encoder model text_enc_dict = torch.load(text_enc_path, map_location="cpu") diff --git a/setup.py b/setup.py index be6859b..80c98ac 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,8 @@ name="lora_diffusion", py_modules=["lora_diffusion"], version="0.0.7", - description="Low Rank Adaptation for Diffusion Models. Works with Stable Diffusion out-of-the-box.", + description="Low Rank Adaptation for Diffusion Models. \ + Works with Stable Diffusion out-of-the-box.", author="Simo Ryu", packages=find_packages(), entry_points={ diff --git a/train_lora_dreambooth.py b/train_lora_dreambooth.py index 362fcb3..64d005e 100644 --- a/train_lora_dreambooth.py +++ b/train_lora_dreambooth.py @@ -3,18 +3,15 @@ import argparse import hashlib +import inspect import itertools import math import os -import inspect from pathlib import Path -from typing import Optional import torch import torch.nn.functional as F import torch.utils.checkpoint - - from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import set_seed @@ -25,8 +22,9 @@ UNet2DConditionModel, ) from diffusers.optimization import get_scheduler -from huggingface_hub import HfFolder, Repository, whoami - +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer @@ -37,15 +35,6 @@ save_safeloras, ) -from torch.utils.data import Dataset -from PIL import Image -from torchvision import transforms - -from pathlib import Path - -import random -import re - class DreamBoothDataset(Dataset): """ @@ -182,7 +171,8 @@ def __getitem__(self, index): def parse_args(input_args=None): - parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser = argparse.ArgumentParser( + description="Simple example of a training script.") parser.add_argument( "--pretrained_model_name_or_path", type=str, @@ -462,7 +452,8 @@ def parse_args(input_args=None): if args.with_prior_preservation: if args.class_data_dir is None: - raise ValueError("You must specify a data directory for class images.") + raise ValueError( + "You must specify a data directory for class images.") if args.class_prompt is None: raise ValueError("You must specify prompt for class images.") else: @@ -490,7 +481,8 @@ def main(args): # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. - # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. + # TODO (patil-suraj): Remove this check when gradient accumulation with + # two models is enabled in accelerate. if ( args.train_text_encoder and args.gradient_accumulation_steps > 1 @@ -609,7 +601,9 @@ def main(args): for _up, _down in extract_lora_ups_down( text_encoder, target_replace_module=["CLIPAttention"] ): - print("Before training: text encoder First Layer lora up", _up.weight.data) + print( + "Before training: text encoder First Layer lora up", + _up.weight.data) print( "Before training: text encoder First Layer lora down", _down.weight.data ) @@ -628,7 +622,8 @@ def main(args): * accelerator.num_processes ) - # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB + # GPUs if args.use_8bit_adam: try: import bitsandbytes as bnb @@ -649,7 +644,8 @@ def main(args): params_to_optimize = ( [ - {"params": itertools.chain(*unet_lora_params), "lr": args.learning_rate}, + {"params": itertools.chain( + *unet_lora_params), "lr": args.learning_rate}, { "params": itertools.chain(*text_encoder_lora_params), "lr": text_lr, @@ -693,7 +689,8 @@ def collate_fn(examples): pixel_values += [example["class_images"] for example in examples] pixel_values = torch.stack(pixel_values) - pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + pixel_values = pixel_values.to( + memory_format=torch.contiguous_format).float() input_ids = tokenizer.pad( {"input_ids": input_ids}, @@ -755,19 +752,23 @@ def collate_fn(examples): # Move text_encode and vae to gpu. # For mixed precision training we cast the text_encoder and vae weights to half-precision - # as these models are only used for inference, keeping weights in full precision is not required. + # as these models are only used for inference, keeping weights in full + # precision is not required. vae.to(accelerator.device, dtype=weight_dtype) if not args.train_text_encoder: text_encoder.to(accelerator.device, dtype=weight_dtype) - # We need to recalculate our total training steps as the size of the training dataloader may have changed. + # We need to recalculate our total training steps as the size of the + # training dataloader may have changed. num_update_steps_per_epoch = math.ceil( len(train_dataloader) / args.gradient_accumulation_steps ) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs - args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + args.num_train_epochs = math.ceil( + args.max_train_steps / + num_update_steps_per_epoch) # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. @@ -789,7 +790,8 @@ def collate_fn(examples): print( f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" ) - print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + print( + f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") print(f" Total optimization steps = {args.max_train_steps}") # Only show the progress bar once on each machine. progress_bar = tqdm( @@ -825,32 +827,42 @@ def collate_fn(examples): # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_latents = noise_scheduler.add_noise( + latents, noise, timesteps) # Get the text embedding for conditioning encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Predict the noise residual - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + model_pred = unet( + noisy_latents, + timesteps, + encoder_hidden_states).sample # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(latents, noise, timesteps) + target = noise_scheduler.get_velocity( + latents, noise, timesteps) else: raise ValueError( f"Unknown prediction type {noise_scheduler.config.prediction_type}" ) if args.with_prior_preservation: - # Chunk the noise and model_pred into two parts and compute the loss on each part separately. - model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + # Chunk the noise and model_pred into two parts and compute the + # loss on each part separately. + model_pred, model_pred_prior = torch.chunk( + model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) # Compute instance loss loss = ( - F.mse_loss(model_pred.float(), target.float(), reduction="none") + F.mse_loss( + model_pred.float(), + target.float(), + reduction="none") .mean([1, 2, 3]) .mean() ) @@ -863,12 +875,17 @@ def collate_fn(examples): # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss else: - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + loss = F.mse_loss( + model_pred.float(), + target.float(), + reduction="mean") accelerator.backward(loss) if accelerator.sync_gradients: params_to_clip = ( - itertools.chain(unet.parameters(), text_encoder.parameters()) + itertools.chain( + unet.parameters(), + text_encoder.parameters()) if args.train_text_encoder else unet.parameters() ) @@ -880,7 +897,8 @@ def collate_fn(examples): global_step += 1 - # Checks if the accelerator has performed an optimization step behind the scenes + # Checks if the accelerator has performed an optimization step + # behind the scenes if accelerator.sync_gradients: if args.save_steps and global_step - last_save >= args.save_steps: if accelerator.is_main_process: @@ -911,7 +929,8 @@ def collate_fn(examples): f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.pt" ) filename_text_encoder = f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.text_encoder.pt" - print(f"save weights {filename_unet}, {filename_text_encoder}") + print( + f"save weights {filename_unet}, {filename_text_encoder}") save_lora_weight(pipeline.unet, filename_unet) if args.train_text_encoder: save_lora_weight( @@ -947,7 +966,9 @@ def collate_fn(examples): last_save = global_step - logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + logs = { + "loss": loss.detach().item(), + "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) diff --git a/train_lora_pt_caption.py b/train_lora_pt_caption.py index ae7aafc..9a5e778 100644 --- a/train_lora_pt_caption.py +++ b/train_lora_pt_caption.py @@ -3,19 +3,16 @@ import argparse import hashlib +import inspect import itertools import math import os -import inspect -from pathlib import Path -from typing import Optional import random +from pathlib import Path import torch import torch.nn.functional as F import torch.utils.checkpoint - - from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import set_seed @@ -26,26 +23,18 @@ UNet2DConditionModel, ) from diffusers.optimization import get_scheduler -from huggingface_hub import HfFolder, Repository, whoami - +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from lora_diffusion import ( + extract_lora_ups_down, inject_trainable_lora, save_lora_weight, - extract_lora_ups_down, ) -from torch.utils.data import Dataset -from PIL import Image -from torchvision import transforms - -from pathlib import Path - -import random -import re - def _randomset(lis): ret = [] @@ -181,7 +170,8 @@ def __getitem__(self, index): logger = get_logger(__name__) -def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path): +def save_progress(text_encoder, placeholder_token_id, + accelerator, args, save_path): logger.info("Saving embeddings") learned_embeds = ( accelerator.unwrap_model(text_encoder) @@ -190,12 +180,14 @@ def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_pa ) print("Current Learned Embeddings: ", learned_embeds[:4]) print("saved to ", save_path) - learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()} + learned_embeds_dict = { + args.placeholder_token: learned_embeds.detach().cpu()} torch.save(learned_embeds_dict, save_path) def parse_args(input_args=None): - parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser = argparse.ArgumentParser( + description="Simple example of a training script.") parser.add_argument( "--pretrained_model_name_or_path", type=str, @@ -494,7 +486,8 @@ def parse_args(input_args=None): if args.with_prior_preservation: if args.class_data_dir is None: - raise ValueError("You must specify a data directory for class images.") + raise ValueError( + "You must specify a data directory for class images.") if args.class_prompt is None: raise ValueError("You must specify prompt for class images.") else: @@ -532,7 +525,8 @@ def main(args): # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. - # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. + # TODO (patil-suraj): Remove this check when gradient accumulation with + # two models is enabled in accelerate. if ( args.train_text_encoder and args.gradient_accumulation_steps > 1 @@ -622,13 +616,16 @@ def main(args): ) # Convert the initializer_token, placeholder_token to ids - token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) + token_ids = tokenizer.encode( + args.initializer_token, + add_special_tokens=False) # Check if initializer_token is a single token or a sequence of tokens if len(token_ids) > 1: raise ValueError("The initializer token must be a single token.") initializer_token_id = token_ids[0] - placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) + placeholder_token_id = tokenizer.convert_tokens_to_ids( + args.placeholder_token) # Load models and create wrapper for stable diffusion text_encoder = CLIPTextModel.from_pretrained( @@ -638,7 +635,8 @@ def main(args): ) text_encoder.resize_token_embeddings(len(tokenizer)) - # Initialise the newly added placeholder token with the embeddings of the initializer token + # Initialise the newly added placeholder token with the embeddings of the + # initializer token token_embeds = text_encoder.get_input_embeddings().weight.data token_embeds[placeholder_token_id] = token_embeds[initializer_token_id] @@ -677,8 +675,12 @@ def main(args): for _up, _down in extract_lora_ups_down( text_encoder, target_replace_module=["CLIPAttention"] ): - print("Before training: text encoder First Layer lora up", _up.weight.data) - print("Before training: text encoder First Layer lora down", _down.weight.data) + print( + "Before training: text encoder First Layer lora up", + _up.weight.data) + print( + "Before training: text encoder First Layer lora down", + _down.weight.data) break if args.gradient_checkpointing: @@ -694,7 +696,8 @@ def main(args): * accelerator.num_processes ) - # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB + # GPUs if args.use_8bit_adam: try: import bitsandbytes as bnb @@ -708,7 +711,8 @@ def main(args): optimizer_class = torch.optim.AdamW params_to_optimize = [ - {"params": itertools.chain(*unet_lora_params), "lr": args.learning_rate}, + {"params": itertools.chain(*unet_lora_params), + "lr": args.learning_rate}, { "params": itertools.chain(*text_encoder_lora_params), "lr": args.learning_rate_text, @@ -762,7 +766,8 @@ def collate_fn(examples): pixel_values += [example["class_images"] for example in examples] pixel_values = torch.stack(pixel_values) - pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + pixel_values = pixel_values.to( + memory_format=torch.contiguous_format).float() input_ids = tokenizer.pad( {"input_ids": input_ids}, @@ -814,17 +819,21 @@ def collate_fn(examples): weight_dtype = torch.bfloat16 # For mixed precision training we cast the text_encoder and vae weights to half-precision - # as these models are only used for inference, keeping weights in full precision is not required. + # as these models are only used for inference, keeping weights in full + # precision is not required. vae.to(accelerator.device, dtype=weight_dtype) - # We need to recalculate our total training steps as the size of the training dataloader may have changed. + # We need to recalculate our total training steps as the size of the + # training dataloader may have changed. num_update_steps_per_epoch = math.ceil( len(train_dataloader) / args.gradient_accumulation_steps ) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs - args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + args.num_train_epochs = math.ceil( + args.max_train_steps / + num_update_steps_per_epoch) # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. @@ -846,7 +855,8 @@ def collate_fn(examples): print( f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" ) - print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + print( + f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") print(f" Total optimization steps = {args.max_train_steps}") # Only show the progress bar once on each machine. progress_bar = tqdm( @@ -893,32 +903,42 @@ def collate_fn(examples): # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_latents = noise_scheduler.add_noise( + latents, noise, timesteps) # Get the text embedding for conditioning encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Predict the noise residual - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + model_pred = unet( + noisy_latents, + timesteps, + encoder_hidden_states).sample # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(latents, noise, timesteps) + target = noise_scheduler.get_velocity( + latents, noise, timesteps) else: raise ValueError( f"Unknown prediction type {noise_scheduler.config.prediction_type}" ) if args.with_prior_preservation: - # Chunk the noise and model_pred into two parts and compute the loss on each part separately. - model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + # Chunk the noise and model_pred into two parts and compute the + # loss on each part separately. + model_pred, model_pred_prior = torch.chunk( + model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) # Compute instance loss loss = ( - F.mse_loss(model_pred.float(), target.float(), reduction="none") + F.mse_loss( + model_pred.float(), + target.float(), + reduction="none") .mean([1, 2, 3]) .mean() ) @@ -931,12 +951,17 @@ def collate_fn(examples): # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss else: - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + loss = F.mse_loss( + model_pred.float(), + target.float(), + reduction="mean") accelerator.backward(loss) if accelerator.sync_gradients: params_to_clip = ( - itertools.chain(unet.parameters(), text_encoder.parameters()) + itertools.chain( + unet.parameters(), + text_encoder.parameters()) if args.train_text_encoder else unet.parameters() ) @@ -947,8 +972,10 @@ def collate_fn(examples): progress_bar.update(1) optimizer.zero_grad() - # Let's make sure we don't update any embedding weights besides the newly added token - index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id + # Let's make sure we don't update any embedding weights besides the + # newly added token + index_no_updates = torch.arange( + len(tokenizer)) != placeholder_token_id with torch.no_grad(): text_encoder.get_input_embeddings().weight[ index_no_updates @@ -956,7 +983,8 @@ def collate_fn(examples): global_step += 1 - # Checks if the accelerator has performed an optimization step behind the scenes + # Checks if the accelerator has performed an optimization step + # behind the scenes if accelerator.sync_gradients: if args.save_steps and global_step - last_save >= args.save_steps: if accelerator.is_main_process: @@ -987,7 +1015,8 @@ def collate_fn(examples): f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.pt" ) filename_text_encoder = f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.text_encoder.pt" - print(f"save weights {filename_unet}, {filename_text_encoder}") + print( + f"save weights {filename_unet}, {filename_text_encoder}") save_lora_weight(pipeline.unet, filename_unet) save_lora_weight( diff --git a/train_lora_w_ti.py b/train_lora_w_ti.py index 868dcff..a3d0ed2 100644 --- a/train_lora_w_ti.py +++ b/train_lora_w_ti.py @@ -3,19 +3,16 @@ import argparse import hashlib +import inspect import itertools import math import os -import inspect -from pathlib import Path -from typing import Optional import random +from pathlib import Path import torch import torch.nn.functional as F import torch.utils.checkpoint - - from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import set_seed @@ -26,27 +23,18 @@ UNet2DConditionModel, ) from diffusers.optimization import get_scheduler -from huggingface_hub import HfFolder, Repository, whoami - +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from lora_diffusion import ( + extract_lora_ups_down, inject_trainable_lora, save_lora_weight, - extract_lora_ups_down, ) -from torch.utils.data import Dataset -from PIL import Image -from torchvision import transforms - -from pathlib import Path - -import random -import re - - imagenet_templates_small = [ "a photo of a {}", "a rendering of a {}", @@ -267,7 +255,8 @@ def __getitem__(self, index): logger = get_logger(__name__) -def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path): +def save_progress(text_encoder, placeholder_token_id, + accelerator, args, save_path): logger.info("Saving embeddings") learned_embeds = ( accelerator.unwrap_model(text_encoder) @@ -276,12 +265,14 @@ def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_pa ) print("Current Learned Embeddings: ", learned_embeds[:4]) print("saved to ", save_path) - learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()} + learned_embeds_dict = { + args.placeholder_token: learned_embeds.detach().cpu()} torch.save(learned_embeds_dict, save_path) def parse_args(input_args=None): - parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser = argparse.ArgumentParser( + description="Simple example of a training script.") parser.add_argument( "--pretrained_model_name_or_path", type=str, @@ -587,7 +578,8 @@ def parse_args(input_args=None): if args.with_prior_preservation: if args.class_data_dir is None: - raise ValueError("You must specify a data directory for class images.") + raise ValueError( + "You must specify a data directory for class images.") if args.class_prompt is None: raise ValueError("You must specify prompt for class images.") else: @@ -625,7 +617,8 @@ def main(args): # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. - # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. + # TODO (patil-suraj): Remove this check when gradient accumulation with + # two models is enabled in accelerate. if ( args.train_text_encoder and args.gradient_accumulation_steps > 1 @@ -715,13 +708,16 @@ def main(args): ) # Convert the initializer_token, placeholder_token to ids - token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) + token_ids = tokenizer.encode( + args.initializer_token, + add_special_tokens=False) # Check if initializer_token is a single token or a sequence of tokens if len(token_ids) > 1: raise ValueError("The initializer token must be a single token.") initializer_token_id = token_ids[0] - placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) + placeholder_token_id = tokenizer.convert_tokens_to_ids( + args.placeholder_token) # Load models and create wrapper for stable diffusion text_encoder = CLIPTextModel.from_pretrained( @@ -731,7 +727,8 @@ def main(args): ) text_encoder.resize_token_embeddings(len(tokenizer)) - # Initialise the newly added placeholder token with the embeddings of the initializer token + # Initialise the newly added placeholder token with the embeddings of the + # initializer token token_embeds = text_encoder.get_input_embeddings().weight.data token_embeds[placeholder_token_id] = token_embeds[initializer_token_id] @@ -770,8 +767,12 @@ def main(args): for _up, _down in extract_lora_ups_down( text_encoder, target_replace_module=["CLIPAttention"] ): - print("Before training: text encoder First Layer lora up", _up.weight.data) - print("Before training: text encoder First Layer lora down", _down.weight.data) + print( + "Before training: text encoder First Layer lora up", + _up.weight.data) + print( + "Before training: text encoder First Layer lora down", + _down.weight.data) break if args.gradient_checkpointing: @@ -787,7 +788,8 @@ def main(args): * accelerator.num_processes ) - # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB + # GPUs if args.use_8bit_adam: try: import bitsandbytes as bnb @@ -801,7 +803,8 @@ def main(args): optimizer_class = torch.optim.AdamW params_to_optimize = [ - {"params": itertools.chain(*unet_lora_params), "lr": args.learning_rate}, + {"params": itertools.chain(*unet_lora_params), + "lr": args.learning_rate}, { "params": itertools.chain(*text_encoder_lora_params), "lr": args.learning_rate_text, @@ -856,7 +859,8 @@ def collate_fn(examples): pixel_values += [example["class_images"] for example in examples] pixel_values = torch.stack(pixel_values) - pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + pixel_values = pixel_values.to( + memory_format=torch.contiguous_format).float() input_ids = tokenizer.pad( {"input_ids": input_ids}, @@ -908,17 +912,21 @@ def collate_fn(examples): weight_dtype = torch.bfloat16 # For mixed precision training we cast the text_encoder and vae weights to half-precision - # as these models are only used for inference, keeping weights in full precision is not required. + # as these models are only used for inference, keeping weights in full + # precision is not required. vae.to(accelerator.device, dtype=weight_dtype) - # We need to recalculate our total training steps as the size of the training dataloader may have changed. + # We need to recalculate our total training steps as the size of the + # training dataloader may have changed. num_update_steps_per_epoch = math.ceil( len(train_dataloader) / args.gradient_accumulation_steps ) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs - args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + args.num_train_epochs = math.ceil( + args.max_train_steps / + num_update_steps_per_epoch) # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. @@ -940,7 +948,8 @@ def collate_fn(examples): print( f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" ) - print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + print( + f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") print(f" Total optimization steps = {args.max_train_steps}") # Only show the progress bar once on each machine. progress_bar = tqdm( @@ -986,32 +995,42 @@ def collate_fn(examples): # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_latents = noise_scheduler.add_noise( + latents, noise, timesteps) # Get the text embedding for conditioning encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Predict the noise residual - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + model_pred = unet( + noisy_latents, + timesteps, + encoder_hidden_states).sample # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(latents, noise, timesteps) + target = noise_scheduler.get_velocity( + latents, noise, timesteps) else: raise ValueError( f"Unknown prediction type {noise_scheduler.config.prediction_type}" ) if args.with_prior_preservation: - # Chunk the noise and model_pred into two parts and compute the loss on each part separately. - model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + # Chunk the noise and model_pred into two parts and compute the + # loss on each part separately. + model_pred, model_pred_prior = torch.chunk( + model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) # Compute instance loss loss = ( - F.mse_loss(model_pred.float(), target.float(), reduction="none") + F.mse_loss( + model_pred.float(), + target.float(), + reduction="none") .mean([1, 2, 3]) .mean() ) @@ -1024,12 +1043,17 @@ def collate_fn(examples): # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss else: - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + loss = F.mse_loss( + model_pred.float(), + target.float(), + reduction="mean") accelerator.backward(loss) if accelerator.sync_gradients: params_to_clip = ( - itertools.chain(unet.parameters(), text_encoder.parameters()) + itertools.chain( + unet.parameters(), + text_encoder.parameters()) if args.train_text_encoder else unet.parameters() ) @@ -1040,8 +1064,10 @@ def collate_fn(examples): progress_bar.update(1) optimizer.zero_grad() - # Let's make sure we don't update any embedding weights besides the newly added token - index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id + # Let's make sure we don't update any embedding weights besides the + # newly added token + index_no_updates = torch.arange( + len(tokenizer)) != placeholder_token_id with torch.no_grad(): text_encoder.get_input_embeddings().weight[ index_no_updates @@ -1049,7 +1075,8 @@ def collate_fn(examples): global_step += 1 - # Checks if the accelerator has performed an optimization step behind the scenes + # Checks if the accelerator has performed an optimization step + # behind the scenes if accelerator.sync_gradients: if args.save_steps and global_step - last_save >= args.save_steps: if accelerator.is_main_process: @@ -1080,7 +1107,8 @@ def collate_fn(examples): f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.pt" ) filename_text_encoder = f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.text_encoder.pt" - print(f"save weights {filename_unet}, {filename_text_encoder}") + print( + f"save weights {filename_unet}, {filename_text_encoder}") save_lora_weight(pipeline.unet, filename_unet) save_lora_weight(