Skip to content

Commit

Permalink
SigLIP impl (mlfoundations#634)
Browse files Browse the repository at this point in the history
* Initial SigLIP impl

* Add logit_bias to custom text clip

* non-dict model output wrong way around wrt logit_bias

* Disable diving loss by world size, better without

* A bit of cleanup

* Add bidirectional exchange option, more cleanup

* Add reference in siglip docstring

* Remove some comments after further verification

* bidir exchange by default

* Proper bidir default
  • Loading branch information
rwightman authored and Interpause committed May 23, 2024
1 parent ba16fc6 commit a39f3a2
Show file tree
Hide file tree
Showing 6 changed files with 272 additions and 15 deletions.
19 changes: 15 additions & 4 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
resize_pos_embed, get_cast_dtype
from .coca_model import CoCa
from .loss import ClipLoss, DistillClipLoss, CoCaLoss
from .loss import ClipLoss, DistillClipLoss, CoCaLoss, SigLipLoss
from .openai import load_openai_model
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\
list_pretrained_tags_by_model, download_pretrained_from_hf
Expand Down Expand Up @@ -128,6 +128,7 @@ def create_model(
cache_dir: Optional[str] = None,
output_dict: Optional[bool] = None,
require_pretrained: bool = False,
**model_kwargs,
):
has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
if has_hf_hub_prefix:
Expand Down Expand Up @@ -193,11 +194,11 @@ def create_model(
if is_hf_model:
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
if "coca" in model_name:
model = CoCa(**model_cfg, cast_dtype=cast_dtype)
model = CoCa(**model_cfg, **model_kwargs, cast_dtype=cast_dtype)
else:
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
model = CustomTextCLIP(**model_cfg, **model_kwargs, cast_dtype=cast_dtype)
else:
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
model = CLIP(**model_cfg, **model_kwargs, cast_dtype=cast_dtype)

if precision in ("fp16", "bf16"):
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
Expand Down Expand Up @@ -285,6 +286,12 @@ def create_loss(args):
world_size=args.world_size,
use_horovod=args.horovod,
)
elif args.siglip:
assert not args.horovod, "Horovod not currently supported for SigLip"
return SigLipLoss(
rank=args.rank,
world_size=args.world_size,
)
return ClipLoss(
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
Expand Down Expand Up @@ -312,6 +319,7 @@ def create_model_and_transforms(
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
cache_dir: Optional[str] = None,
output_dict: Optional[bool] = None,
**model_kwargs,
):
model = create_model(
model_name,
Expand All @@ -327,6 +335,7 @@ def create_model_and_transforms(
pretrained_hf=pretrained_hf,
cache_dir=cache_dir,
output_dict=output_dict,
**model_kwargs,
)

image_mean = image_mean or getattr(model.visual, 'image_mean', None)
Expand Down Expand Up @@ -361,6 +370,7 @@ def create_model_from_pretrained(
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
cache_dir: Optional[str] = None,
**model_kwargs,
):
model = create_model(
model_name,
Expand All @@ -373,6 +383,7 @@ def create_model_from_pretrained(
force_image_size=force_image_size,
cache_dir=cache_dir,
require_pretrained=True,
**model_kwargs,
)

if not return_transform:
Expand Down
198 changes: 198 additions & 0 deletions src/open_clip/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,201 @@ def forward(
return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss}

return contrastive_loss, distill_loss


def neighbour_exchange(from_rank, to_rank, tensor, group=None):
tensor_recv = torch.zeros_like(tensor)
send_op = torch.distributed.P2POp(
torch.distributed.isend,
tensor,
to_rank,
group=group,
)
recv_op = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_recv,
from_rank,
group=group,
)
reqs = torch.distributed.batch_isend_irecv([send_op, recv_op])
for req in reqs:
req.wait()
return tensor_recv


def neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None):
tensor_from_left = torch.zeros_like(tensor_to_right)
tensor_from_right = torch.zeros_like(tensor_to_left)
send_op_left = torch.distributed.P2POp(
torch.distributed.isend,
tensor_to_left,
left_rank,
group=group,
)
send_op_right = torch.distributed.P2POp(
torch.distributed.isend,
tensor_to_right,
right_rank,
group=group,
)
recv_op_left = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_from_left,
left_rank,
group=group,
)
recv_op_right = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_from_right,
right_rank,
group=group,
)
reqs = torch.distributed.batch_isend_irecv([send_op_right, send_op_left, recv_op_right, recv_op_left])
for req in reqs:
req.wait()
return tensor_from_right, tensor_from_left


class NeighbourExchange(torch.autograd.Function):
@staticmethod
def forward(ctx, from_rank, to_rank, group, tensor):
ctx.group = group
ctx.from_rank = from_rank
ctx.to_rank = to_rank
return neighbour_exchange(from_rank, to_rank, tensor, group=group)

@staticmethod
def backward(ctx, grad_output):
return (None, None, None) + (NeighbourExchange.apply(ctx.to_rank, ctx.from_rank, ctx.group, grad_output),)


def neighbour_exchange_with_grad(from_rank, to_rank, tensor, group=None):
return NeighbourExchange.apply(from_rank, to_rank, group, tensor)


class NeighbourExchangeBidir(torch.autograd.Function):
@staticmethod
def forward(ctx, left_rank, right_rank, group, tensor_to_left, tensor_to_right):
ctx.group = group
ctx.left_rank = left_rank
ctx.right_rank = right_rank
return neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=group)

@staticmethod
def backward(ctx, *grad_outputs):
return (None, None, None) + \
NeighbourExchangeBidir.apply(ctx.right_rank, ctx.left_rank, ctx.group, *grad_outputs)


def neighbour_exchange_bidir_with_grad(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None):
return NeighbourExchangeBidir.apply(left_rank, right_rank, group, tensor_to_left, tensor_to_right)


class SigLipLoss(nn.Module):
""" Sigmoid Loss for Language Image Pre-Training (SigLIP) - https://arxiv.org/abs/2303.15343
@article{zhai2023sigmoid,
title={Sigmoid loss for language image pre-training},
author={Zhai, Xiaohua and Mustafa, Basil and Kolesnikov, Alexander and Beyer, Lucas},
journal={arXiv preprint arXiv:2303.15343},
year={2023}
}
"""
def __init__(
self,
cache_labels=False,
rank=0,
world_size=1,
bidir=True,
use_horovod=False,
):
super().__init__()
self.cache_labels = cache_labels
self.rank = rank
self.world_size = world_size
assert not use_horovod # FIXME need to look at hvd ops for ring transfers
self.use_horovod = use_horovod
self.bidir = bidir

# cache state FIXME cache not currently used, worthwhile?
self.prev_num_logits = 0
self.labels = {}

def get_ground_truth(self, device, dtype, num_logits, negative_only=False) -> torch.Tensor:
labels = -torch.ones((num_logits, num_logits), device=device, dtype=dtype)
if not negative_only:
labels = 2 * torch.eye(num_logits, device=device, dtype=dtype) + labels
return labels

def get_logits(self, image_features, text_features, logit_scale, logit_bias=None):
logits = logit_scale * image_features @ text_features.T
if logit_bias is not None:
logits += logit_bias
return logits

def _loss(self, image_features, text_features, logit_scale, logit_bias=None, negative_only=False):
logits = self.get_logits(image_features, text_features, logit_scale, logit_bias)
labels = self.get_ground_truth(
image_features.device,
image_features.dtype,
image_features.shape[0],
negative_only=negative_only,
)
loss = -F.logsigmoid(labels * logits).sum() / image_features.shape[0]
return loss

def forward(self, image_features, text_features, logit_scale, logit_bias, output_dict=False):
loss = self._loss(image_features, text_features, logit_scale, logit_bias)

if self.world_size > 1:
# exchange text features w/ neighbour world_size - 1 times
right_rank = (self.rank + 1) % self.world_size
left_rank = (self.rank - 1 + self.world_size) % self.world_size
if self.bidir:
text_features_to_right = text_features_to_left = text_features
num_bidir, remainder = divmod(self.world_size - 1, 2)
for i in range(num_bidir):
text_features_recv = neighbour_exchange_bidir_with_grad(
left_rank,
right_rank,
text_features_to_left,
text_features_to_right,
)

for f in text_features_recv:
loss += self._loss(
image_features,
f,
logit_scale,
logit_bias,
negative_only=True,
)
text_features_to_left, text_features_to_right = text_features_recv

