Skip to content

Commit

Permalink
Combining CLIPA-v2 and SigLIP (both big_vision based) models (#660)
Browse files Browse the repository at this point in the history
* merge changes for clipa inference

* update get_tokenizer to pass CI test; replace gelu_appoximate with act_kwargs

* Temporary, cannot have a force tf dependency

* Supporting SigLIP and CLIPA-v2 models (both sourced from big_vision jax based modelling code).

* Fix some test failures, remove old v1 CLIPA configs, add add 336 H14 CLIPA

* Fix torchscript

* Fix CoCa expand typo, force final LN after attentional pool

* Used wrong default clean fn in SimpleTokenizer, put lower case back

* Attempt to fix xlm roberta test w/ pretrained hf weight difference

* SigLIP weights working. More changes to support differing image preprocessing / text tokenization sensibly.

* A typo and unused import

* Fix two small issues, add hf_tokenizer_name to SigLIP models for non hf-hub use

* CLIPA reference temppory rwightman/ models for testing

* Rename profile->profiler to avoid python naming conflict

* More tokenizer rework, add context_len as class attr set in factory, default __call__() arg to None. Clean up reduction masking logic and fix #680

* fix ViT-SO400M-14-SigLIP name

* Fix CoCa pool LN, improve clarity of ViT pooling logic

* Exclude first/last tokens from tokens output of text models, should match prev CoCa behaviour, but at odds with argmax which leaves special tokens in (not consistent)

* Add eval results for CLIPA + SigLIP models

* Fixup bigG CLIPA config, 83.03 top-1 IN-1k

---------

Co-authored-by: zw <[email protected]>
Co-authored-by: Gabriel Ilharco <[email protected]>
  • Loading branch information
3 people authored Oct 20, 2023
1 parent e7b39e4 commit a5f3ae9
Show file tree
Hide file tree
Showing 50 changed files with 1,761 additions and 478 deletions.
29 changes: 22 additions & 7 deletions docs/openclip_results.csv

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions scripts/clipav1_vit_l16_i37_t8.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# eval on a single gpu
CUDA_VISIBLE_DEVICES=2 TORCH_CUDNN_V8_API_ENABLED=1 TFDS_PREFETCH_SIZE=8192 python3 -m training.main \
--model ViT-L-16-CL32-GAP \
--pretrained "/path/to/clipa_vit_l16_i37_t8.pt" \
--seed 0 \
--imagenet-val '/path/to/ImageNet/val'
10 changes: 10 additions & 0 deletions scripts/clipav2_vit_h14_i84_224_336_cl32_gap_datacomp1b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
CUDA_VISIBLE_DEVICES=1 python3 -m training.main \
--model ViT-H-14-CL32-GAP-BigVision \
--pretrained "/path/to/vit_h14_i84_224_336_cl32_gap_datacomp1b.pt" \
--force-image-size 336 \
--square-resize-only \
--interpolation 'bilinear' \
--image-mean 0.485 0.456 0.406 \
--image-std 0.229 0.224 0.225 \
--seed 0 \
--imagenet-val '/path/to/ImageNet/val'
3 changes: 2 additions & 1 deletion src/open_clip/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
from .loss import ClipLoss, DistillClipLoss, CoCaLoss
from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype, \
get_model_tokenize_cfg, get_model_preprocess_cfg, set_model_preprocess_cfg
from .openai import load_openai_model, list_openai_models
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
Expand Down
136 changes: 136 additions & 0 deletions src/open_clip/big_vision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import torch
import numpy as np

from .model import CustomTextCLIP
from .transformer import TextTransformer, Transformer


@torch.no_grad()
def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str):
""" Load weights from .npz checkpoints for official Google big_vision image-text models
Currently the SigLIP source models are supported and a CustomTextCLIP destination model
w/ timm image encoder.
"""
from timm.layers import resample_patch_embed, resample_abs_pos_embed

def _n2p(w, t=True):
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
w = w.flatten()
if t:
if w.ndim == 4:
w = w.transpose([3, 2, 0, 1])
elif w.ndim == 3:
w = w.transpose([2, 0, 1])
elif w.ndim == 2:
w = w.transpose([1, 0])
return torch.from_numpy(w)

w = np.load(checkpoint_path)
interpolation = 'bilinear'
antialias = False

def _convert_timm_img(module, prefix):
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
if embed_conv_w.shape[-2:] != module.patch_embed.proj.weight.shape[-2:]:
embed_conv_w = resample_patch_embed(
embed_conv_w,
module.patch_embed.proj.weight.shape[-2:],
interpolation=interpolation,
antialias=antialias,
verbose=True,
)
module.patch_embed.proj.weight.copy_(embed_conv_w)
module.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))

