-
Notifications
You must be signed in to change notification settings - Fork 981
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Combining CLIPA-v2 and SigLIP (both big_vision based) models (#660)
* merge changes for clipa inference * update get_tokenizer to pass CI test; replace gelu_appoximate with act_kwargs * Temporary, cannot have a force tf dependency * Supporting SigLIP and CLIPA-v2 models (both sourced from big_vision jax based modelling code). * Fix some test failures, remove old v1 CLIPA configs, add add 336 H14 CLIPA * Fix torchscript * Fix CoCa expand typo, force final LN after attentional pool * Used wrong default clean fn in SimpleTokenizer, put lower case back * Attempt to fix xlm roberta test w/ pretrained hf weight difference * SigLIP weights working. More changes to support differing image preprocessing / text tokenization sensibly. * A typo and unused import * Fix two small issues, add hf_tokenizer_name to SigLIP models for non hf-hub use * CLIPA reference temppory rwightman/ models for testing * Rename profile->profiler to avoid python naming conflict * More tokenizer rework, add context_len as class attr set in factory, default __call__() arg to None. Clean up reduction masking logic and fix #680 * fix ViT-SO400M-14-SigLIP name * Fix CoCa pool LN, improve clarity of ViT pooling logic * Exclude first/last tokens from tokens output of text models, should match prev CoCa behaviour, but at odds with argmax which leaves special tokens in (not consistent) * Add eval results for CLIPA + SigLIP models * Fixup bigG CLIPA config, 83.03 top-1 IN-1k --------- Co-authored-by: zw <[email protected]> Co-authored-by: Gabriel Ilharco <[email protected]>
- Loading branch information
1 parent
e7b39e4
commit a5f3ae9
Showing
50 changed files
with
1,761 additions
and
478 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# eval on a single gpu | ||
CUDA_VISIBLE_DEVICES=2 TORCH_CUDNN_V8_API_ENABLED=1 TFDS_PREFETCH_SIZE=8192 python3 -m training.main \ | ||
--model ViT-L-16-CL32-GAP \ | ||
--pretrained "/path/to/clipa_vit_l16_i37_t8.pt" \ | ||
--seed 0 \ | ||
--imagenet-val '/path/to/ImageNet/val' |
10 changes: 10 additions & 0 deletions
10
scripts/clipav2_vit_h14_i84_224_336_cl32_gap_datacomp1b.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
CUDA_VISIBLE_DEVICES=1 python3 -m training.main \ | ||
--model ViT-H-14-CL32-GAP-BigVision \ | ||
--pretrained "/path/to/vit_h14_i84_224_336_cl32_gap_datacomp1b.pt" \ | ||
--force-image-size 336 \ | ||
--square-resize-only \ | ||
--interpolation 'bilinear' \ | ||
--image-mean 0.485 0.456 0.406 \ | ||
--image-std 0.229 0.224 0.225 \ | ||
--seed 0 \ | ||
--imagenet-val '/path/to/ImageNet/val' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
import torch | ||
import numpy as np | ||
|
||
from .model import CustomTextCLIP | ||
from .transformer import TextTransformer, Transformer | ||
|
||
|
||
@torch.no_grad() | ||
def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str): | ||
""" Load weights from .npz checkpoints for official Google big_vision image-text models | ||
Currently the SigLIP source models are supported and a CustomTextCLIP destination model | ||
w/ timm image encoder. | ||
""" | ||
from timm.layers import resample_patch_embed, resample_abs_pos_embed | ||
|
||
def _n2p(w, t=True): | ||
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: | ||
w = w.flatten() | ||
if t: | ||
if w.ndim == 4: | ||
w = w.transpose([3, 2, 0, 1]) | ||
elif w.ndim == 3: | ||
w = w.transpose([2, 0, 1]) | ||
elif w.ndim == 2: | ||
w = w.transpose([1, 0]) | ||
return torch.from_numpy(w) | ||
|
||
w = np.load(checkpoint_path) | ||
interpolation = 'bilinear' | ||
antialias = False | ||
|
||
def _convert_timm_img(module, prefix): | ||
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) | ||
if embed_conv_w.shape[-2:] != module.patch_embed.proj.weight.shape[-2:]: | ||
embed_conv_w = resample_patch_embed( | ||
embed_conv_w, | ||
module.patch_embed.proj.weight.shape[-2:], | ||
interpolation=interpolation, | ||
antialias=antialias, | ||
verbose=True, | ||
) | ||
module.patch_embed.proj.weight.copy_(embed_conv_w) | ||
module.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) | ||
|
||
if module.cls_token is not None: | ||
module.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) | ||
|
||
pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False) | ||
if pos_embed_w.shape != module.pos_embed.shape: | ||
assert False, f'{pos_embed_w.shape}, {module.pos_embed.shape}' | ||
num_prefix_tokens = 0 if getattr(module, 'no_embed_class', False) else getattr(module, 'num_prefix_tokens', 1) | ||
pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights | ||
pos_embed_w, | ||
new_size=module.patch_embed.grid_size, | ||
num_prefix_tokens=num_prefix_tokens, | ||
interpolation=interpolation, | ||
antialias=antialias, | ||
verbose=True, | ||
) | ||
module.pos_embed.copy_(pos_embed_w) | ||
|
||
mha_sub, b_sub, ln1_sub = (0, 0, 1) | ||
for i, block in enumerate(module.blocks.children()): | ||
block_prefix = f'{prefix}Transformer/encoderblock_{i}/' | ||
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/' | ||
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) | ||
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) | ||
block.attn.qkv.weight.copy_(torch.cat([ | ||
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) | ||
block.attn.qkv.bias.copy_(torch.cat([ | ||
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) | ||
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) | ||
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) | ||
for r in range(2): | ||
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'])) | ||
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'])) | ||
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'])) | ||
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'])) | ||
|
||
module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) | ||
module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) | ||
|
||
if module.attn_pool is not None: | ||
block_prefix = f'{prefix}MAPHead_0/' | ||
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' | ||
module.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False)) | ||
module.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T) | ||
module.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1)) | ||
module.attn_pool.kv.weight.copy_(torch.cat([ | ||
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')])) | ||
module.attn_pool.kv.bias.copy_(torch.cat([ | ||
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')])) | ||
module.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) | ||
module.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) | ||
module.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) | ||
module.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) | ||
for r in range(2): | ||
getattr(module.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel'])) | ||
getattr(module.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias'])) | ||
|
||
def _convert_openclip_transformer(module: Transformer, prefix): | ||
for i, block in enumerate(module.resblocks.children()): | ||
block_prefix = f'{prefix}encoderblock_{i}/' | ||
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' | ||
block.ln_1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) | ||
block.ln_1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) | ||
block.attn.in_proj_weight.copy_(torch.cat([ | ||
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) | ||
block.attn.in_proj_bias.copy_(torch.cat([ | ||
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) | ||
block.attn.out_proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) | ||
block.attn.out_proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) | ||
block.ln_2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/scale'])) | ||
block.ln_2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/bias'])) | ||
block.mlp.c_fc.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/kernel'])) | ||
block.mlp.c_fc.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/bias'])) | ||
block.mlp.c_proj.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/kernel'])) | ||
block.mlp.c_proj.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/bias'])) | ||
|
||
def _convert_openclip_txt(module: TextTransformer, prefix): | ||
module.token_embedding.weight.copy_(_n2p(w[f'{prefix}Embed_0/embedding'], t=False)) | ||
pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False).squeeze(0) | ||
module.positional_embedding.copy_(pos_embed_w) | ||
_convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/') | ||
module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale'])) | ||
module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias'])) | ||
module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) | ||
module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias'])) | ||
|
||
_convert_timm_img(model.visual.trunk, 'params/img/') | ||
_convert_openclip_txt(model.text, 'params/txt/') | ||
model.logit_bias.copy_(_n2p(w['params/b'])[0]) | ||
model.logit_scale.copy_(_n2p(w['params/t'])[0]) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,6 @@ | ||
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) | ||
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) | ||
IMAGENET_MEAN = (0.485, 0.456, 0.406) | ||
IMAGENET_STD = (0.229, 0.224, 0.225) | ||
INCEPTION_MEAN = (0.5, 0.5, 0.5) | ||
INCEPTION_STD = (0.5, 0.5, 0.5) |
Oops, something went wrong.