Skip to content

Commit

Permalink
Supporting SigLIP and CLIPA-v2 models (both sourced from big_vision j…
Browse files Browse the repository at this point in the history
…ax based modelling code).
  • Loading branch information
rwightman committed Oct 11, 2023
1 parent e3c2ea2 commit 0316911
Show file tree
Hide file tree
Showing 31 changed files with 1,314 additions and 503 deletions.
133 changes: 133 additions & 0 deletions src/open_clip/big_vision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import torch
import numpy as np

from .model import CustomTextCLIP
from .transformer import TextTransformer, Transformer


@torch.no_grad()
def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str):
""" Load weights from .npz checkpoints for official Google big_vision image-text models
"""
from timm.layers import resample_patch_embed, resample_abs_pos_embed

def _n2p(w, t=True):
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
w = w.flatten()
if t:
if w.ndim == 4:
w = w.transpose([3, 2, 0, 1])
elif w.ndim == 3:
w = w.transpose([2, 0, 1])
elif w.ndim == 2:
w = w.transpose([1, 0])
return torch.from_numpy(w)

w = np.load(checkpoint_path)
interpolation = 'bilinear'
antialias = False

def _convert_timm_img(module, prefix):
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
if embed_conv_w.shape[-2:] != module.patch_embed.proj.weight.shape[-2:]:
embed_conv_w = resample_patch_embed(
embed_conv_w,
module.patch_embed.proj.weight.shape[-2:],
interpolation=interpolation,
antialias=antialias,
verbose=True,
)
module.patch_embed.proj.weight.copy_(embed_conv_w)
module.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))

if module.cls_token is not None:
module.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))

pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False)
if pos_embed_w.shape != module.pos_embed.shape:
assert False, f'{pos_embed_w.shape}, {module.pos_embed.shape}'
num_prefix_tokens = 0 if getattr(module, 'no_embed_class', False) else getattr(module, 'num_prefix_tokens', 1)
pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights
pos_embed_w,
new_size=module.patch_embed.grid_size,
num_prefix_tokens=num_prefix_tokens,
interpolation=interpolation,
antialias=antialias,
verbose=True,
)
module.pos_embed.copy_(pos_embed_w)

mha_sub, b_sub, ln1_sub = (0, 0, 1)
for i, block in enumerate(module.blocks.children()):
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/'
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
block.attn.qkv.weight.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
block.attn.qkv.bias.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
for r in range(2):
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel']))
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias']))
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale']))
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias']))

module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))

if module.attn_pool is not None:
block_prefix = f'{prefix}MAPHead_0/'
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
module.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False))
module.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T)
module.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1))
module.attn_pool.kv.weight.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')]))
module.attn_pool.kv.bias.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')]))
module.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
module.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
module.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
module.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
for r in range(2):
getattr(module.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel']))
getattr(module.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias']))

def _convert_openclip_transformer(module: Transformer, prefix):
for i, block in enumerate(module.resblocks.children()):
block_prefix = f'{prefix}encoderblock_{i}/'
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
block.ln_1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
block.ln_1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
block.attn.in_proj_weight.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
block.attn.in_proj_bias.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
block.attn.out_proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
block.attn.out_proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
block.ln_2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/scale']))
block.ln_2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/bias']))
block.mlp.c_fc.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/kernel']))
block.mlp.c_fc.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/bias']))
block.mlp.c_proj.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/kernel']))
block.mlp.c_proj.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/bias']))

def _convert_openclip_txt(module: TextTransformer, prefix):
module.token_embedding.weight.copy_(_n2p(w[f'{prefix}Embed_0/embedding'], t=False))
pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False).squeeze(0)
module.positional_embedding.copy_(pos_embed_w)
_convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/')
module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale']))
module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias']))
module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias']))

_convert_timm_img(model.visual.trunk, 'params/img/')
_convert_openclip_txt(model.text, 'params/txt/')
model.logit_bias.copy_(_n2p(w['params/b'])[0])
model.logit_scale.copy_(_n2p(w['params/t'])[0])


4 changes: 4 additions & 0 deletions src/open_clip/constants.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
INCEPTION_MEAN = (0.5, 0.5, 0.5)
INCEPTION_STD = (0.5, 0.5, 0.5)
107 changes: 73 additions & 34 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import json
import logging
import os
Expand All @@ -19,7 +20,7 @@
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\
list_pretrained_tags_by_model, download_pretrained_from_hf
from .transform import image_transform, AugmentationCfg
from .tokenizer import HFTokenizer, tokenize, syntax_mask_tokenize, random_mask_tokenize, block_mask_tokenize, get_pp_bert_tokenize
from .tokenizer import HFTokenizer, SimpleTokenizer


HF_HUB_PREFIX = 'hf-hub:'
Expand Down Expand Up @@ -75,30 +76,53 @@ def get_model_config(model_name):
return None


def get_tokenizer(model_name):
def _get_hf_config(model_id, cache_dir=None):
config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir)
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
return config


