Skip to content

Commit

Permalink
bugfix; properly freeze layernorm
Browse files Browse the repository at this point in the history
  • Loading branch information
rsomani95 committed Jun 18, 2024
1 parent 01c3cd5 commit 38bcad6
Showing 1 changed file with 21 additions and 17 deletions.
38 changes: 21 additions & 17 deletions src/open_clip/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,25 +811,29 @@ def lock_text_transformer(
for param in transformer.parameters():
param.requires_grad = False

groups = [
[transformer.token_embedding, transformer.positional_embedding],
*transformer.transformer.resblocks[:-1],
[transformer.transformer.resblocks[-1], transformer.ln_final],
transformer.text_projection,
]

def _unlock(x):
if isinstance(x, Sequence):
for g in x:
_unlock(g)
else:
if isinstance(x, torch.nn.Parameter):
x.requires_grad = True
if unlocked_groups != 0:
groups = [
[transformer.token_embedding, transformer.positional_embedding],
*transformer.transformer.resblocks[:-1],
[transformer.transformer.resblocks[-1], transformer.ln_final],
transformer.text_projection,
]

def _unlock(x):
if isinstance(x, Sequence):
for g in x:
_unlock(g)
else:
for p in x.parameters():
p.requires_grad = True
if isinstance(x, torch.nn.Parameter):
x.requires_grad = True
else:
for n,p in x.named_parameters():
if n.startswith("ln_"): # If LayerNorm layer
p.requires_grad = False if freeze_layer_norm else True
else:
p.requires_grad = True

_unlock(groups[-unlocked_groups:])
_unlock(groups[-unlocked_groups:])


class MultimodalTransformer(Transformer):
Expand Down

0 comments on commit 38bcad6

Please sign in to comment.