From fa79c7328fa5170dbd15b3447928de82494a5d41 Mon Sep 17 00:00:00 2001 From: Nicholas Bardy Date: Wed, 4 Oct 2023 11:37:52 -0700 Subject: [PATCH 01/15] Update coca_model.py --- src/open_clip/coca_model.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index ad81fb665..eea172617 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -139,13 +139,21 @@ def _encode_text(self, text, normalize=True, embed_cls=True): text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent return text_latent, token_emb - def encode_image(self, images, normalize=True): - image_latent, _ = self._encode_image(images, normalize=normalize) - return image_latent - - def encode_text(self, text, normalize=True, embed_cls=True): - text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls) - return text_latent + def encode_image(self, images, normalize=True, return_embedding=False): + image_latent, token_emb = self._encode_image(images, normalize=normalize) + + if return_embedding: + return text_latent, token_emb + else: + return text_latent + + def encode_text(self, text, normalize=True, embed_cls=True, return_embedding=False): + text_latent, token_emb = self._encode_text(text, normalize=normalize, embed_cls=embed_cls) + + if return_embedding: + return text_latent, token_emb + else: + return text_latent def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None): text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls) From 6880041aa5d238ea3aef809d6fa68c3decb73d5b Mon Sep 17 00:00:00 2001 From: Nicholas Bardy Date: Thu, 5 Oct 2023 16:55:23 -0700 Subject: [PATCH 02/15] openclip --- src/open_clip/coca_model.py | 53 +++++++++++++++------------ src/open_clip/model.py | 55 +++++++++++++++++++++++----- src/open_clip/transformer.py | 70 ++++++++++++++++++++++++++++-------- tests/test_hidden_states.py | 58 ++++++++++++++++++++++++++++++ 4 files changed, 192 insertions(+), 44 deletions(-) create mode 100644 tests/test_hidden_states.py diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index eea172617..1fe41b4c6 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -11,6 +11,7 @@ LayerNorm, QuickGELU, MultimodalTransformer, + TransformerOutput ) from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower @@ -128,32 +129,40 @@ def set_grad_checkpointing(self, enable=True): self.text.set_grad_checkpointing(enable) self.text_decoder.set_grad_checkpointing(enable) - def _encode_image(self, images, normalize=True): - image_latent, tokens_embs = self.visual(images) - image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent - return image_latent, tokens_embs + def _encode_image(self, images, normalize=True, output_hidden_states=False): + result = self.visual(images, output_hidden_states=output_hidden_states) - def _encode_text(self, text, normalize=True, embed_cls=True): - text = text[:, :-1] if embed_cls else text # make space for CLS token - text_latent, token_emb = self.text(text) - text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent - return text_latent, token_emb - - def encode_image(self, images, normalize=True, return_embedding=False): - image_latent, token_emb = self._encode_image(images, normalize=normalize) + image_latent = result[0] + if normalize: + image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent - if return_embedding: - return text_latent, token_emb - else: - return text_latent + + return TransformerOutput( + pooled=image_latent, + tokens=result[1], + hidden_states=result[2] + ).value() - def encode_text(self, text, normalize=True, embed_cls=True, return_embedding=False): - text_latent, token_emb = self._encode_text(text, normalize=normalize, embed_cls=embed_cls) + def _encode_text(self, text, normalize=True, embed_cls=True, output_hidden_states=False): + text = text[:, :-1] if embed_cls else text # make space for CLS token + # text_latent, token_emb = self.text(text) + result = self.text(text, output_hidden_states=output_hidden_states) + text_latent = result[0] + if normalize: + text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent + + return TransformerOutput( + pooled=text_latent, + tokens=result[1], + hidden_states=result[2] + ).value() + + def encode_image(self, images, normalize=True, output_hidden_states=False): + return self._encode_image(images, normalize=normalize, output_hidden_states=output_hidden_states) + + def encode_text(self, text, normalize=True, embed_cls=True, output_hidden_states=False): + return self._encode_text(text, normalize=normalize, embed_cls=embed_cls, output_hidden_states=output_hidden_states) - if return_embedding: - return text_latent, token_emb - else: - return text_latent def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None): text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls) diff --git a/src/open_clip/model.py b/src/open_clip/model.py index 2fcafa227..570d0bbf6 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -16,7 +16,7 @@ from .hf_model import HFTextEncoder from .modified_resnet import ModifiedResNet from .timm_model import TimmModel -from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer +from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer, TransformerOutput from .utils import to_2tuple @@ -87,7 +87,8 @@ def _build_vision_tower( embed_dim: int, vision_cfg: CLIPVisionCfg, quick_gelu: bool = False, - cast_dtype: Optional[torch.dtype] = None + cast_dtype: Optional[torch.dtype] = None, + output_hidden_states: bool = False, ): if isinstance(vision_cfg, dict): vision_cfg = CLIPVisionCfg(**vision_cfg) @@ -98,6 +99,9 @@ def _build_vision_tower( act_layer = QuickGELU if quick_gelu else nn.GELU if vision_cfg.timm_model_name: + if output_hidden_states: + raise ValueError("output_hidden_states not supported with timm models") + visual = TimmModel( vision_cfg.timm_model_name, pretrained=vision_cfg.timm_model_pretrained, @@ -112,6 +116,10 @@ def _build_vision_tower( ) elif isinstance(vision_cfg.layers, (tuple, list)): vision_heads = vision_cfg.width * 32 // vision_cfg.head_width + + if output_hidden_states: + raise ValueError("output_hidden_states not supported with ModifiedResNet") + visual = ModifiedResNet( layers=vision_cfg.layers, output_dim=embed_dim, @@ -140,6 +148,7 @@ def _build_vision_tower( output_dim=embed_dim, act_layer=act_layer, norm_layer=norm_layer, + output_hidden_states=output_hidden_states, ) return visual @@ -150,6 +159,7 @@ def _build_text_tower( text_cfg: CLIPTextCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, + output_hidden_states: bool = False, ): if isinstance(text_cfg, dict): text_cfg = CLIPTextCfg(**text_cfg) @@ -167,6 +177,7 @@ def _build_text_tower( act_layer = QuickGELU if quick_gelu else nn.GELU norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + text = TextTransformer( context_length=text_cfg.context_length, vocab_size=text_cfg.vocab_size, @@ -227,23 +238,51 @@ def set_grad_checkpointing(self, enable=True): self.visual.set_grad_checkpointing(enable) self.transformer.grad_checkpointing = enable - def encode_image(self, image, normalize: bool = False): - features = self.visual(image) - return F.normalize(features, dim=-1) if normalize else features + def encode_image(self, image, normalize: bool = False, output_hidden_states: bool = False): + result = self.visual(image, output_hidden_states=output_hidden_states) + if output_hidden_states: + features = result[0] + hidden_states = result[1] + else: + features = result + hidden_states = None - def encode_text(self, text, normalize: bool = False): + if normalize: + features = F.normalize(features, dim=-1) + + return TransformerOutput(features, hidden_states) + + def encode_text(self, text, normalize: bool = False, output_hidden_states: bool = False): + # TODO: Why is this all here? We should just use the TextTransformer + # method it already does this. Why do we unwrap the transformer and + # then rewrap it? cast_dtype = self.transformer.get_cast_dtype() + encoder_states = [] if output_hidden_states else None + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] x = x + self.positional_embedding.to(cast_dtype) x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x, attn_mask=self.attn_mask) + + if output_hidden_states: + encoder_states.append(x) + + x = self.transformer(x, attn_mask=self.attn_mask, output_hidden_states=output_hidden_states) + + if output_hidden_states: + x = x[0] + encoder_states.extend(x[1]) + x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection - return F.normalize(x, dim=-1) if normalize else x + + if normalize: + x = F.normalize(x, dim=-1) + + return TransformerOutput(x, dim=-1, hidden_states=encoder_states) def forward( self, diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 0a30e9466..372cbec6a 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -10,6 +10,28 @@ from .utils import to_2tuple + +# Standardize the output of the model so we can optionally return: +# hidden states +# tokens +class TransformerOutput(torch.nn.Module): + def __init__(self, pooled, tokens, hidden_states): + self.pooled = pooled + self.tokens = tokens + self.hidden_states = hidden_states + + # Get value + def value(self): + if self.output_tokens and self.output_hidden_states: + return self.pooled, self.tokens, self.hidden_states + + if self.output_tokens: + return self.pooled, self.tokens, None + + if self.output_hidden_states: + return self.pooled, None, self.hidden_states + + class LayerNormFp32(nn.LayerNorm): """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" @@ -235,6 +257,8 @@ def forward( k_x: Optional[torch.Tensor] = None, v_x: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, ): k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None @@ -312,14 +336,23 @@ def get_cast_dtype(self) -> torch.dtype: return self.resblocks[0].mlp.c_fc.int8_original_dtype return self.resblocks[0].mlp.c_fc.weight.dtype - def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, output_hidden_states: bool = False): + encoder_states = [] if output_hidden_states else None + + if output_hidden_states: + encoder_states.append(x) + for r in self.resblocks: if self.grad_checkpointing and not torch.jit.is_scripting(): # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 x = checkpoint(r, x, None, None, attn_mask) else: x = r(x, attn_mask=attn_mask) - return x + + if output_hidden_states: + encoder_states.append(x) + + return TransformerOutput(x, encoder_states).value() class VisionTransformer(nn.Module): @@ -457,8 +490,7 @@ def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: else: return x[:, 0], x[:, 1:] - def forward(self, x: torch.Tensor): - + def forward(self, x: torch.Tensor, output_hidden_states: bool = False): # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1 if self.input_patchnorm: # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') @@ -483,7 +515,11 @@ def forward(self, x: torch.Tensor): x = self.ln_pre(x) x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x) + x = self.transformer(x, output_hidden_states=output_hidden_states) + if output_hidden_states: + x = x[0] + hidden_states = x[1] + x = x.permute(1, 0, 2) # LND -> NLD if self.attn_pool is not None: @@ -497,10 +533,9 @@ def forward(self, x: torch.Tensor): if self.proj is not None: pooled = pooled @ self.proj - if self.output_tokens: - return pooled, tokens - return pooled + return TransformerOutput(pooled, tokens, hidden_states).value() + class TextTransformer(nn.Module): @@ -596,11 +631,12 @@ def build_cls_mask(self, text, cast_dtype: torch.dtype): def _repeat(self, t, N: int): return t.reshape(1, 1, -1).repeat(N, 1, 1) - def forward(self, text): + def forward(self, text, output_hidden_states: bool = False): cast_dtype = self.transformer.get_cast_dtype() seq_len = text.shape[1] x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + attn_mask = self.attn_mask if self.cls_emb is not None: seq_len += 1 @@ -610,7 +646,12 @@ def forward(self, text): x = x + self.positional_embedding[:seq_len].to(cast_dtype) x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x, attn_mask=attn_mask) + x = self.transformer(x, attn_mask=attn_mask, output_hidden_states=output_hidden_states) + + if output_hidden_states: + x = x[0] + hidden_states = x[1] + x = x.permute(1, 0, 2) # LND -> NLD # x.shape = [batch_size, n_ctx, transformer.width] @@ -625,10 +666,11 @@ def forward(self, text): if self.text_projection is not None: pooled = pooled @ self.text_projection - if self.output_tokens: - return pooled, tokens - - return pooled + return TransformerOutput( + pooled=pooled, + tokens=tokens, + encoder_states=hidden_states, + ).value() class MultimodalTransformer(Transformer): diff --git a/tests/test_hidden_states.py b/tests/test_hidden_states.py new file mode 100644 index 000000000..8ab4c6499 --- /dev/null +++ b/tests/test_hidden_states.py @@ -0,0 +1,58 @@ +import torch +from PIL import Image +from open_clip.factory import get_tokenizer +import pytest +import open_clip +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "" + + +if hasattr(torch._C, '_jit_set_profiling_executor'): + # legacy executor is too slow to compile large models for unit tests + # no need for the fusion performance here + torch._C._jit_set_profiling_executor(True) + torch._C._jit_set_profiling_mode(False) + + +test_simple_models = [ + # model, pretrained, jit, force_custom_text + ("ViT-B-32", "laion2b_s34b_b79k", False, False), + # TODO: Add test coca +] + + +@pytest.mark.parametrize("model_type,pretrained,jit,force_custom_text", test_simple_models) +def test_inference_simple( + model_type, + pretrained, + jit, + force_custom_text, +): + model, _, preprocess = open_clip.create_model_and_transforms( + model_type, + pretrained=pretrained, + jit=jit, + force_custom_text=force_custom_text, + ) + tokenizer = get_tokenizer(model_type) + + current_dir = os.path.dirname(os.path.realpath(__file__)) + + image = preprocess(Image.open(current_dir + "/../docs/CLIP.png")).unsqueeze(0) + text = tokenizer(["a diagram", "a dog", "a cat"]) + + with torch.no_grad(): + image_result = model.encode_image(image, output_hidden_states=True) + text_result = model.encode_text(text, output_hidden_states=True) + + image_features = image_result.features + text_features = text_result.features + + text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) + + print(text_result.hidden_states.shape) + print(image_result.hidden_states.shape) + + # TODO: Write hidden state shapes assertions + + assert text_probs.cpu().numpy()[0].tolist() == [1.0, 0.0, 0.0] From 0adfd7833166ff46338f06ad97589a510f5565be Mon Sep 17 00:00:00 2001 From: ggrigorev Date: Wed, 1 Nov 2023 14:24:11 +0000 Subject: [PATCH 03/15] fix bugs after the merge --- src/open_clip/model.py | 3 --- src/open_clip/transformer.py | 6 ++++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/open_clip/model.py b/src/open_clip/model.py index 460de22f3..709bfe895 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -283,9 +283,6 @@ def encode_image(self, image, normalize: bool = False, output_hidden_states: boo return TransformerOutput(features, hidden_states) - - return TransformerOutput(features, hidden_states) - def encode_text(self, text, normalize: bool = False, output_hidden_states: bool = False): # TODO: Why is this all here? We should just use the TextTransformer # method it already does this. Why do we unwrap the transformer and diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 57d54ed8e..8ef7ab55f 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -532,7 +532,7 @@ def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: return pooled, tokens - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor, output_hidden_states=False): x = self.conv1(x) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] @@ -727,6 +727,8 @@ def forward(self, text, output_hidden_states: bool = False): if output_hidden_states: x = x[0] hidden_states = x[1] + else: + hidden_states = None x = x.permute(1, 0, 2) # LND -> NLD @@ -751,7 +753,7 @@ def forward(self, text, output_hidden_states: bool = False): return TransformerOutput( pooled=pooled, tokens=tokens, - encoder_states=hidden_states, + hidden_states=hidden_states, ).value() From 21cc2b6ee2af6aaec172a704b47a97710a9d22bb Mon Sep 17 00:00:00 2001 From: ggrigorev Date: Wed, 1 Nov 2023 14:28:04 +0000 Subject: [PATCH 04/15] add output hidden states flag into init --- src/open_clip/transformer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 8ef7ab55f..6d7c68704 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -382,10 +382,12 @@ def __init__( act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, output_tokens: bool = False, + output_hidden_states: bool = False ): super().__init__() assert pool_type in ('tok', 'avg', 'none') self.output_tokens = output_tokens + self.output_hidden_states = output_hidden_states image_height, image_width = self.image_size = to_2tuple(image_size) patch_height, patch_width = self.patch_size = to_2tuple(patch_size) self.grid_size = (image_height // patch_height, image_width // patch_width) From 5584e3d51d54531e5e3cb354542007b8676b5140 Mon Sep 17 00:00:00 2001 From: ggrigorev Date: Wed, 1 Nov 2023 14:32:00 +0000 Subject: [PATCH 05/15] fix transformer call --- src/open_clip/transformer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 6d7c68704..edee6bb6f 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -355,7 +355,9 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, out if output_hidden_states: encoder_states.append(x) - return TransformerOutput(x, encoder_states).value() + if output_hidden_states: + return x, encoder_states + return x class VisionTransformer(nn.Module): @@ -552,6 +554,8 @@ def forward(self, x: torch.Tensor, output_hidden_states=False): if output_hidden_states: x = x[0] hidden_states = x[1] + else: + hidden_states = None x = x.permute(1, 0, 2) # LND -> NLD From 6ccc07dcd980a4fd16c480a1446300f3f50ebe09 Mon Sep 17 00:00:00 2001 From: ggrigorev Date: Wed, 1 Nov 2023 14:33:26 +0000 Subject: [PATCH 06/15] reformat value method --- src/open_clip/transformer.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index edee6bb6f..7e2f251de 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -12,10 +12,6 @@ from .pos_embed import get_2d_sincos_pos_embed - -# Standardize the output of the model so we can optionally return: -# hidden states -# tokens class TransformerOutput(torch.nn.Module): def __init__(self, pooled, tokens, hidden_states): self.pooled = pooled @@ -24,14 +20,7 @@ def __init__(self, pooled, tokens, hidden_states): # Get value def value(self): - if self.output_tokens and self.output_hidden_states: - return self.pooled, self.tokens, self.hidden_states - - if self.output_tokens: - return self.pooled, self.tokens, None - - if self.output_hidden_states: - return self.pooled, None, self.hidden_states + return self.pooled, self.tokens, self.hidden_states class LayerNormFp32(nn.LayerNorm): From e8d6dcebe819248f3aa9077bc7486fcc994f973c Mon Sep 17 00:00:00 2001 From: ggrigorev Date: Wed, 1 Nov 2023 14:39:32 +0000 Subject: [PATCH 07/15] add logging --- src/open_clip/transformer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 7e2f251de..c3e57a45b 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -332,6 +332,7 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, out encoder_states = [] if output_hidden_states else None if output_hidden_states: + print("encoder states len", len(encoder_states)) encoder_states.append(x) for r in self.resblocks: @@ -342,9 +343,11 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, out x = r(x, attn_mask=attn_mask) if output_hidden_states: + print("encoder states len", len(encoder_states)) encoder_states.append(x) if output_hidden_states: + print("encoder states len", len(encoder_states)) return x, encoder_states return x From 6abb4994489ab1d92f078954cee024dfd077d96a Mon Sep 17 00:00:00 2001 From: ggrigorev Date: Wed, 1 Nov 2023 14:45:03 +0000 Subject: [PATCH 08/15] simplify --- src/open_clip/transformer.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index c3e57a45b..5aade0016 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -332,7 +332,6 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, out encoder_states = [] if output_hidden_states else None if output_hidden_states: - print("encoder states len", len(encoder_states)) encoder_states.append(x) for r in self.resblocks: @@ -343,11 +342,9 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, out x = r(x, attn_mask=attn_mask) if output_hidden_states: - print("encoder states len", len(encoder_states)) encoder_states.append(x) if output_hidden_states: - print("encoder states len", len(encoder_states)) return x, encoder_states return x @@ -542,11 +539,12 @@ def forward(self, x: torch.Tensor, output_hidden_states=False): x = self.ln_pre(x) x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x, output_hidden_states=output_hidden_states) + transformer_out = self.transformer(x, output_hidden_states=output_hidden_states) if output_hidden_states: - x = x[0] - hidden_states = x[1] + x, hidden_states = transformer_out + assert isinstance(hidden_states, list) else: + x = transformer_out hidden_states = None x = x.permute(1, 0, 2) # LND -> NLD @@ -579,7 +577,7 @@ def forward(self, x: torch.Tensor, output_hidden_states=False): # if self.output_tokens: # return pooled, tokens - return TransformerOutput(pooled, tokens, hidden_states).value() + return pooled, tokens, hidden_states From 847a70284f1ff39ae580ee8891c58ce43188da94 Mon Sep 17 00:00:00 2001 From: ggrigorev Date: Wed, 1 Nov 2023 14:49:44 +0000 Subject: [PATCH 09/15] apply permute to the transformer out and pass output_hidden_states to the init --- src/open_clip/transformer.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 5aade0016..66b46ef23 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -543,6 +543,7 @@ def forward(self, x: torch.Tensor, output_hidden_states=False): if output_hidden_states: x, hidden_states = transformer_out assert isinstance(hidden_states, list) + hidden_states = [h.permute(1, 0, 2) for h in hidden_states] else: x = transformer_out hidden_states = None @@ -574,10 +575,12 @@ def forward(self, x: torch.Tensor, output_hidden_states=False): if self.proj is not None: pooled = pooled @ self.proj - # if self.output_tokens: - # return pooled, tokens - - return pooled, tokens, hidden_states + if self.output_tokens: + return pooled, tokens + elif self.output_hidden_states: + return pooled, tokens, hidden_states + else: + return pooled @@ -718,12 +721,14 @@ def forward(self, text, output_hidden_states: bool = False): x = x + self.positional_embedding[:seq_len].to(cast_dtype) x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x, attn_mask=attn_mask, output_hidden_states=output_hidden_states) + transformer_out = self.transformer(x, attn_mask=attn_mask, output_hidden_states=output_hidden_states) if output_hidden_states: - x = x[0] - hidden_states = x[1] + x, hidden_states = transformer_out + assert isinstance(hidden_states, list) + hidden_states = [h.permute(1, 0, 2) for h in hidden_states] else: + x = transformer_out hidden_states = None x = x.permute(1, 0, 2) # LND -> NLD From 858e8c35c2d127763dc3c501c77b9a1ddea7516e Mon Sep 17 00:00:00 2001 From: ggrigorev Date: Wed, 1 Nov 2023 14:54:46 +0000 Subject: [PATCH 10/15] do not pass output_hidden_states to method and instead use attribute --- src/open_clip/transformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 66b46ef23..7dcb34ada 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -525,7 +525,7 @@ def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: return pooled, tokens - def forward(self, x: torch.Tensor, output_hidden_states=False): + def forward(self, x: torch.Tensor): x = self.conv1(x) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] @@ -539,8 +539,8 @@ def forward(self, x: torch.Tensor, output_hidden_states=False): x = self.ln_pre(x) x = x.permute(1, 0, 2) # NLD -> LND - transformer_out = self.transformer(x, output_hidden_states=output_hidden_states) - if output_hidden_states: + transformer_out = self.transformer(x, output_hidden_states=self.output_hidden_states) + if self.output_hidden_states: x, hidden_states = transformer_out assert isinstance(hidden_states, list) hidden_states = [h.permute(1, 0, 2) for h in hidden_states] From 4eb5a11cb2a8b397897c73e1d62a7782dc9e1450 Mon Sep 17 00:00:00 2001 From: ggrigorev Date: Wed, 1 Nov 2023 15:04:22 +0000 Subject: [PATCH 11/15] pass hidden size to VisualTransformer --- src/open_clip/transformer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 7dcb34ada..a38ef2b1c 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -379,6 +379,8 @@ def __init__( assert pool_type in ('tok', 'avg', 'none') self.output_tokens = output_tokens self.output_hidden_states = output_hidden_states + self.hidden_size = width + image_height, image_width = self.image_size = to_2tuple(image_size) patch_height, patch_width = self.patch_size = to_2tuple(patch_size) self.grid_size = (image_height // patch_height, image_width // patch_width) From 7988d5a0de081335f75dfb84ad5dcfc4cf20cec7 Mon Sep 17 00:00:00 2001 From: thepowerfuldeez Date: Thu, 2 Nov 2023 12:56:55 +0000 Subject: [PATCH 12/15] refactor the code and remove TransformerOutput class --- src/open_clip/model.py | 14 +++++---- src/open_clip/transformer.py | 61 +++++++++++++++++++----------------- 2 files changed, 41 insertions(+), 34 deletions(-) diff --git a/src/open_clip/model.py b/src/open_clip/model.py index 709bfe895..54a269d17 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -19,7 +19,7 @@ from .modified_resnet import ModifiedResNet from .timm_model import TimmModel from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer,\ - text_global_pool, TransformerOutput + text_global_pool from .utils import to_2tuple @@ -281,12 +281,11 @@ def encode_image(self, image, normalize: bool = False, output_hidden_states: boo if normalize: features = F.normalize(features, dim=-1) - return TransformerOutput(features, hidden_states) + if output_hidden_states: + return features, hidden_states + return features def encode_text(self, text, normalize: bool = False, output_hidden_states: bool = False): - # TODO: Why is this all here? We should just use the TextTransformer - # method it already does this. Why do we unwrap the transformer and - # then rewrap it? cast_dtype = self.transformer.get_cast_dtype() encoder_states = [] if output_hidden_states else None @@ -296,6 +295,7 @@ def encode_text(self, text, normalize: bool = False, output_hidden_states: bool x = x + self.positional_embedding.to(cast_dtype) x = x.permute(1, 0, 2) # NLD -> LND + # we use text embedding as first hidden state in a list if output_hidden_states: encoder_states.append(x) @@ -317,7 +317,9 @@ def encode_text(self, text, normalize: bool = False, output_hidden_states: bool if normalize: x = F.normalize(x, dim=-1) - return TransformerOutput(x, dim=-1, hidden_states=encoder_states) + if output_hidden_states: + return x, encoder_states + return x def forward( self, diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index a38ef2b1c..a19963b6c 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -12,17 +12,6 @@ from .pos_embed import get_2d_sincos_pos_embed -class TransformerOutput(torch.nn.Module): - def __init__(self, pooled, tokens, hidden_states): - self.pooled = pooled - self.tokens = tokens - self.hidden_states = hidden_states - - # Get value - def value(self): - return self.pooled, self.tokens, self.hidden_states - - class LayerNormFp32(nn.LayerNorm): """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" @@ -329,8 +318,14 @@ def get_cast_dtype(self) -> torch.dtype: return self.resblocks[0].mlp.c_fc.weight.dtype def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, output_hidden_states: bool = False): + """ + :param x: shape = [bs, seq_len, width] + :param attn_mask: shape = [bs, seq_len, seq_len] + :param output_hidden_states: bool + """ encoder_states = [] if output_hidden_states else None + # collect encoder states in a list if output_hidden_states: encoder_states.append(x) @@ -543,6 +538,7 @@ def forward(self, x: torch.Tensor): x = x.permute(1, 0, 2) # NLD -> LND transformer_out = self.transformer(x, output_hidden_states=self.output_hidden_states) if self.output_hidden_states: + # transformer_out is a tuple of (tokens, hidden_states) x, hidden_states = transformer_out assert isinstance(hidden_states, list) hidden_states = [h.permute(1, 0, 2) for h in hidden_states] @@ -577,13 +573,16 @@ def forward(self, x: torch.Tensor): if self.proj is not None: pooled = pooled @ self.proj - if self.output_tokens: - return pooled, tokens - elif self.output_hidden_states: - return pooled, tokens, hidden_states + if self.output_hidden_states: + if self.output_tokens: + return pooled, tokens, hidden_states + else: + return pooled, hidden_states else: - return pooled - + if self.output_tokens: + return pooled, tokens + else: + return pooled def text_global_pool(x, text: Optional[torch.Tensor] = None, pool_type: str = 'argmax'): @@ -622,10 +621,13 @@ def __init__( act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, output_tokens: bool = False, + output_hidden_states: bool = False, ): super().__init__() assert pool_type in ('first', 'last', 'argmax', 'none') self.output_tokens = output_tokens + self.output_hidden_states = output_hidden_states + self.hidden_size = width self.num_pos = self.context_length = context_length self.vocab_size = vocab_size self.width = width @@ -708,7 +710,7 @@ def build_cls_mask(self, text, cast_dtype: torch.dtype): additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) return additive_mask - def forward(self, text, output_hidden_states: bool = False): + def forward(self, text): cast_dtype = self.transformer.get_cast_dtype() seq_len = text.shape[1] @@ -723,9 +725,10 @@ def forward(self, text, output_hidden_states: bool = False): x = x + self.positional_embedding[:seq_len].to(cast_dtype) x = x.permute(1, 0, 2) # NLD -> LND - transformer_out = self.transformer(x, attn_mask=attn_mask, output_hidden_states=output_hidden_states) - if output_hidden_states: + transformer_out = self.transformer(x, output_hidden_states=self.output_hidden_states) + if self.output_hidden_states: + # transformer_out is a tuple of (tokens, hidden_states) x, hidden_states = transformer_out assert isinstance(hidden_states, list) hidden_states = [h.permute(1, 0, 2) for h in hidden_states] @@ -750,14 +753,16 @@ def forward(self, text, output_hidden_states: bool = False): else: pooled = pooled @ self.text_projection - # if self.output_tokens: - # return pooled, tokens - - return TransformerOutput( - pooled=pooled, - tokens=tokens, - hidden_states=hidden_states, - ).value() + if self.output_hidden_states: + if self.output_tokens: + return pooled, tokens, hidden_states + else: + return pooled, hidden_states + else: + if self.output_tokens: + return pooled, tokens + else: + return pooled class MultimodalTransformer(Transformer): From 6e971cca0402a56a989f5220418cc1ee5be3ac5f Mon Sep 17 00:00:00 2001 From: thepowerfuldeez Date: Thu, 2 Nov 2023 13:08:00 +0000 Subject: [PATCH 13/15] allow override output_hidden_states to forward call --- src/open_clip/transformer.py | 14 ++++++-------- tests/test_hidden_states.py | 8 ++++---- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index a19963b6c..c28d1b6af 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -234,8 +234,6 @@ def forward( k_x: Optional[torch.Tensor] = None, v_x: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - output_hidden_states: bool = False, ): k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None @@ -522,7 +520,7 @@ def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: return pooled, tokens - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor, output_hidden_states: bool = False): x = self.conv1(x) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] @@ -536,8 +534,8 @@ def forward(self, x: torch.Tensor): x = self.ln_pre(x) x = x.permute(1, 0, 2) # NLD -> LND - transformer_out = self.transformer(x, output_hidden_states=self.output_hidden_states) - if self.output_hidden_states: + transformer_out = self.transformer(x, output_hidden_states=self.output_hidden_states or output_hidden_states) + if self.output_hidden_states or output_hidden_states: # transformer_out is a tuple of (tokens, hidden_states) x, hidden_states = transformer_out assert isinstance(hidden_states, list) @@ -710,7 +708,7 @@ def build_cls_mask(self, text, cast_dtype: torch.dtype): additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) return additive_mask - def forward(self, text): + def forward(self, text, output_hidden_states: bool = False): cast_dtype = self.transformer.get_cast_dtype() seq_len = text.shape[1] @@ -726,8 +724,8 @@ def forward(self, text): x = x + self.positional_embedding[:seq_len].to(cast_dtype) x = x.permute(1, 0, 2) # NLD -> LND - transformer_out = self.transformer(x, output_hidden_states=self.output_hidden_states) - if self.output_hidden_states: + transformer_out = self.transformer(x, output_hidden_states=self.output_hidden_states or output_hidden_states) + if self.output_hidden_states or output_hidden_states: # transformer_out is a tuple of (tokens, hidden_states) x, hidden_states = transformer_out assert isinstance(hidden_states, list) diff --git a/tests/test_hidden_states.py b/tests/test_hidden_states.py index 8ab4c6499..ae4c8f2a7 100644 --- a/tests/test_hidden_states.py +++ b/tests/test_hidden_states.py @@ -45,13 +45,13 @@ def test_inference_simple( image_result = model.encode_image(image, output_hidden_states=True) text_result = model.encode_text(text, output_hidden_states=True) - image_features = image_result.features - text_features = text_result.features + image_features, image_hidden_states = image_result + text_features, text_hidden_states = text_result text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) - print(text_result.hidden_states.shape) - print(image_result.hidden_states.shape) + print(text_hidden_states.shape) + print(image_hidden_states.shape) # TODO: Write hidden state shapes assertions From 19fd3bd7a6dda309ccdb06020fbd1dc799d3071e Mon Sep 17 00:00:00 2001 From: thepowerfuldeez Date: Thu, 2 Nov 2023 13:12:11 +0000 Subject: [PATCH 14/15] fix bug with return --- src/open_clip/transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index c28d1b6af..9c89bda69 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -571,7 +571,7 @@ def forward(self, x: torch.Tensor, output_hidden_states: bool = False): if self.proj is not None: pooled = pooled @ self.proj - if self.output_hidden_states: + if self.output_hidden_states or output_hidden_states: if self.output_tokens: return pooled, tokens, hidden_states else: @@ -751,7 +751,7 @@ def forward(self, text, output_hidden_states: bool = False): else: pooled = pooled @ self.text_projection - if self.output_hidden_states: + if self.output_hidden_states or output_hidden_states: if self.output_tokens: return pooled, tokens, hidden_states else: From 879718ea6919058f5720aac21af380cd27ddda60 Mon Sep 17 00:00:00 2001 From: thepowerfuldeez Date: Thu, 2 Nov 2023 13:18:31 +0000 Subject: [PATCH 15/15] update test case --- tests/test_hidden_states.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_hidden_states.py b/tests/test_hidden_states.py index ae4c8f2a7..13366de8b 100644 --- a/tests/test_hidden_states.py +++ b/tests/test_hidden_states.py @@ -50,8 +50,10 @@ def test_inference_simple( text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) - print(text_hidden_states.shape) - print(image_hidden_states.shape) + # shape of hidden states: [bs, n_hidden_states, 1, seq_len, hidden_size] + # take first elem + print(f"Length of text hidden states: {len(text_hidden_states[0])}") + print(f"Length of image hidden states: {len(image_hidden_states)}") # TODO: Write hidden state shapes assertions