def get_tokenizer(
model_name: str = '',
text_mask: str = '',
**kwargs,
):
if 'siglip' in model_name.lower():
# FIXME temporary hack
from open_clip.tokenizer import SigLipTokenizer
config = get_model_config(model_name)
tokenizer = SigLipTokenizer('c4-en', **kwargs)
if 'context_length' in config['text_cfg'].keys():
tokenizer = partial(tokenizer, context_length=config['text_cfg']['context_length'])

return tokenizer

if model_name.startswith(HF_HUB_PREFIX):
tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):])
model_name = model_name[len(HF_HUB_PREFIX):]
try:
config = _get_hf_config(model_name)
except Exception:
tokenizer = HFTokenizer(model_name)
return tokenizer
else:
config = get_model_config(model_name)
if 'hf_tokenizer_name' in config['text_cfg']:
tokenizer = HFTokenizer(config['text_cfg']['hf_tokenizer_name'])
elif 'text_mask' in config['text_cfg'] and config['text_cfg']['text_mask'] == 'syntax':
tokenizer = syntax_mask_tokenize
elif 'text_mask' in config['text_cfg'] and config['text_cfg']['text_mask'] == 'random':
tokenizer = random_mask_tokenize
elif 'text_mask' in config['text_cfg'] and config['text_cfg']['text_mask'] == 'block':
tokenizer = block_mask_tokenize
elif 'bert_tokenizer' in config['text_cfg'] and config['text_cfg']['bert_tokenizer']:
tokenizer = get_pp_bert_tokenize(
vocab_path=config['text_cfg']['vocab_path'],
max_len=config['text_cfg']['context_length'],
)
else:
tokenizer = tokenize

if 'context_length' in config['text_cfg'].keys():
context_length = config['text_cfg']['context_length']
tokenizer = partial(tokenizer, context_length=context_length)
if 'tokenizer_kwargs' in config['text_cfg']:
tokenizer_kwargs = dict(config['text_cfg']['tokenizer_kwargs'], **kwargs)
else:
tokenizer_kwargs = kwargs

if 'hf_tokenizer_name' in config['text_cfg']:
tokenizer = HFTokenizer(
config['text_cfg']['hf_tokenizer_name'],
**tokenizer_kwargs,
)
else:
tokenizer = SimpleTokenizer.create(text_mask=text_mask, **tokenizer_kwargs)

if 'context_length' in config['text_cfg'].keys():
tokenizer = partial(tokenizer, context_length=config['text_cfg']['context_length'])

return tokenizer

Expand All @@ -119,6 +143,11 @@ def load_state_dict(checkpoint_path: str, map_location='cpu'):


def load_checkpoint(model, checkpoint_path, strict=True):
if Path(checkpoint_path).suffix in ('.npz', '.npy'):
from .big_vision import load_big_vision_weights
load_big_vision_weights(model, checkpoint_path)
return {}

state_dict = load_state_dict(checkpoint_path)
# detect old format and make compatible with new format
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
Expand Down Expand Up @@ -155,10 +184,8 @@ def create_model(
if has_hf_hub_prefix:
model_id = model_name[len(HF_HUB_PREFIX):]
checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir)

