From c8cf6e7412ba9a77f760d8730c633c64ea08e4a7 Mon Sep 17 00:00:00 2001 From: ExponentialML Date: Tue, 31 Jan 2023 11:29:52 -0800 Subject: [PATCH] Fix LoRA merging. --- dreambooth/diff_to_sd.py | 12 ++++++++++-- helpers/image_builder.py | 2 +- lora_diffusion/lora.py | 31 ++++++++++++++++++++++++++----- 3 files changed, 37 insertions(+), 8 deletions(-) diff --git a/dreambooth/diff_to_sd.py b/dreambooth/diff_to_sd.py index 17661585..d74f25e6 100644 --- a/dreambooth/diff_to_sd.py +++ b/dreambooth/diff_to_sd.py @@ -21,7 +21,7 @@ from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import unload_system_models, reload_system_models from extensions.sd_dreambooth_extension.helpers.mytqdm import mytqdm from extensions.sd_dreambooth_extension.dreambooth.utils.utils import printi -from extensions.sd_dreambooth_extension.lora_diffusion.lora import merge_loras_to_pipe +from extensions.sd_dreambooth_extension.lora_diffusion.lora import merge_loras_to_pipe, get_target_module unet_conversion_map = [ # (stable-diffusion, HF Diffusers) @@ -400,7 +400,15 @@ def compile_checkpoint(model_name: str, lora_path: str=None, reload_models: bool printi(f"Saving UNET Lora and applying lora alpha of {config.lora_weight}", log=log) if os.path.exists(lora_txt): printi(f"Saving Text Lora and applying lora alpha of {config.lora_txt_weight}", log=log) - merge_loras_to_pipe(loaded_pipeline, lora_path, lora_alpha=config.lora_weight, lora_txt_alpha=config.lora_txt_weight) + merge_loras_to_pipe( + loaded_pipeline, + lora_path, + lora_alpha=config.lora_weight, + lora_txt_alpha=config.lora_txt_weight, + r=config.lora_unet_rank, + r_txt=config.lora_txt_rank, + unet_target_module=get_target_module("module", config.use_lora_extended) + ) loaded_pipeline.unet.save_pretrained(os.path.join(config.pretrained_model_name_or_path, "unet_lora")) diff --git a/helpers/image_builder.py b/helpers/image_builder.py index a22a51cb..797caaaf 100644 --- a/helpers/image_builder.py +++ b/helpers/image_builder.py @@ -101,7 +101,7 @@ def __init__( unet_target_replace_module=get_target_module("module", config.use_lora_extended), token="None", r=config.lora_unet_rank, - txt_r=config.lora_txt_rank + r_txt=config.lora_txt_rank ) tune_lora_scale(self.image_pipe.unet, config.lora_weight) diff --git a/lora_diffusion/lora.py b/lora_diffusion/lora.py index 8d6b17e2..e080332d 100644 --- a/lora_diffusion/lora.py +++ b/lora_diffusion/lora.py @@ -898,7 +898,7 @@ def patch_pipe( maybe_unet_path, token: Optional[str] = None, r: int = 4, - txt_r: int = 4, + r_txt: int = 4, patch_unet=True, patch_text=True, patch_ti=False, @@ -922,7 +922,12 @@ def patch_pipe( disable_safe_unpickle() if patch_unet: print("LoRA : Patching Unet") - monkeypatch_or_replace_lora( + lora_patch = get_target_module( + "patch", + bool(unet_target_replace_module == UNET_EXTENDED_TARGET_REPLACE) + ) + + lora_patch( pipe.unet, torch.load(unet_path), r=r, @@ -935,7 +940,7 @@ def patch_pipe( pipe.text_encoder, torch.load(text_path), target_replace_module=text_target_replace_module, - r=txt_r, + r=r_txt, ) enable_safe_unpickle() if patch_ti: @@ -1050,12 +1055,26 @@ def save_all( save_safeloras_with_embeds(loras, embeds, save_path) -def merge_loras_to_pipe(pipline, lora_path=None, lora_alpha: float = 1, lora_txt_alpha: float = 1): +def merge_loras_to_pipe( + pipline, + lora_path=None, + lora_alpha: float = 1, + lora_txt_alpha: float = 1, + r: int = 4, + r_txt: int = 4, + unet_target_module=UNET_DEFAULT_TARGET_REPLACE + ): print( f"Merging UNET/CLIP with LoRA from {lora_path}. Merging ratio : UNET: {lora_alpha}, CLIP: {lora_txt_alpha}." ) - patch_pipe(pipline, lora_path) + patch_pipe( + pipline, + lora_path, + r=r, + r_txt=r_txt, + unet_target_replace_module=unet_target_module + ) collapse_lora(pipline.unet, lora_alpha) collapse_lora(pipline.text_encoder, lora_txt_alpha) @@ -1093,5 +1112,7 @@ def get_target_module(target_type: str = "injection", use_extended: bool = False return inject_trainable_lora if not use_extended else inject_trainable_lora_extended if target_type == "module": return UNET_DEFAULT_TARGET_REPLACE if not use_extended else UNET_EXTENDED_TARGET_REPLACE + if target_type == "patch": + return monkeypatch_or_replace_lora_extended if use_extended else monkeypatch_or_replace_lora