Skip to content

Commit

Permalink
Merge pull request #880 from ExponentialML/features/lora-update-v0-1-2
Browse files Browse the repository at this point in the history
Update LoRA with Dropout & Conv2d Support
  • Loading branch information
d8ahazard authored Jan 31, 2023
2 parents 801f0cc + c8cf6e7 commit b95e67c
Show file tree
Hide file tree
Showing 9 changed files with 740 additions and 296 deletions.
4 changes: 3 additions & 1 deletion dreambooth/dataclasses/db_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ class DreamboothConfig(BaseModel):
lifetime_revision: int = 0
lora_learning_rate: float = 1e-4
lora_model_name: str = ""
lora_rank: int = 4
lora_unet_rank: int = 4
lora_txt_rank: int = 4
lora_txt_learning_rate: float = 5e-5
lora_txt_weight: float = 1.0
lora_weight: float = 1.0
Expand Down Expand Up @@ -99,6 +100,7 @@ class DreamboothConfig(BaseModel):
use_concepts: bool = False
use_ema: bool = True
use_lora: bool = False
use_lora_extended: bool = False
use_subdir: bool = False
v2: bool = False

Expand Down
36 changes: 24 additions & 12 deletions dreambooth/diff_to_sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import unload_system_models, reload_system_models
from extensions.sd_dreambooth_extension.helpers.mytqdm import mytqdm
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import printi
from extensions.sd_dreambooth_extension.lora_diffusion.lora import weight_apply_lora
from extensions.sd_dreambooth_extension.lora_diffusion.lora import merge_loras_to_pipe, get_target_module