with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
pretrained_hf = False # override, no need to load original HF text weights
config = _get_hf_config(cache_dir, model_id)
pretrained_cfg = config['preprocess_cfg']
model_cfg = config['model_cfg']
else:
Expand Down Expand Up @@ -213,11 +240,12 @@ def create_model(
# cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes
cast_dtype = get_cast_dtype(precision)
is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
if is_hf_model and not pretrained and pretrained_hf:
# load pretrained weights for HF text model IFF no CLIP weights being loaded
model_cfg['text_cfg']['hf_model_pretrained'] = True
custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model

if custom_text:
if is_hf_model:
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
if "coca" in model_name:
model = CoCa(**model_cfg, **model_kwargs, cast_dtype=cast_dtype)
else:
Expand All @@ -234,6 +262,7 @@ def create_model(
# Why? The convert_weights_to_lp fn only works with native models.
model.to(device=device, dtype=dtype)
from .transformer import LayerNormFp32

def _convert_ln(m):
if isinstance(m, LayerNormFp32):
m.weight.data = m.weight.data.to(torch.float32)
Expand Down Expand Up @@ -268,7 +297,7 @@ def _convert_ln(m):
raise RuntimeError(error_str)
pretrained_loaded = True
elif has_hf_hub_prefix:
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
logging.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).')
load_checkpoint(model, checkpoint_path)
pretrained_loaded = True

Expand All @@ -280,6 +309,7 @@ def _convert_ln(m):
# set image / mean metadata from pretrained_cfg if available, or use default
model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
model.pretrained_cfg = copy.deepcopy(pretrained_cfg)

if output_dict and hasattr(model, "output_dict"):
model.output_dict = True
Expand Down Expand Up @@ -341,12 +371,11 @@ def create_model_and_transforms(
pretrained_hf: bool = True,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
image_interpolation: Optional[str] = None,
image_resize_mode: Optional[str] = None, # only effective for inference
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
cache_dir: Optional[str] = None,
output_dict: Optional[bool] = None,
pos_embed: str = None,
interpolation: str = 'bicubic', # only effective for inference
square_resize_only: bool = False, # only effective for inference
**model_kwargs,
):
model = create_model(
Expand All @@ -363,26 +392,29 @@ def create_model_and_transforms(
pretrained_hf=pretrained_hf,
cache_dir=cache_dir,
output_dict=output_dict,
pos_embed=pos_embed,
**model_kwargs,
)

image_mean = image_mean or getattr(model.visual, 'image_mean', None)
image_std = image_std or getattr(model.visual, 'image_std', None)
image_interpolation = image_interpolation or getattr(model, 'pretrained_cfg', {}).get('interpolation', None)
image_resize_mode = image_resize_mode or getattr(model, 'pretrained_cfg', {}).get('resize_mode', None)

preprocess_train = image_transform(
model.visual.image_size,
is_train=True,
mean=image_mean,
std=image_std,
interpolation=image_interpolation,
aug_cfg=aug_cfg,
)
preprocess_val = image_transform(
model.visual.image_size,
is_train=False,
mean=image_mean,
std=image_std,
interpolation=interpolation,
square_resize_only=square_resize_only,
interpolation=image_interpolation,
resize_mode=image_resize_mode,
)

return model, preprocess_train, preprocess_val
Expand All @@ -400,6 +432,8 @@ def create_model_from_pretrained(
return_transform: bool = True,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
image_interpolation: Optional[str] = None,
image_resize_mode: Optional[str] = None, # only effective for inference
cache_dir: Optional[str] = None,
**model_kwargs,
):
Expand All @@ -422,11 +456,16 @@ def create_model_from_pretrained(

image_mean = image_mean or getattr(model.visual, 'image_mean', None)
image_std = image_std or getattr(model.visual, 'image_std', None)
image_interpolation = image_interpolation or getattr(model, 'pretrained_cfg', {}).get('interpolation', None)
image_resize_mode = image_resize_mode or getattr(model, 'pretrained_cfg', {}).get('resize_mode', None)

preprocess = image_transform(
model.visual.image_size,
is_train=False,
mean=image_mean,
std=image_std,
interpolation=image_interpolation,
resize_mode=image_resize_mode,
)

return model, preprocess
8 changes: 4 additions & 4 deletions src/open_clip/hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __init__(
output_dim: int,
config: PretrainedConfig = None,
pooler_type: str = None,
proj: str = None,
proj_type: str = None,
pretrained: bool = True,
output_tokens: bool = False,
):
Expand Down Expand Up @@ -139,11 +139,11 @@ def __init__(
self.pooler = _POOLERS[pooler_type]()

d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
if (d_model == output_dim) and (proj is None): # do we always need a proj?
if (d_model == output_dim) and (proj_type is None): # do we always need a proj?
self.proj = nn.Identity()
elif proj == 'linear':
elif proj_type == 'linear':
self.proj = nn.Linear(d_model, output_dim, bias=False)
elif proj == 'mlp':
elif proj_type == 'mlp':
hidden_size = (d_model + output_dim) // 2
self.proj = nn.Sequential(
nn.Linear(d_model, hidden_size, bias=False),
Expand Down
Loading

0 comments on commit 0316911

Please sign in to comment.