Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable passing output_hidden_states #731

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 43 additions & 7 deletions src/open_clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ def _build_vision_tower(
embed_dim: int,
vision_cfg: CLIPVisionCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None
cast_dtype: Optional[torch.dtype] = None,
output_hidden_states: bool = False,
):
if isinstance(vision_cfg, dict):
vision_cfg = CLIPVisionCfg(**vision_cfg)
Expand All @@ -130,6 +131,10 @@ def _build_vision_tower(
)
elif isinstance(vision_cfg.layers, (tuple, list)):
vision_heads = vision_cfg.width * 32 // vision_cfg.head_width

if output_hidden_states:
raise ValueError("output_hidden_states not supported with ModifiedResNet")

visual = ModifiedResNet(
layers=vision_cfg.layers,
output_dim=embed_dim,
Expand Down Expand Up @@ -165,6 +170,7 @@ def _build_vision_tower(
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer,
output_hidden_states=output_hidden_states,
)

return visual
Expand All @@ -175,6 +181,7 @@ def _build_text_tower(
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
output_hidden_states: bool = False,
):
if isinstance(text_cfg, dict):
text_cfg = CLIPTextCfg(**text_cfg)
Expand Down Expand Up @@ -262,18 +269,42 @@ def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.transformer.grad_checkpointing = enable

def encode_image(self, image, normalize: bool = False):
features = self.visual(image)
return F.normalize(features, dim=-1) if normalize else features
def encode_image(self, image, normalize: bool = False, output_hidden_states: bool = False):
result = self.visual(image, output_hidden_states=output_hidden_states)
if output_hidden_states:
features = result[0]
hidden_states = result[1]
else:
features = result
hidden_states = None

def encode_text(self, text, normalize: bool = False):
if normalize:
features = F.normalize(features, dim=-1)

if output_hidden_states:
return features, hidden_states
return features

def encode_text(self, text, normalize: bool = False, output_hidden_states: bool = False):
cast_dtype = self.transformer.get_cast_dtype()

encoder_states = [] if output_hidden_states else None

x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]

x = x + self.positional_embedding.to(cast_dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x, attn_mask=self.attn_mask)

# we use text embedding as first hidden state in a list
if output_hidden_states:
encoder_states.append(x)

x = self.transformer(x, attn_mask=self.attn_mask, output_hidden_states=output_hidden_states)

if output_hidden_states:
x = x[0]
encoder_states.extend(x[1])

x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
x, _ = text_global_pool(x, text, self.text_pool_type)
Expand All @@ -283,7 +314,12 @@ def encode_text(self, text, normalize: bool = False):
else:
x = x @ self.text_projection

return F.normalize(x, dim=-1) if normalize else x
if normalize:
x = F.normalize(x, dim=-1)

if output_hidden_states:
return x, encoder_states
return x

def forward(
self,
Expand Down
81 changes: 68 additions & 13 deletions src/open_clip/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,13 +315,30 @@ def get_cast_dtype(self) -> torch.dtype:
return self.resblocks[0].mlp.c_fc.int8_original_dtype
return self.resblocks[0].mlp.c_fc.weight.dtype

def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, output_hidden_states: bool = False):
"""
:param x: shape = [bs, seq_len, width]
:param attn_mask: shape = [bs, seq_len, seq_len]
:param output_hidden_states: bool
"""
encoder_states = [] if output_hidden_states else None

# collect encoder states in a list
if output_hidden_states:
encoder_states.append(x)

for r in self.resblocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
x = checkpoint(r, x, None, None, attn_mask)
else:
x = r(x, attn_mask=attn_mask)

if output_hidden_states:
encoder_states.append(x)

if output_hidden_states:
return x, encoder_states
return x


Expand Down Expand Up @@ -349,10 +366,14 @@ def __init__(
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
output_tokens: bool = False,
output_hidden_states: bool = False
):
super().__init__()
assert pool_type in ('tok', 'avg', 'none')
self.output_tokens = output_tokens
self.output_hidden_states = output_hidden_states
self.hidden_size = width

