Skip to content

Commit

Permalink
Support MobileCLIP S1 & S2 models via timm integration
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Jun 7, 2024
1 parent 2e8de83 commit 1d7b953
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 5 deletions.
51 changes: 50 additions & 1 deletion src/open_clip/big_vision.py → src/open_clip/convert.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
""" Conversion functions for 3rd part state-dicts and non-torch native checkpoint formats.
"""
from typing import Union

import torch
import numpy as np

from .model import CustomTextCLIP
from .model import CLIP, CustomTextCLIP
from .transformer import TextTransformer, Transformer


Expand Down Expand Up @@ -134,3 +138,48 @@ def _convert_openclip_txt(module: TextTransformer, prefix):
model.logit_scale.copy_(_n2p(w['params/t'])[0])


@torch.no_grad()
def convert_mobile_clip_state_dict(model: CustomTextCLIP, state_dict):
from timm.models.fastvit import _checkpoint_filter_fn

def _convert_timm_img(state_dict, prefix='image_encoder.'):
timm_state_dict = _checkpoint_filter_fn(state_dict, model.visual.trunk)
timm_state_dict = {'visual.trunk.' + k: v for k, v in timm_state_dict.items()}
return timm_state_dict

def _convert_openclip_txt(state_dict, prefix='text_encoder.'):
text_dict = {}
for k, v in state_dict.items():
if not k.startswith(prefix):
continue
k = k.replace(prefix, '')
k = k.replace('projection_layer', 'text_projection')
k = k.replace('embedding_layer', 'token_embedding')
if k.startswith('positional_embedding.pos_embed.pos_embed'):
k = k.replace('positional_embedding.pos_embed.pos_embed', 'positional_embedding')
v = v.squeeze()
k = k.replace('final_layer_norm', 'ln_final')
k = k.replace('pre_norm_mha.0', 'ln_1')
k = k.replace('pre_norm_mha.1', 'attn')
k = k.replace('pre_norm_ffn.0', 'ln_2')
k = k.replace('pre_norm_ffn.1', 'mlp.c_fc')
k = k.replace('pre_norm_ffn.4', 'mlp.c_proj')
k = k.replace('qkv_proj.weight', 'in_proj_weight')
k = k.replace('qkv_proj.bias', 'in_proj_bias')
k = k.replace('transformer.', 'transformer.resblocks.')
text_dict['text.' + k] = v
return text_dict

image_dict = _convert_timm_img(state_dict)
text_dict = _convert_openclip_txt(state_dict)
out_dict = {**image_dict, **text_dict}
out_dict['logit_scale'] = state_dict['logit_scale']
return out_dict


def convert_state_dict(model: Union[CustomTextCLIP, CLIP], state_dict):
if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict:
# Apple MobileCLIP s1 & s2 state_dicts (s0 and b not currently supported)
state_dict = convert_mobile_clip_state_dict(model, state_dict)

return state_dict
21 changes: 18 additions & 3 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch

from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .convert import convert_state_dict
from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
resize_pos_embed, get_cast_dtype, resize_text_pos_embed, set_model_preprocess_cfg
from .coca_model import CoCa
Expand Down Expand Up @@ -139,25 +140,39 @@ def load_state_dict(checkpoint_path: str, map_location='cpu'):
return state_dict


def load_checkpoint(model, checkpoint_path, strict=True):
def load_checkpoint(
model: Union[CLIP, CustomTextCLIP],
checkpoint_path: str,
strict: bool = True,
):
if Path(checkpoint_path).suffix in ('.npz', '.npy'):
from .big_vision import load_big_vision_weights
# Separate path loading numpy big_vision (SigLIP) weights
from open_clip.convert import load_big_vision_weights
load_big_vision_weights(model, checkpoint_path)
return {}

state_dict = load_state_dict(checkpoint_path)
# detect old format and make compatible with new format

# Detect & convert 3rd party state_dicts -> open_clip
state_dict = convert_state_dict(model, state_dict)

# Detect old format and make compatible with new format
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
state_dict = convert_to_custom_text_state_dict(state_dict)

# If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712
if 'logit_bias' not in state_dict and model.logit_bias is not None:
state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"])

# Certain text transformers no longer expect position_ids after transformers==4.31
position_id_key = 'text.transformer.embeddings.position_ids'
if position_id_key in state_dict and not hasattr(model, position_id_key):
del state_dict[position_id_key]

resize_pos_embed(state_dict, model)
resize_text_pos_embed(state_dict, model)

# Finally, load the massaged state_dict into model
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
return incompatible_keys

Expand Down
21 changes: 21 additions & 0 deletions src/open_clip/model_configs/mobileclip_s1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
{
"embed_dim": 512,
"vision_cfg": {
"timm_model_name": "fastvit_mci1",
"timm_model_pretrained": false,
"timm_pool": "avg",
"timm_proj": null,
"timm_drop": 0.0,
"timm_drop_path": 0.0,
"image_size": 256
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12,
"no_causal_mask": true
},
"custom_text": true
}
21 changes: 21 additions & 0 deletions src/open_clip/model_configs/mobileclip_s2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
{
"embed_dim": 512,
"vision_cfg": {
"timm_model_name": "fastvit_mci2",
"timm_model_pretrained": false,
"timm_pool": "avg",
"timm_proj": null,
"timm_drop": 0.0,
"timm_drop_path": 0.0,
"image_size": 256
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12,
"no_causal_mask": true
},
"custom_text": true
}
21 changes: 20 additions & 1 deletion src/open_clip/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,20 @@ def _apcfg(url='', hf_hub='', **kwargs):
}


def _mccfg(url='', hf_hub='', **kwargs):
# MobileCLIP
return {
'url': url,
'hf_hub': hf_hub,
'mean': (0., 0., 0.),
'std': (1., 1., 1.),
'interpolation': 'bilinear',
'resize_mode': 'shortest',
**kwargs,
}



_RN50 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"),
Expand Down Expand Up @@ -438,7 +452,12 @@ def _apcfg(url='', hf_hub='', **kwargs):
"nllb-clip-large-siglip": dict(
v1=_slpcfg(hf_hub='visheratin/nllb-clip-large-siglip/'),
mrl=_slpcfg(hf_hub='visheratin/nllb-siglip-mrl-large/'),
)
),

"mobileclip_s1": dict(
datacompdr=_mccfg(url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s1.pt')),
"mobileclip_s2": dict(
datacompdr=_mccfg(url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s2.pt'))
}


Expand Down

0 comments on commit 1d7b953

Please sign in to comment.