Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CLIPA-v2 and SigLIP (big_vision based) model support #659

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions scripts/clipav1_vit_l16_i37_t8.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# eval on a single gpu
CUDA_VISIBLE_DEVICES=2 TORCH_CUDNN_V8_API_ENABLED=1 TFDS_PREFETCH_SIZE=8192 python3 -m training.main \
--model ViT-L-16-CL32-GAP \
--pretrained "/path/to/clipa_vit_l16_i37_t8.pt" \
--seed 0 \
--imagenet-val '/path/to/ImageNet/val'
10 changes: 10 additions & 0 deletions scripts/clipav2_vit_h14_i84_224_336_cl32_gap_datacomp1b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
CUDA_VISIBLE_DEVICES=1 python3 -m training.main \
--model ViT-H-14-CL32-GAP-BigVision \
--pretrained "/path/to/vit_h14_i84_224_336_cl32_gap_datacomp1b.pt" \
--force-image-size 336 \
--square-resize-only \
--interpolation 'bilinear' \
--image-mean 0.485 0.456 0.406 \
--image-std 0.229 0.224 0.225 \
--seed 0 \
--imagenet-val '/path/to/ImageNet/val'
20 changes: 19 additions & 1 deletion src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\
list_pretrained_tags_by_model, download_pretrained_from_hf
from .transform import image_transform, AugmentationCfg
from .tokenizer import HFTokenizer, tokenize, syntax_mask_tokenize, random_mask_tokenize, block_mask_tokenize
from .tokenizer import HFTokenizer, tokenize, syntax_mask_tokenize, random_mask_tokenize, block_mask_tokenize, get_pp_bert_tokenize


HF_HUB_PREFIX = 'hf-hub:'
Expand Down Expand Up @@ -88,11 +88,18 @@ def get_tokenizer(model_name):
tokenizer = random_mask_tokenize
elif 'text_mask' in config['text_cfg'] and config['text_cfg']['text_mask'] == 'block':
tokenizer = block_mask_tokenize
elif 'bert_tokenizer' in config['text_cfg'] and config['text_cfg']['bert_tokenizer']:
tokenizer = get_pp_bert_tokenize(
vocab_path=config['text_cfg']['vocab_path'],
max_len=config['text_cfg']['context_length'],
)
else:
tokenizer = tokenize

if 'context_length' in config['text_cfg'].keys():
context_length = config['text_cfg']['context_length']
tokenizer = partial(tokenizer, context_length=context_length)

return tokenizer


