From 38bcad6e926c8a09f13b6dc92526b9102a7b32cc Mon Sep 17 00:00:00 2001 From: Rahul Somani Date: Tue, 18 Jun 2024 01:13:52 -0400 Subject: [PATCH] bugfix; properly freeze layernorm --- src/open_clip/transformer.py | 38 ++++++++++++++++++++---------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 8bf02bbb9..7b6df4431 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -811,25 +811,29 @@ def lock_text_transformer( for param in transformer.parameters(): param.requires_grad = False - groups = [ - [transformer.token_embedding, transformer.positional_embedding], - *transformer.transformer.resblocks[:-1], - [transformer.transformer.resblocks[-1], transformer.ln_final], - transformer.text_projection, - ] - - def _unlock(x): - if isinstance(x, Sequence): - for g in x: - _unlock(g) - else: - if isinstance(x, torch.nn.Parameter): - x.requires_grad = True + if unlocked_groups != 0: + groups = [ + [transformer.token_embedding, transformer.positional_embedding], + *transformer.transformer.resblocks[:-1], + [transformer.transformer.resblocks[-1], transformer.ln_final], + transformer.text_projection, + ] + + def _unlock(x): + if isinstance(x, Sequence): + for g in x: + _unlock(g) else: - for p in x.parameters(): - p.requires_grad = True + if isinstance(x, torch.nn.Parameter): + x.requires_grad = True + else: + for n,p in x.named_parameters(): + if n.startswith("ln_"): # If LayerNorm layer + p.requires_grad = False if freeze_layer_norm else True + else: + p.requires_grad = True - _unlock(groups[-unlocked_groups:]) + _unlock(groups[-unlocked_groups:]) class MultimodalTransformer(Transformer):