From a39f3a27e1747a49ffadb8a07c1460b866bc2f9f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 22 Sep 2023 12:17:11 -0700 Subject: [PATCH] SigLIP impl (#634) * Initial SigLIP impl * Add logit_bias to custom text clip * non-dict model output wrong way around wrt logit_bias * Disable diving loss by world size, better without * A bit of cleanup * Add bidirectional exchange option, more cleanup * Add reference in siglip docstring * Remove some comments after further verification * bidir exchange by default * Proper bidir default --- src/open_clip/factory.py | 19 +++- src/open_clip/loss.py | 198 +++++++++++++++++++++++++++++++++++++++ src/open_clip/model.py | 35 ++++++- src/training/main.py | 5 + src/training/params.py | 6 ++ src/training/train.py | 24 +++-- 6 files changed, 272 insertions(+), 15 deletions(-) diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index e4f0b7632..ebea3a160 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -13,7 +13,7 @@ from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ resize_pos_embed, get_cast_dtype from .coca_model import CoCa -from .loss import ClipLoss, DistillClipLoss, CoCaLoss +from .loss import ClipLoss, DistillClipLoss, CoCaLoss, SigLipLoss from .openai import load_openai_model from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\ list_pretrained_tags_by_model, download_pretrained_from_hf @@ -128,6 +128,7 @@ def create_model( cache_dir: Optional[str] = None, output_dict: Optional[bool] = None, require_pretrained: bool = False, + **model_kwargs, ): has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX) if has_hf_hub_prefix: @@ -193,11 +194,11 @@ def create_model( if is_hf_model: model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf if "coca" in model_name: - model = CoCa(**model_cfg, cast_dtype=cast_dtype) + model = CoCa(**model_cfg, **model_kwargs, cast_dtype=cast_dtype) else: - model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) + model = CustomTextCLIP(**model_cfg, **model_kwargs, cast_dtype=cast_dtype) else: - model = CLIP(**model_cfg, cast_dtype=cast_dtype) + model = CLIP(**model_cfg, **model_kwargs, cast_dtype=cast_dtype) if precision in ("fp16", "bf16"): dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 @@ -285,6 +286,12 @@ def create_loss(args): world_size=args.world_size, use_horovod=args.horovod, ) + elif args.siglip: + assert not args.horovod, "Horovod not currently supported for SigLip" + return SigLipLoss( + rank=args.rank, + world_size=args.world_size, + ) return ClipLoss( local_loss=args.local_loss, gather_with_grad=args.gather_with_grad, @@ -312,6 +319,7 @@ def create_model_and_transforms( aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, cache_dir: Optional[str] = None, output_dict: Optional[bool] = None, + **model_kwargs, ): model = create_model( model_name, @@ -327,6 +335,7 @@ def create_model_and_transforms( pretrained_hf=pretrained_hf, cache_dir=cache_dir, output_dict=output_dict, + **model_kwargs, ) image_mean = image_mean or getattr(model.visual, 'image_mean', None) @@ -361,6 +370,7 @@ def create_model_from_pretrained( image_mean: Optional[Tuple[float, ...]] = None, image_std: Optional[Tuple[float, ...]] = None, cache_dir: Optional[str] = None, + **model_kwargs, ): model = create_model( model_name, @@ -373,6 +383,7 @@ def create_model_from_pretrained( force_image_size=force_image_size, cache_dir=cache_dir, require_pretrained=True, + **model_kwargs, ) if not return_transform: diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index 0dd048935..5beaab1c3 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -214,3 +214,201 @@ def forward( return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss} return contrastive_loss, distill_loss + + +def neighbour_exchange(from_rank, to_rank, tensor, group=None): + tensor_recv = torch.zeros_like(tensor) + send_op = torch.distributed.P2POp( + torch.distributed.isend, + tensor, + to_rank, + group=group, + ) + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, + tensor_recv, + from_rank, + group=group, + ) + reqs = torch.distributed.batch_isend_irecv([send_op, recv_op]) + for req in reqs: + req.wait() + return tensor_recv + + +def neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None): + tensor_from_left = torch.zeros_like(tensor_to_right) + tensor_from_right = torch.zeros_like(tensor_to_left) + send_op_left = torch.distributed.P2POp( + torch.distributed.isend, + tensor_to_left, + left_rank, + group=group, + ) + send_op_right = torch.distributed.P2POp( + torch.distributed.isend, + tensor_to_right, + right_rank, + group=group, + ) + recv_op_left = torch.distributed.P2POp( + torch.distributed.irecv, + tensor_from_left, + left_rank, + group=group, + ) + recv_op_right = torch.distributed.P2POp( + torch.distributed.irecv, + tensor_from_right, + right_rank, + group=group, + ) + reqs = torch.distributed.batch_isend_irecv([send_op_right, send_op_left, recv_op_right, recv_op_left]) + for req in reqs: + req.wait() + return tensor_from_right, tensor_from_left + + +class NeighbourExchange(torch.autograd.Function): + @staticmethod + def forward(ctx, from_rank, to_rank, group, tensor): + ctx.group = group + ctx.from_rank = from_rank + ctx.to_rank = to_rank + return neighbour_exchange(from_rank, to_rank, tensor, group=group) + + @staticmethod + def backward(ctx, grad_output): + return (None, None, None) + (NeighbourExchange.apply(ctx.to_rank, ctx.from_rank, ctx.group, grad_output),) + + +def neighbour_exchange_with_grad(from_rank, to_rank, tensor, group=None): + return NeighbourExchange.apply(from_rank, to_rank, group, tensor) + + +class NeighbourExchangeBidir(torch.autograd.Function): + @staticmethod + def forward(ctx, left_rank, right_rank, group, tensor_to_left, tensor_to_right): + ctx.group = group + ctx.left_rank = left_rank + ctx.right_rank = right_rank + return neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=group) + + @staticmethod + def backward(ctx, *grad_outputs): + return (None, None, None) + \ + NeighbourExchangeBidir.apply(ctx.right_rank, ctx.left_rank, ctx.group, *grad_outputs) + + +def neighbour_exchange_bidir_with_grad(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None): + return NeighbourExchangeBidir.apply(left_rank, right_rank, group, tensor_to_left, tensor_to_right) + + +class SigLipLoss(nn.Module): + """ Sigmoid Loss for Language Image Pre-Training (SigLIP) - https://arxiv.org/abs/2303.15343 + + @article{zhai2023sigmoid, + title={Sigmoid loss for language image pre-training}, + author={Zhai, Xiaohua and Mustafa, Basil and Kolesnikov, Alexander and Beyer, Lucas}, + journal={arXiv preprint arXiv:2303.15343}, + year={2023} + } + """ + def __init__( + self, + cache_labels=False, + rank=0, + world_size=1, + bidir=True, + use_horovod=False, + ): + super().__init__() + self.cache_labels = cache_labels + self.rank = rank + self.world_size = world_size + assert not use_horovod # FIXME need to look at hvd ops for ring transfers + self.use_horovod = use_horovod + self.bidir = bidir + + # cache state FIXME cache not currently used, worthwhile? + self.prev_num_logits = 0 + self.labels = {} + + def get_ground_truth(self, device, dtype, num_logits, negative_only=False) -> torch.Tensor: + labels = -torch.ones((num_logits, num_logits), device=device, dtype=dtype) + if not negative_only: + labels = 2 * torch.eye(num_logits, device=device, dtype=dtype) + labels + return labels + + def get_logits(self, image_features, text_features, logit_scale, logit_bias=None): + logits = logit_scale * image_features @ text_features.T + if logit_bias is not None: + logits += logit_bias + return logits + + def _loss(self, image_features, text_features, logit_scale, logit_bias=None, negative_only=False): + logits = self.get_logits(image_features, text_features, logit_scale, logit_bias) + labels = self.get_ground_truth( + image_features.device, + image_features.dtype, + image_features.shape[0], + negative_only=negative_only, + ) + loss = -F.logsigmoid(labels * logits).sum() / image_features.shape[0] + return loss + + def forward(self, image_features, text_features, logit_scale, logit_bias, output_dict=False): + loss = self._loss(image_features, text_features, logit_scale, logit_bias) + + if self.world_size > 1: + # exchange text features w/ neighbour world_size - 1 times + right_rank = (self.rank + 1) % self.world_size + left_rank = (self.rank - 1 + self.world_size) % self.world_size + if self.bidir: + text_features_to_right = text_features_to_left = text_features + num_bidir, remainder = divmod(self.world_size - 1, 2) + for i in range(num_bidir): + text_features_recv = neighbour_exchange_bidir_with_grad( + left_rank, + right_rank, + text_features_to_left, + text_features_to_right, + ) + + for f in text_features_recv: + loss += self._loss( + image_features, + f, + logit_scale, + logit_bias, + negative_only=True, + ) + text_features_to_left, text_features_to_right = text_features_recv + + if remainder: + text_features_recv = neighbour_exchange_with_grad( + left_rank, right_rank, text_features_to_right) + + loss += self._loss( + image_features, + text_features_recv, + logit_scale, + logit_bias, + negative_only=True, + ) + else: + text_features_to_right = text_features + for i in range(self.world_size - 1): + text_features_from_left = neighbour_exchange_with_grad( + left_rank, right_rank, text_features_to_right) + + loss += self._loss( + image_features, + text_features_from_left, + logit_scale, + logit_bias, + negative_only=True, + ) + text_features_to_right = text_features_from_left + + return {"contrastive_loss": loss} if output_dict else loss diff --git a/src/open_clip/model.py b/src/open_clip/model.py index 17fefc669..b812f4ad8 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -193,6 +193,8 @@ def __init__( vision_cfg: CLIPVisionCfg, text_cfg: CLIPTextCfg, quick_gelu: bool = False, + init_logit_scale: float = np.log(1 / 0.07), + init_logit_bias: Optional[float] = None, cast_dtype: Optional[torch.dtype] = None, output_dict: bool = False, ): @@ -210,7 +212,11 @@ def __init__( self.text_projection = text.text_projection self.register_buffer('attn_mask', text.attn_mask, persistent=False) - self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale) + if init_logit_bias is not None: + self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias) + else: + self.logit_bias = None def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 @@ -249,12 +255,19 @@ def forward( ): image_features = self.encode_image(image, normalize=True) if image is not None else None text_features = self.encode_text(text, normalize=True) if text is not None else None + if self.output_dict: - return { + out_dict = { "image_features": image_features, "text_features": text_features, "logit_scale": self.logit_scale.exp() } + if self.logit_bias is not None: + out_dict['logit_bias'] = self.logit_bias + return out_dict + + if self.logit_bias is not None: + return image_features, text_features, self.logit_scale.exp(), self.logit_bias return image_features, text_features, self.logit_scale.exp() @@ -267,6 +280,8 @@ def __init__( vision_cfg: CLIPVisionCfg, text_cfg: CLIPTextCfg, quick_gelu: bool = False, + init_logit_scale: float = np.log(1 / 0.07), + init_logit_bias: Optional[float] = None, cast_dtype: Optional[torch.dtype] = None, output_dict: bool = False, ): @@ -276,7 +291,11 @@ def __init__( self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) self.context_length = self.text.context_length self.vocab_size = self.text.vocab_size - self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale) + if init_logit_bias is not None: + self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias) + else: + self.logit_bias = None def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 @@ -305,12 +324,19 @@ def forward( ): image_features = self.encode_image(image, normalize=True) if image is not None else None text_features = self.encode_text(text, normalize=True) if text is not None else None + if self.output_dict: - return { + out_dict = { "image_features": image_features, "text_features": text_features, "logit_scale": self.logit_scale.exp() } + if self.logit_bias is not None: + out_dict['logit_bias'] = self.logit_bias + return out_dict + + if self.logit_bias is not None: + return image_features, text_features, self.logit_scale.exp(), self.logit_bias return image_features, text_features, self.logit_scale.exp() @@ -420,7 +446,6 @@ def build_model_from_openai_state_dict( for key in ["input_resolution", "context_length", "vocab_size"]: state_dict.pop(key, None) - convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16 model.load_state_dict(state_dict) return model.eval() diff --git a/src/training/main.py b/src/training/main.py index d7e586ecf..08d2412e2 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -215,6 +215,10 @@ def main(args): # arg is nargs, single (square) image size list -> int args.force_image_size = args.force_image_size[0] random_seed(args.seed, 0) + model_kwargs = {} + if args.siglip: + model_kwargs['init_logit_scale'] = np.log(10) # different from CLIP + model_kwargs['init_logit_bias'] = -10 model, preprocess_train, preprocess_val = create_model_and_transforms( args.model, args.pretrained, @@ -230,6 +234,7 @@ def main(args): image_std=args.image_std, aug_cfg=args.aug_cfg, output_dict=True, + **model_kwargs, ) if args.distill: # FIXME: currently assumes the model you're distilling from has the same tokenizer & transforms. diff --git a/src/training/params.py b/src/training/params.py index 18ed9364d..345382e57 100644 --- a/src/training/params.py +++ b/src/training/params.py @@ -436,6 +436,12 @@ def parse_args(args): help='Replace the network linear layers from the bitsandbytes library. ' 'Allows int8 training/inference, etc.' ) + parser.add_argument( + "--siglip", + default=False, + action="store_true", + help='Use SigLip (sigmoid) loss.' + ) args = parser.parse_args(args) # If some params are not passed, we use the default values based on model name. diff --git a/src/training/train.py b/src/training/train.py index e93d9d370..b31d424b4 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -38,6 +38,7 @@ def update(self, val, n=1): self.count += n self.avg = self.sum / self.count + def postprocess_clip_output(model_out): return { "image_features": model_out[0], @@ -45,6 +46,7 @@ def postprocess_clip_output(model_out): "logit_scale": model_out[2] } + def unwrap_model(model): if hasattr(model, 'module'): return model.module @@ -64,7 +66,6 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist autocast = get_autocast(args.precision) input_dtype = get_input_dtype(args.precision) - model.train() if args.distill: dist_model.eval() @@ -102,7 +103,7 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist if args.distill: with torch.no_grad(): dist_model_out = dist_model(images, texts) - model_out.update({f'dist_{k}' : v for k, v in dist_model_out.items()}) + model_out.update({f'dist_{k}': v for k, v in dist_model_out.items()}) losses = loss(**model_out, output_dict=True) total_loss = sum(losses.values()) @@ -114,7 +115,10 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist with torch.no_grad(): with autocast(): model_out = model(images, texts) - model_out.pop("logit_scale") + + for f in ("logit_scale", "logit_bias"): + model_out.pop(f, None) + for key, val in model_out.items(): if key in accum_features: accum_features[key].append(val) @@ -138,15 +142,23 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist texts = accum_texts[j] with autocast(): model_out = model(images, texts) - logit_scale = model_out.pop("logit_scale") + + inputs_no_accum = {} + inputs_no_accum["logit_scale"] = logit_scale = model_out.pop("logit_scale") + if "logit_bias" in model_out: + inputs_no_accum["logit_bias"] = model_out.pop("logit_bias") + inputs = {} for key, val in accum_features.items(): accumulated = accum_features[key] - inputs[key] = torch.cat(accumulated[:j] + [model_out[key]] + accumulated[j + 1:]) - losses = loss(**inputs, logit_scale=logit_scale, output_dict=True) + inputs[key] = torch.cat(accumulated[:j] + [model_out[key]] + accumulated[j + 1:]) + + losses = loss(**inputs, **inputs_no_accum, output_dict=True) del inputs + del inputs_no_accum total_loss = sum(losses.values()) losses["loss"] = total_loss + backward(total_loss, scaler) if scaler is not None: