Skip to content

Commit

Permalink
fix(lora): fix conv lora
Browse files Browse the repository at this point in the history
  • Loading branch information
caopulan committed Jun 24, 2023
1 parent 55c2cf3 commit 7495c9b
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions unidiffusion/peft/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def forward(self, hidden_states):

class LoRAConvLayer(BaseLoRAModule):

def __init__(self, org_module: nn.Module, org_name: str, rank=4, network_alpha=None, multiplier=1.0, dropout=0., use_cp=False):
def __init__(self, org_module: nn.Module, org_name: str, rank=4, network_alpha=None, scale=1.0, dropout=0.,
use_cp=False):
assert isinstance(org_module, nn.Conv2d)
super().__init__(org_module, org_name)
in_dim = org_module.in_channels
Expand All @@ -101,15 +102,14 @@ def __init__(self, org_module: nn.Module, org_name: str, rank=4, network_alpha=N
if type(network_alpha) == torch.Tensor:
network_alpha = network_alpha.detach().float().numpy() # without casting, bf16 causes error
network_alpha = rank if network_alpha is None or network_alpha == 0 else network_alpha
self.scale = network_alpha / rank
self.scale = scale
self.register_buffer('alpha', torch.tensor(network_alpha))

# same as microsoft's
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
torch.nn.init.zeros_(self.lora_up.weight)
if self.cp:
torch.nn.init.kaiming_uniform_(self.lora_mid.weight, a=math.sqrt(5))
self.multiplier = multiplier

self.apply_to()

Expand All @@ -121,11 +121,11 @@ def make_weight(self):
def forward(self, x):
if self.cp:
return self.org_forward(x) + self.dropout(
self.lora_up(self.lora_mid(self.lora_down(x)))* self.multiplier * self.scale
self.lora_up(self.lora_mid(self.lora_down(x))) * self.scale
)
else:
return self.org_forward(x) + self.dropout(
self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
self.lora_up(self.lora_down(x)) * self.scale
)


Expand Down

0 comments on commit 7495c9b

Please sign in to comment.