if remainder:
text_features_recv = neighbour_exchange_with_grad(
left_rank, right_rank, text_features_to_right)

loss += self._loss(
image_features,
text_features_recv,
logit_scale,
logit_bias,
negative_only=True,
)
else:
text_features_to_right = text_features
for i in range(self.world_size - 1):
text_features_from_left = neighbour_exchange_with_grad(
left_rank, right_rank, text_features_to_right)

loss += self._loss(
image_features,
text_features_from_left,
logit_scale,
logit_bias,
negative_only=True,
)
text_features_to_right = text_features_from_left

return {"contrastive_loss": loss} if output_dict else loss
35 changes: 30 additions & 5 deletions src/open_clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ def __init__(
vision_cfg: CLIPVisionCfg,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
init_logit_scale: float = np.log(1 / 0.07),
init_logit_bias: Optional[float] = None,
cast_dtype: Optional[torch.dtype] = None,
output_dict: bool = False,
):
Expand All @@ -210,7 +212,11 @@ def __init__(
self.text_projection = text.text_projection
self.register_buffer('attn_mask', text.attn_mask, persistent=False)

self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
if init_logit_bias is not None:
self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias)
else:
self.logit_bias = None

def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
Expand Down Expand Up @@ -249,12 +255,19 @@ def forward(
):
image_features = self.encode_image(image, normalize=True) if image is not None else None
text_features = self.encode_text(text, normalize=True) if text is not None else None

if self.output_dict:
return {
out_dict = {
"image_features": image_features,
"text_features": text_features,
"logit_scale": self.logit_scale.exp()
}
if self.logit_bias is not None:
out_dict['logit_bias'] = self.logit_bias
return out_dict

if self.logit_bias is not None:
return image_features, text_features, self.logit_scale.exp(), self.logit_bias
return image_features, text_features, self.logit_scale.exp()


Expand All @@ -267,6 +280,8 @@ def __init__(
vision_cfg: CLIPVisionCfg,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
init_logit_scale: float = np.log(1 / 0.07),
init_logit_bias: Optional[float] = None,
cast_dtype: Optional[torch.dtype] = None,
output_dict: bool = False,
):
Expand All @@ -276,7 +291,11 @@ def __init__(
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
self.context_length = self.text.context_length
self.vocab_size = self.text.vocab_size
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
if init_logit_bias is not None:
self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias)
else:
self.logit_bias = None

def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
Expand Down Expand Up @@ -305,12 +324,19 @@ def forward(
):
image_features = self.encode_image(image, normalize=True) if image is not None else None
text_features = self.encode_text(text, normalize=True) if text is not None else None

if self.output_dict:
return {
out_dict = {
"image_features": image_features,
"text_features": text_features,
"logit_scale": self.logit_scale.exp()
}
if self.logit_bias is not None:
out_dict['logit_bias'] = self.logit_bias
return out_dict

if self.logit_bias is not None:
return image_features, text_features, self.logit_scale.exp(), self.logit_bias
return image_features, text_features, self.logit_scale.exp()


Expand Down Expand Up @@ -420,7 +446,6 @@ def build_model_from_openai_state_dict(

for key in ["input_resolution", "context_length", "vocab_size"]:
state_dict.pop(key, None)

convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
model.load_state_dict(state_dict)
return model.eval()
Expand Down
5 changes: 5 additions & 0 deletions src/training/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,10 @@ def main(args):
# arg is nargs, single (square) image size list -> int
args.force_image_size = args.force_image_size[0]
random_seed(args.seed, 0)
model_kwargs = {}
if args.siglip:
model_kwargs['init_logit_scale'] = np.log(10) # different from CLIP
model_kwargs['init_logit_bias'] = -10
model, preprocess_train, preprocess_val = create_model_and_transforms(
args.model,
args.pretrained,
Expand All @@ -230,6 +234,7 @@ def main(args):
image_std=args.image_std,
aug_cfg=args.aug_cfg,
output_dict=True,
**model_kwargs,
)
if args.distill:
# FIXME: currently assumes the model you're distilling from has the same tokenizer & transforms.
Expand Down
Loading

0 comments on commit a39f3a2

Please sign in to comment.