Skip to content

Commit

Permalink
Add pickle handlers for loras
Browse files Browse the repository at this point in the history
  • Loading branch information
ExponentialML committed Jan 30, 2023
1 parent 49bcd84 commit 11facf4
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion lora_diffusion/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import disable_safe_unpickle, enable_safe_unpickle

try:
from safetensors.torch import safe_open
Expand Down Expand Up @@ -223,6 +224,7 @@ def inject_trainable_lora(
inject lora into model, and returns lora parameter groups.
"""

disable_safe_unpickle()
require_grad_params = []
names = []

Expand Down Expand Up @@ -259,6 +261,7 @@ def inject_trainable_lora(
_module._modules[name].lora_down.weight.requires_grad = True
names.append(name)

enable_safe_unpickle()
return require_grad_params, names


Expand All @@ -271,7 +274,7 @@ def inject_trainable_lora_extended(
"""
inject lora into model, and returns lora parameter groups.
"""

disable_safe_unpickle()
require_grad_params = []
names = []

Expand Down Expand Up @@ -330,6 +333,7 @@ def inject_trainable_lora_extended(
_module._modules[name].lora_down.weight.requires_grad = True
names.append(name)

enable_safe_unpickle()
return require_grad_params, names


Expand Down Expand Up @@ -441,7 +445,9 @@ def convert_loras_to_safeloras_with_embeds(
for name, (path, target_replace_module, r) in modelmap.items():
metadata[name] = json.dumps(list(target_replace_module))

disable_safe_unpickle()
lora = torch.load(path)
enable_safe_unpickle()
for i, weight in enumerate(lora):
is_up = i % 2 == 0
i = i // 2
Expand Down Expand Up @@ -879,7 +885,9 @@ def load_learned_embed_in_clip(
token: Optional[Union[str, List[str]]] = None,
idempotent=False,
):
disable_safe_unpickle()
learned_embeds = torch.load(learned_embeds_path)
enable_safe_unpickle()
apply_learned_embed_in_clip(
learned_embeds, text_encoder, tokenizer, token, idempotent
)
Expand Down Expand Up @@ -911,6 +919,7 @@ def patch_pipe(
ti_path = _ti_lora_path(unet_path)
text_path = _text_lora_path_ui(unet_path)

disable_safe_unpickle()
if patch_unet:
print("LoRA : Patching Unet")
monkeypatch_or_replace_lora(
Expand All @@ -928,6 +937,7 @@ def patch_pipe(
target_replace_module=text_target_replace_module,
r=txt_r,
)
enable_safe_unpickle()
if patch_ti:
print("LoRA : Patching token input")
token = load_learned_embed_in_clip(
Expand Down

0 comments on commit 11facf4

Please sign in to comment.