unet_conversion_map = [
# (stable-diffusion, HF Diffusers)
Expand Down Expand Up @@ -393,21 +393,33 @@ def compile_checkpoint(model_name: str, lora_path: str=None, reload_models: bool
model_dir = os.path.dirname(cmd_lora_models_path) if cmd_lora_models_path else shared.models_path
lora_path = os.path.join(model_dir, "lora", lora_path)
printi(f"Loading lora from {lora_path}", log=log)

if os.path.exists(lora_path):
lora_txt = lora_path.replace(".pt", "_txt.pt")
checkpoint_path = checkpoint_path.replace(checkpoint_ext, f"_lora{checkpoint_ext}")
printi(f"Applying lora weight of alpha: {config.lora_weight} to unet...", log=log)
weight_apply_lora(loaded_pipeline.unet, torch.load(lora_path), alpha=config.lora_weight)
printi("Saving lora unet...", log=log)

printi(f"Saving UNET Lora and applying lora alpha of {config.lora_weight}", log=log)
if os.path.exists(lora_txt): printi(f"Saving Text Lora and applying lora alpha of {config.lora_txt_weight}", log=log)
merge_loras_to_pipe(
loaded_pipeline,
lora_path,
lora_alpha=config.lora_weight,
lora_txt_alpha=config.lora_txt_weight,
r=config.lora_unet_rank,
r_txt=config.lora_txt_rank,
unet_target_module=get_target_module("module", config.use_lora_extended)
)


loaded_pipeline.unet.save_pretrained(os.path.join(config.pretrained_model_name_or_path, "unet_lora"))
unet_path = osp.join(config.pretrained_model_name_or_path, "unet_lora", "diffusion_pytorch_model.bin")
lora_txt = lora_path.replace(".pt", "_txt.pt")
if os.path.exists(lora_txt):
printi(f"Applying lora weight of alpha: {config.lora_txt_weight} to text encoder...", log=log)
weight_apply_lora(loaded_pipeline.text_encoder, torch.load(lora_txt), target_replace_module=["CLIPAttention"], alpha=config.lora_weight)
printi("Saving lora text encoder...", log=log)
loaded_pipeline.text_encoder.save_pretrained(
os.path.join(config.pretrained_model_name_or_path, "text_encoder_lora"))
text_enc_path = osp.join(config.pretrained_model_name_or_path, "text_encoder_lora", "pytorch_model.bin")

if os.path.exists(lora_txt):
loaded_pipeline.text_encoder.save_pretrained(
os.path.join(config.pretrained_model_name_or_path, "text_encoder_lora")
)
text_enc_path = osp.join(config.pretrained_model_name_or_path, "text_encoder_lora", "pytorch_model.bin")

del loaded_pipeline

# Convert the UNet model
Expand Down
26 changes: 17 additions & 9 deletions dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from extensions.sd_dreambooth_extension.dreambooth.dataset.sample_dataset import SampleDataset
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import cleanup, parse_logs, printm
from extensions.sd_dreambooth_extension.dreambooth.xattention import optim_to
from extensions.sd_dreambooth_extension.lora_diffusion.lora import save_lora_weight, inject_trainable_lora
from extensions.sd_dreambooth_extension.lora_diffusion.lora import save_lora_weight, TEXT_ENCODER_DEFAULT_TARGET_REPLACE, get_target_module

logger = logging.getLogger(__name__)
# define a Handler which writes DEBUG messages or higher to the sys.stderr
Expand Down Expand Up @@ -234,18 +234,23 @@ def create_vae():
lora_path = None
lora_txt = None

unet_lora_params, _ = inject_trainable_lora(
injectable_lora = get_target_module("injection", args.use_lora_extended)
target_module = get_target_module("module", args.use_lora_extended)

unet_lora_params, _ = injectable_lora(
unet,
r=args.lora_rank,
loras=lora_path
r=args.lora_unet_rank,
loras=lora_path,
target_replace_module=target_module
)

if stop_text_percentage != 0:
text_encoder.requires_grad_(False)
text_encoder_lora_params, _ = inject_trainable_lora(
inject_trainable_txt_lora = get_target_module("injection", False)
text_encoder_lora_params, _ = inject_trainable_txt_lora(
text_encoder,
target_replace_module=["CLIPAttention"],
r=args.lora_rank,
target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
r=args.lora_txt_rank,
loras=lora_txt
)
printm("Lora loaded")
Expand Down Expand Up @@ -497,6 +502,7 @@ def collate_fn(examples):
print(f" UNET: {args.train_unet}")
print(f" Freeze CLIP Normalization Layers: {args.freeze_clip_normalization}")
print(f" LR: {args.learning_rate}")
if args.use_lora_extended: print(f" LoRA Extended: {args.use_lora_extended}")
if args.use_lora and stop_text_percentage > 0: print(f" LoRA Text Encoder LR: {args.lora_txt_learning_rate}")
print(f" V2: {args.v2}")

Expand Down Expand Up @@ -641,12 +647,14 @@ def save_weights(save_image, save_model, save_snapshot, save_checkpoint, save_lo
out_file = os.path.join(model_dir, "lora")
os.makedirs(out_file, exist_ok=True)
out_file = os.path.join(out_file, f"{lora_model_name}_{args.revision}.pt")
save_lora_weight(s_pipeline.unet, out_file)

target_module = get_target_module("module", args.use_lora_extended)
save_lora_weight(s_pipeline.unet, out_file, target_module)
if stop_text_percentage != 0:
out_txt = out_file.replace(".pt", "_txt.pt")
save_lora_weight(s_pipeline.text_encoder,
out_txt,
target_replace_module=["CLIPAttention"],
target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
)
pbar.update()

Expand Down
2 changes: 1 addition & 1 deletion dreambooth/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def get_db_models():

def get_lora_models():
model_dir = shared.lora_models_path
out_dir = os.path.join(model_dir, "lora")
out_dir = model_dir
output = [""]
if os.path.exists(out_dir):
dirs = os.listdir(out_dir)
Expand Down
17 changes: 14 additions & 3 deletions helpers/image_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,22 @@
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import get_checkpoint_match, reload_system_models, \
enable_safe_unpickle
from extensions.sd_dreambooth_extension.helpers.mytqdm import mytqdm
from extensions.sd_dreambooth_extension.lora_diffusion.lora import _text_lora_path_ui, patch_pipe, tune_lora_scale
from extensions.sd_dreambooth_extension.lora_diffusion.lora import _text_lora_path_ui, patch_pipe, tune_lora_scale, \
get_target_module
from modules import sd_models
from modules.processing import StableDiffusionProcessingTxt2Img


class ImageBuilder:
def __init__(self, config: DreamboothConfig, use_txt2img: bool, lora_model: str = None, batch_size: int = 1, accelerator: Accelerator = None):
def __init__(
self, config: DreamboothConfig,
use_txt2img: bool,
lora_model: str = None,
batch_size: int = 1,
accelerator: Accelerator = None,
lora_unet_rank: int = 4,
lora_txt_rank: int = 4
):
self.image_pipe = None
self.txt_pipe = None
self.resolution = config.resolution
Expand Down Expand Up @@ -89,8 +98,10 @@ def __init__(self, config: DreamboothConfig, use_txt2img: bool, lora_model: str
patch_pipe(
pipe=self.image_pipe,
unet_path=lora_model_path,
unet_target_replace_module=get_target_module("module", config.use_lora_extended),
token="None",
r=config.lora_rank
r=config.lora_unet_rank,
r_txt=config.lora_txt_rank
)

tune_lora_scale(self.image_pipe.unet, config.lora_weight)
Expand Down
4 changes: 3 additions & 1 deletion javascript/dreambooth.js
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,9 @@ let db_titles = {
"Load Settings": "Load last saved training parameters for the model.",
"Log Memory": "Log the current GPU memory usage.",
"Lora Model": "The Lora model to load for continued fine-tuning or checkpoint generation.",
"Lora Rank": "The rank of LoRA models. A low rank stores less information (~2MB file), and a higher rank stores more information (~60MB file). Default is 4 (~4MB file) with good results. Set based on your dataset and complexity before training [Advanced Usage].",
"Use Lora Extended": "Trains the Lora model with resnet layers. This will always improves quality and editability, but leads to bigger files.",
"Lora UNET Rank": "The rank for the Lora UNET (Default 4). Higher values = better quality with large file size. Lower values = sacrifice quality with lower file size. Learning rates work differently at different ranks. Saved loras at high precision (fp32) will lead to larger lora files.",
"Lora Text Encoder Rank": "The rank for the Lora Text Encoder (Default 4). Higher values = better quality with large file size. Lower values = sacrifice quality with lower file size. Learning rates work differently at different ranks. Saved loras at high precision (fp32) will lead to larger lora files.",
"Lora Text Learning Rate": "The learning rate at which to train lora text encoder. Regular learning rate is ignored.",
"Lora Text Weight": "What percentage of the lora weights should be applied to the text encoder when creating a checkpoint.",
"Lora UNET Learning Rate": "The learning rate at which to train lora unet. Regular learning rate is ignored.",
Expand Down
Loading

0 comments on commit b95e67c

Please sign in to comment.