diff --git a/src/open_clip/model.py b/src/open_clip/model.py index fe3aa31c9..54a269d17 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -105,7 +105,8 @@ def _build_vision_tower( embed_dim: int, vision_cfg: CLIPVisionCfg, quick_gelu: bool = False, - cast_dtype: Optional[torch.dtype] = None + cast_dtype: Optional[torch.dtype] = None, + output_hidden_states: bool = False, ): if isinstance(vision_cfg, dict): vision_cfg = CLIPVisionCfg(**vision_cfg) @@ -130,6 +131,10 @@ def _build_vision_tower( ) elif isinstance(vision_cfg.layers, (tuple, list)): vision_heads = vision_cfg.width * 32 // vision_cfg.head_width + + if output_hidden_states: + raise ValueError("output_hidden_states not supported with ModifiedResNet") + visual = ModifiedResNet( layers=vision_cfg.layers, output_dim=embed_dim, @@ -165,6 +170,7 @@ def _build_vision_tower( output_dim=embed_dim, act_layer=act_layer, norm_layer=norm_layer, + output_hidden_states=output_hidden_states, ) return visual @@ -175,6 +181,7 @@ def _build_text_tower( text_cfg: CLIPTextCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, + output_hidden_states: bool = False, ): if isinstance(text_cfg, dict): text_cfg = CLIPTextCfg(**text_cfg) @@ -262,18 +269,42 @@ def set_grad_checkpointing(self, enable=True): self.visual.set_grad_checkpointing(enable) self.transformer.grad_checkpointing = enable - def encode_image(self, image, normalize: bool = False): - features = self.visual(image) - return F.normalize(features, dim=-1) if normalize else features + def encode_image(self, image, normalize: bool = False, output_hidden_states: bool = False): + result = self.visual(image, output_hidden_states=output_hidden_states) + if output_hidden_states: + features = result[0] + hidden_states = result[1] + else: + features = result + hidden_states = None - def encode_text(self, text, normalize: bool = False): + if normalize: + features = F.normalize(features, dim=-1) + + if output_hidden_states: + return features, hidden_states + return features + + def encode_text(self, text, normalize: bool = False, output_hidden_states: bool = False): cast_dtype = self.transformer.get_cast_dtype() + encoder_states = [] if output_hidden_states else None + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] x = x + self.positional_embedding.to(cast_dtype) x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x, attn_mask=self.attn_mask) + + # we use text embedding as first hidden state in a list + if output_hidden_states: + encoder_states.append(x) + + x = self.transformer(x, attn_mask=self.attn_mask, output_hidden_states=output_hidden_states) + + if output_hidden_states: + x = x[0] + encoder_states.extend(x[1]) + x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] x, _ = text_global_pool(x, text, self.text_pool_type) @@ -283,7 +314,12 @@ def encode_text(self, text, normalize: bool = False): else: x = x @ self.text_projection - return F.normalize(x, dim=-1) if normalize else x + if normalize: + x = F.normalize(x, dim=-1) + + if output_hidden_states: + return x, encoder_states + return x def forward( self, diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 6d4e604d8..9c89bda69 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -315,13 +315,30 @@ def get_cast_dtype(self) -> torch.dtype: return self.resblocks[0].mlp.c_fc.int8_original_dtype return self.resblocks[0].mlp.c_fc.weight.dtype - def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, output_hidden_states: bool = False): + """ + :param x: shape = [bs, seq_len, width] + :param attn_mask: shape = [bs, seq_len, seq_len] + :param output_hidden_states: bool + """ + encoder_states = [] if output_hidden_states else None + + # collect encoder states in a list + if output_hidden_states: + encoder_states.append(x) + for r in self.resblocks: if self.grad_checkpointing and not torch.jit.is_scripting(): # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 x = checkpoint(r, x, None, None, attn_mask) else: x = r(x, attn_mask=attn_mask) + + if output_hidden_states: + encoder_states.append(x) + + if output_hidden_states: + return x, encoder_states return x @@ -349,10 +366,14 @@ def __init__( act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, output_tokens: bool = False, + output_hidden_states: bool = False ): super().__init__() assert pool_type in ('tok', 'avg', 'none') self.output_tokens = output_tokens + self.output_hidden_states = output_hidden_states + self.hidden_size = width + image_height, image_width = self.image_size = to_2tuple(image_size) patch_height, patch_width = self.patch_size = to_2tuple(patch_size) self.grid_size = (image_height // patch_height, image_width // patch_width) @@ -499,7 +520,7 @@ def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: return pooled, tokens - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor, output_hidden_states: bool = False): x = self.conv1(x) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] @@ -513,7 +534,16 @@ def forward(self, x: torch.Tensor): x = self.ln_pre(x) x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x) + transformer_out = self.transformer(x, output_hidden_states=self.output_hidden_states or output_hidden_states) + if self.output_hidden_states or output_hidden_states: + # transformer_out is a tuple of (tokens, hidden_states) + x, hidden_states = transformer_out + assert isinstance(hidden_states, list) + hidden_states = [h.permute(1, 0, 2) for h in hidden_states] + else: + x = transformer_out + hidden_states = None + x = x.permute(1, 0, 2) # LND -> NLD if self.attn_pool is not None: @@ -541,10 +571,16 @@ def forward(self, x: torch.Tensor): if self.proj is not None: pooled = pooled @ self.proj - if self.output_tokens: - return pooled, tokens - - return pooled + if self.output_hidden_states or output_hidden_states: + if self.output_tokens: + return pooled, tokens, hidden_states + else: + return pooled, hidden_states + else: + if self.output_tokens: + return pooled, tokens + else: + return pooled def text_global_pool(x, text: Optional[torch.Tensor] = None, pool_type: str = 'argmax'): @@ -583,10 +619,13 @@ def __init__( act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, output_tokens: bool = False, + output_hidden_states: bool = False, ): super().__init__() assert pool_type in ('first', 'last', 'argmax', 'none') self.output_tokens = output_tokens + self.output_hidden_states = output_hidden_states + self.hidden_size = width self.num_pos = self.context_length = context_length self.vocab_size = vocab_size self.width = width @@ -669,7 +708,7 @@ def build_cls_mask(self, text, cast_dtype: torch.dtype): additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) return additive_mask - def forward(self, text): + def forward(self, text, output_hidden_states: bool = False): cast_dtype = self.transformer.get_cast_dtype() seq_len = text.shape[1] @@ -684,7 +723,17 @@ def forward(self, text): x = x + self.positional_embedding[:seq_len].to(cast_dtype) x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x, attn_mask=attn_mask) + + transformer_out = self.transformer(x, output_hidden_states=self.output_hidden_states or output_hidden_states) + if self.output_hidden_states or output_hidden_states: + # transformer_out is a tuple of (tokens, hidden_states) + x, hidden_states = transformer_out + assert isinstance(hidden_states, list) + hidden_states = [h.permute(1, 0, 2) for h in hidden_states] + else: + x = transformer_out + hidden_states = None + x = x.permute(1, 0, 2) # LND -> NLD # x.shape = [batch_size, n_ctx, transformer.width] @@ -702,10 +751,16 @@ def forward(self, text): else: pooled = pooled @ self.text_projection - if self.output_tokens: - return pooled, tokens - - return pooled + if self.output_hidden_states or output_hidden_states: + if self.output_tokens: + return pooled, tokens, hidden_states + else: + return pooled, hidden_states + else: + if self.output_tokens: + return pooled, tokens + else: + return pooled class MultimodalTransformer(Transformer): diff --git a/tests/test_hidden_states.py b/tests/test_hidden_states.py new file mode 100644 index 000000000..13366de8b --- /dev/null +++ b/tests/test_hidden_states.py @@ -0,0 +1,60 @@ +import torch +from PIL import Image +from open_clip.factory import get_tokenizer +import pytest +import open_clip +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "" + + +if hasattr(torch._C, '_jit_set_profiling_executor'): + # legacy executor is too slow to compile large models for unit tests + # no need for the fusion performance here + torch._C._jit_set_profiling_executor(True) + torch._C._jit_set_profiling_mode(False) + + +test_simple_models = [ + # model, pretrained, jit, force_custom_text + ("ViT-B-32", "laion2b_s34b_b79k", False, False), + # TODO: Add test coca +] + + +@pytest.mark.parametrize("model_type,pretrained,jit,force_custom_text", test_simple_models) +def test_inference_simple( + model_type, + pretrained, + jit, + force_custom_text, +): + model, _, preprocess = open_clip.create_model_and_transforms( + model_type, + pretrained=pretrained, + jit=jit, + force_custom_text=force_custom_text, + ) + tokenizer = get_tokenizer(model_type) + + current_dir = os.path.dirname(os.path.realpath(__file__)) + + image = preprocess(Image.open(current_dir + "/../docs/CLIP.png")).unsqueeze(0) + text = tokenizer(["a diagram", "a dog", "a cat"]) + + with torch.no_grad(): + image_result = model.encode_image(image, output_hidden_states=True) + text_result = model.encode_text(text, output_hidden_states=True) + + image_features, image_hidden_states = image_result + text_features, text_hidden_states = text_result + + text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) + + # shape of hidden states: [bs, n_hidden_states, 1, seq_len, hidden_size] + # take first elem + print(f"Length of text hidden states: {len(text_hidden_states[0])}") + print(f"Length of image hidden states: {len(image_hidden_states)}") + + # TODO: Write hidden state shapes assertions + + assert text_probs.cpu().numpy()[0].tolist() == [1.0, 0.0, 0.0]