if module.cls_token is not None:
module.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))

pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False)
if pos_embed_w.shape != module.pos_embed.shape:
assert False, f'{pos_embed_w.shape}, {module.pos_embed.shape}'
num_prefix_tokens = 0 if getattr(module, 'no_embed_class', False) else getattr(module, 'num_prefix_tokens', 1)
pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights
pos_embed_w,
new_size=module.patch_embed.grid_size,
num_prefix_tokens=num_prefix_tokens,
interpolation=interpolation,
antialias=antialias,
verbose=True,
)
module.pos_embed.copy_(pos_embed_w)

mha_sub, b_sub, ln1_sub = (0, 0, 1)
for i, block in enumerate(module.blocks.children()):
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/'
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
block.attn.qkv.weight.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
block.attn.qkv.bias.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
for r in range(2):
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel']))
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias']))
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale']))
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias']))

module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))

if module.attn_pool is not None:
block_prefix = f'{prefix}MAPHead_0/'
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
module.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False))
module.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T)
module.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1))
module.attn_pool.kv.weight.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')]))
module.attn_pool.kv.bias.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')]))
module.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
module.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
module.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
module.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
for r in range(2):
getattr(module.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel']))
getattr(module.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias']))

def _convert_openclip_transformer(module: Transformer, prefix):
for i, block in enumerate(module.resblocks.children()):
block_prefix = f'{prefix}encoderblock_{i}/'
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
block.ln_1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
block.ln_1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
block.attn.in_proj_weight.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
block.attn.in_proj_bias.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
block.attn.out_proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
block.attn.out_proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
block.ln_2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/scale']))
block.ln_2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/bias']))
block.mlp.c_fc.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/kernel']))
block.mlp.c_fc.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/bias']))
block.mlp.c_proj.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/kernel']))
block.mlp.c_proj.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/bias']))

def _convert_openclip_txt(module: TextTransformer, prefix):
module.token_embedding.weight.copy_(_n2p(w[f'{prefix}Embed_0/embedding'], t=False))
pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False).squeeze(0)
module.positional_embedding.copy_(pos_embed_w)
_convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/')
module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale']))
module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias']))
module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias']))

_convert_timm_img(model.visual.trunk, 'params/img/')
_convert_openclip_txt(model.text, 'params/txt/')
model.logit_bias.copy_(_n2p(w['params/b'])[0])
model.logit_scale.copy_(_n2p(w['params/t'])[0])


25 changes: 18 additions & 7 deletions src/open_clip/coca_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,35 +123,46 @@ def __init__(
self.pad_id = pad_id

@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
def set_grad_checkpointing(self, enable: bool = True):
self.visual.set_grad_checkpointing(enable)
self.text.set_grad_checkpointing(enable)
self.text_decoder.set_grad_checkpointing(enable)

def _encode_image(self, images, normalize=True):
def _encode_image(self, images, normalize: bool = True):
image_latent, tokens_embs = self.visual(images)
image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
return image_latent, tokens_embs

def _encode_text(self, text, normalize=True, embed_cls=True):
def _encode_text(self, text, normalize: bool = True, embed_cls: bool = True):
text = text[:, :-1] if embed_cls else text # make space for CLS token
text_latent, token_emb = self.text(text)
text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
return text_latent, token_emb

def encode_image(self, images, normalize=True):
def encode_image(self, images, normalize: bool = True):
image_latent, _ = self._encode_image(images, normalize=normalize)
return image_latent

def encode_text(self, text, normalize=True, embed_cls=True):
def encode_text(self, text, normalize: bool = True, embed_cls: bool = True):
text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls)
return text_latent

def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None):
text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls)
def forward(
self,
image,
text: Optional[torch.Tensor] = None,
embed_cls: bool = True,
image_latent: Optional[torch.Tensor] = None,
image_embs: Optional[torch.Tensor] = None,
):
if image_latent is None or image_embs is None:
image_latent, image_embs = self._encode_image(image)

if text is None:
return {"image_features": image_latent, "image_embs": image_embs}

text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls)

# TODO: add assertion to avoid bugs?
labels = text[:, -token_embs.shape[1]:]

Expand Down
4 changes: 4 additions & 0 deletions src/open_clip/constants.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
INCEPTION_MEAN = (0.5, 0.5, 0.5)
INCEPTION_STD = (0.5, 0.5, 0.5)
Loading

0 comments on commit a5f3ae9

Please sign in to comment.