image_height, image_width = self.image_size = to_2tuple(image_size)
patch_height, patch_width = self.patch_size = to_2tuple(patch_size)
self.grid_size = (image_height // patch_height, image_width // patch_width)
Expand Down Expand Up @@ -499,7 +520,7 @@ def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

return pooled, tokens

def forward(self, x: torch.Tensor):
def forward(self, x: torch.Tensor, output_hidden_states: bool = False):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
Expand All @@ -513,7 +534,16 @@ def forward(self, x: torch.Tensor):
x = self.ln_pre(x)

x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
transformer_out = self.transformer(x, output_hidden_states=self.output_hidden_states or output_hidden_states)
if self.output_hidden_states or output_hidden_states:
# transformer_out is a tuple of (tokens, hidden_states)
x, hidden_states = transformer_out
assert isinstance(hidden_states, list)
hidden_states = [h.permute(1, 0, 2) for h in hidden_states]
else:
x = transformer_out
hidden_states = None

x = x.permute(1, 0, 2) # LND -> NLD

if self.attn_pool is not None:
Expand Down Expand Up @@ -541,10 +571,16 @@ def forward(self, x: torch.Tensor):
if self.proj is not None:
pooled = pooled @ self.proj

if self.output_tokens:
return pooled, tokens

return pooled
if self.output_hidden_states or output_hidden_states:
if self.output_tokens:
return pooled, tokens, hidden_states
else:
return pooled, hidden_states
else:
if self.output_tokens:
return pooled, tokens
else:
return pooled


def text_global_pool(x, text: Optional[torch.Tensor] = None, pool_type: str = 'argmax'):
Expand Down Expand Up @@ -583,10 +619,13 @@ def __init__(
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
output_tokens: bool = False,
output_hidden_states: bool = False,
):
super().__init__()
assert pool_type in ('first', 'last', 'argmax', 'none')
self.output_tokens = output_tokens
self.output_hidden_states = output_hidden_states
self.hidden_size = width
self.num_pos = self.context_length = context_length
self.vocab_size = vocab_size
self.width = width
Expand Down Expand Up @@ -669,7 +708,7 @@ def build_cls_mask(self, text, cast_dtype: torch.dtype):
additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
return additive_mask

def forward(self, text):
def forward(self, text, output_hidden_states: bool = False):
cast_dtype = self.transformer.get_cast_dtype()
seq_len = text.shape[1]

Expand All @@ -684,7 +723,17 @@ def forward(self, text):

x = x + self.positional_embedding[:seq_len].to(cast_dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x, attn_mask=attn_mask)

transformer_out = self.transformer(x, output_hidden_states=self.output_hidden_states or output_hidden_states)
if self.output_hidden_states or output_hidden_states:
# transformer_out is a tuple of (tokens, hidden_states)
x, hidden_states = transformer_out
assert isinstance(hidden_states, list)
hidden_states = [h.permute(1, 0, 2) for h in hidden_states]
else:
x = transformer_out
hidden_states = None

x = x.permute(1, 0, 2) # LND -> NLD

# x.shape = [batch_size, n_ctx, transformer.width]
Expand All @@ -702,10 +751,16 @@ def forward(self, text):
else:
pooled = pooled @ self.text_projection

if self.output_tokens:
return pooled, tokens

return pooled
if self.output_hidden_states or output_hidden_states:
if self.output_tokens:
return pooled, tokens, hidden_states
else:
return pooled, hidden_states
else:
if self.output_tokens:
return pooled, tokens
else:
return pooled


class MultimodalTransformer(Transformer):
Expand Down
60 changes: 60 additions & 0 deletions tests/test_hidden_states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import torch
from PIL import Image
from open_clip.factory import get_tokenizer
import pytest
import open_clip
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""


if hasattr(torch._C, '_jit_set_profiling_executor'):
# legacy executor is too slow to compile large models for unit tests
# no need for the fusion performance here
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(False)


test_simple_models = [
# model, pretrained, jit, force_custom_text
("ViT-B-32", "laion2b_s34b_b79k", False, False),
# TODO: Add test coca
]


@pytest.mark.parametrize("model_type,pretrained,jit,force_custom_text", test_simple_models)
def test_inference_simple(
model_type,
pretrained,
jit,
force_custom_text,
):
model, _, preprocess = open_clip.create_model_and_transforms(
model_type,
pretrained=pretrained,
jit=jit,
force_custom_text=force_custom_text,
)
tokenizer = get_tokenizer(model_type)

current_dir = os.path.dirname(os.path.realpath(__file__))

image = preprocess(Image.open(current_dir + "/../docs/CLIP.png")).unsqueeze(0)
text = tokenizer(["a diagram", "a dog", "a cat"])

with torch.no_grad():
image_result = model.encode_image(image, output_hidden_states=True)
text_result = model.encode_text(text, output_hidden_states=True)

image_features, image_hidden_states = image_result
text_features, text_hidden_states = text_result

text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)

# shape of hidden states: [bs, n_hidden_states, 1, seq_len, hidden_size]
# take first elem
print(f"Length of text hidden states: {len(text_hidden_states[0])}")
print(f"Length of image hidden states: {len(image_hidden_states)}")

# TODO: Write hidden state shapes assertions

assert text_probs.cpu().numpy()[0].tolist() == [1.0, 0.0, 0.0]
Loading