Expand Down Expand Up @@ -141,6 +148,7 @@ def create_model(
cache_dir: Optional[str] = None,
output_dict: Optional[bool] = None,
require_pretrained: bool = False,
pos_embed: str = None,
**model_kwargs,
):
has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
Expand Down Expand Up @@ -190,6 +198,10 @@ def create_model(
# override model config's image size
model_cfg["vision_cfg"]["image_size"] = force_image_size

if pos_embed is not None:
# override model config's positional embedding
model_cfg["vision_cfg"]["pos_embed"] = pos_embed

is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {})
if pretrained_image:
if is_timm_model:
Expand Down Expand Up @@ -332,6 +344,9 @@ def create_model_and_transforms(
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
cache_dir: Optional[str] = None,
output_dict: Optional[bool] = None,
pos_embed: str = None,
interpolation: str = 'bicubic', # only effective for inference
square_resize_only: bool = False, # only effective for inference
**model_kwargs,
):
model = create_model(
Expand All @@ -348,6 +363,7 @@ def create_model_and_transforms(
pretrained_hf=pretrained_hf,
cache_dir=cache_dir,
output_dict=output_dict,
pos_embed=pos_embed,
**model_kwargs,
)

Expand All @@ -365,6 +381,8 @@ def create_model_and_transforms(
is_train=False,
mean=image_mean,
std=image_std,
interpolation=interpolation,
square_resize_only=square_resize_only,
)

return model, preprocess_train, preprocess_val
Expand Down
32 changes: 31 additions & 1 deletion src/open_clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch.nn.functional as F
from torch import nn
from torch.utils.checkpoint import checkpoint
from functools import partial

from .hf_model import HFTextEncoder
from .modified_resnet import ModifiedResNet
Expand Down Expand Up @@ -45,6 +46,10 @@ class CLIPVisionCfg:
timm_proj_bias: bool = False # enable bias final projection
timm_drop: float = 0. # head dropout
timm_drop_path: Optional[float] = None # backbone stochastic depth
pos_embed: str = 'learnable'
act_kwargs: dict = None
ln_pre: bool = True
pool_style: str = 'open_clip'


@dataclass
Expand All @@ -63,6 +68,11 @@ class CLIPTextCfg:
embed_cls: bool = False
pad_id: int = 0
output_tokens: bool = False
act_kwargs: dict = None
pool_style: str = 'open_clip'
bert_tokenizer: bool = False
vocab_path: str = None
attention_mask: bool = True
text_mask: str = 'first' # default first truncate in bpe_tokenizer


Expand Down Expand Up @@ -123,6 +133,8 @@ def _build_vision_tower(
else:
vision_heads = vision_cfg.width // vision_cfg.head_width
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
if vision_cfg.act_kwargs is not None:
act_layer = partial(act_layer, **vision_cfg.act_kwargs)
visual = VisionTransformer(
image_size=vision_cfg.image_size,
patch_size=vision_cfg.patch_size,
Expand All @@ -141,6 +153,9 @@ def _build_vision_tower(
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer,
pos_embed=vision_cfg.pos_embed,
ln_pre=vision_cfg.ln_pre,
pool_style=vision_cfg.pool_style,
)

return visual
Expand Down Expand Up @@ -168,6 +183,9 @@ def _build_text_tower(
act_layer = QuickGELU if quick_gelu else nn.GELU
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm

if text_cfg.act_kwargs is not None:
act_layer = partial(act_layer, **text_cfg.act_kwargs)

text = TextTransformer(
context_length=text_cfg.context_length,
vocab_size=text_cfg.vocab_size,
Expand All @@ -181,6 +199,8 @@ def _build_text_tower(
pad_id=text_cfg.pad_id,
act_layer=act_layer,
norm_layer=norm_layer,
pool_style=text_cfg.pool_style,
attention_mask=text_cfg.attention_mask,
)
return text

Expand Down Expand Up @@ -211,6 +231,7 @@ def __init__(
self.positional_embedding = text.positional_embedding
self.ln_final = text.ln_final
self.text_projection = text.text_projection
self.pool_style = text.pool_style
self.register_buffer('attn_mask', text.attn_mask, persistent=False)

self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
Expand Down Expand Up @@ -243,7 +264,16 @@ def encode_text(self, text, normalize: bool = False):
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
if self.pool_style == 'open_clip':
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
elif self.pool_style == 'big_vision_tok':
pooled = x[:, 0]
x = pooled @ self.text_projection
elif self.pool_style == 'big_vision_last':
pooled = x[:, -1]
x = pooled @ self.text_projection
else:
raise ValueError
return F.normalize(x, dim=-1) if normalize else x

def forward(
Expand Down
30 changes: 30 additions & 0 deletions src/open_clip/model_configs/ViT-H-14-CL32-GAP-BigVision.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 32,
"width": 1280,
"head_width": 80,
"patch_size": 14,
"act_kwargs":{
"approximate": "tanh"
},
"ln_pre": false,
"pool_style": "big_vision_gap",
"global_average_pool": true
},
"text_cfg": {
"context_length": 32,
"vocab_size": 32000,
"bert_tokenizer": true,
"vocab_path": "gs://vit_models/lit/bert/uncased_L-12_H-768_A-12/vocab.txt",
"width": 1024,
"heads": 16,
"layers": 24,
"act_kwargs":{
"approximate": "tanh"
},
"pool_style": "big_vision_last",
"attention_mask": false
}
}
17 changes: 17 additions & 0 deletions src/open_clip/model_configs/ViT-L-16-CL32-GAP.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 224,
"layers": 24,
"width": 1024,
"patch_size": 16,
"global_average_pool": true
},
"text_cfg": {
"context_length": 32,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
}
}
96 changes: 96 additions & 0 deletions src/open_clip/pos_embed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# Position embedding utils
# --------------------------------------------------------

import numpy as np

import torch

# --------------------------------------------------------
# 2D sine-cosine position embedding
# References:
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
# MoCo v3: https://github.com/facebookresearch/moco-v3
# --------------------------------------------------------
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)

grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed


def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0

# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)

emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=float)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)

pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product

emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)

emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb


# --------------------------------------------------------
# Interpolate position embeddings for high-resolution
# References:
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------
def interpolate_pos_embed(model, checkpoint_model):
if 'pos_embed' in checkpoint_model:
pos_embed_checkpoint = checkpoint_model['pos_embed']
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model['pos_embed'] = new_pos_embed
Loading
Loading