Skip to content

Commit

Permalink
Fix LoRA merging.
Browse files Browse the repository at this point in the history
  • Loading branch information
ExponentialML committed Jan 31, 2023
1 parent 3ceb581 commit c8cf6e7
Showing 3 changed files with 37 additions and 8 deletions.
12 changes: 10 additions & 2 deletions dreambooth/diff_to_sd.py
Original file line number Diff line number Diff line change
@@ -21,7 +21,7 @@
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import unload_system_models, reload_system_models
from extensions.sd_dreambooth_extension.helpers.mytqdm import mytqdm
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import printi
from extensions.sd_dreambooth_extension.lora_diffusion.lora import merge_loras_to_pipe
from extensions.sd_dreambooth_extension.lora_diffusion.lora import merge_loras_to_pipe, get_target_module

unet_conversion_map = [
# (stable-diffusion, HF Diffusers)
@@ -400,7 +400,15 @@ def compile_checkpoint(model_name: str, lora_path: str=None, reload_models: bool

printi(f"Saving UNET Lora and applying lora alpha of {config.lora_weight}", log=log)
if os.path.exists(lora_txt): printi(f"Saving Text Lora and applying lora alpha of {config.lora_txt_weight}", log=log)
merge_loras_to_pipe(loaded_pipeline, lora_path, lora_alpha=config.lora_weight, lora_txt_alpha=config.lora_txt_weight)
merge_loras_to_pipe(
loaded_pipeline,
lora_path,
lora_alpha=config.lora_weight,
lora_txt_alpha=config.lora_txt_weight,
r=config.lora_unet_rank,
r_txt=config.lora_txt_rank,
unet_target_module=get_target_module("module", config.use_lora_extended)
)


loaded_pipeline.unet.save_pretrained(os.path.join(config.pretrained_model_name_or_path, "unet_lora"))
2 changes: 1 addition & 1 deletion helpers/image_builder.py
Original file line number Diff line number Diff line change
@@ -101,7 +101,7 @@ def __init__(
unet_target_replace_module=get_target_module("module", config.use_lora_extended),
token="None",
r=config.lora_unet_rank,
txt_r=config.lora_txt_rank
r_txt=config.lora_txt_rank
)

tune_lora_scale(self.image_pipe.unet, config.lora_weight)
31 changes: 26 additions & 5 deletions lora_diffusion/lora.py
Original file line number Diff line number Diff line change
@@ -898,7 +898,7 @@ def patch_pipe(
maybe_unet_path,
token: Optional[str] = None,
r: int = 4,
txt_r: int = 4,
r_txt: int = 4,
patch_unet=True,
patch_text=True,
patch_ti=False,
@@ -922,7 +922,12 @@ def patch_pipe(
disable_safe_unpickle()
if patch_unet:
print("LoRA : Patching Unet")
monkeypatch_or_replace_lora(
lora_patch = get_target_module(
"patch",
bool(unet_target_replace_module == UNET_EXTENDED_TARGET_REPLACE)
)

lora_patch(
pipe.unet,
torch.load(unet_path),
r=r,
@@ -935,7 +940,7 @@ def patch_pipe(
pipe.text_encoder,
torch.load(text_path),
target_replace_module=text_target_replace_module,
r=txt_r,
r=r_txt,
)
enable_safe_unpickle()
if patch_ti:
@@ -1050,12 +1055,26 @@ def save_all(

save_safeloras_with_embeds(loras, embeds, save_path)

def merge_loras_to_pipe(pipline, lora_path=None, lora_alpha: float = 1, lora_txt_alpha: float = 1):
def merge_loras_to_pipe(
pipline,
lora_path=None,
lora_alpha: float = 1,
lora_txt_alpha: float = 1,
r: int = 4,
r_txt: int = 4,
unet_target_module=UNET_DEFAULT_TARGET_REPLACE
):
print(
f"Merging UNET/CLIP with LoRA from {lora_path}. Merging ratio : UNET: {lora_alpha}, CLIP: {lora_txt_alpha}."
)

patch_pipe(pipline, lora_path)
patch_pipe(
pipline,
lora_path,
r=r,
r_txt=r_txt,
unet_target_replace_module=unet_target_module
)

collapse_lora(pipline.unet, lora_alpha)
collapse_lora(pipline.text_encoder, lora_txt_alpha)
@@ -1093,5 +1112,7 @@ def get_target_module(target_type: str = "injection", use_extended: bool = False
return inject_trainable_lora if not use_extended else inject_trainable_lora_extended
if target_type == "module":
return UNET_DEFAULT_TARGET_REPLACE if not use_extended else UNET_EXTENDED_TARGET_REPLACE
if target_type == "patch":
return monkeypatch_or_replace_lora_extended if use_extended else monkeypatch_or_replace_lora


0 comments on commit c8cf6e7

Please sign in to comment.