From 6f0622d2e385f77776ee00f6fe573987afea7bd1 Mon Sep 17 00:00:00 2001 From: iejMac Date: Fri, 17 Mar 2023 07:51:07 +0000 Subject: [PATCH 1/2] CoCa: Condition captioning loss on the CLIP similarity --- src/open_clip/loss.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index 4fbf61dac..27fede0ce 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -155,7 +155,7 @@ def __init__( self.clip_loss_weight = clip_loss_weight self.caption_loss_weight = caption_loss_weight - self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id) + self.caption_loss = nn.CrossEntropyLoss(reduction='none', ignore_index=pad_id) def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False): clip_loss = super().forward(image_features, text_features, logit_scale) @@ -165,6 +165,27 @@ def forward(self, image_features, text_features, logits, labels, logit_scale, ou logits.permute(0, 2, 1), labels, ) + + # IDEAS: + # - do we let gradients prop backward from this? maybe a torch.no_grad is due here + # - normalization, + # - we want the distribution to be soft so maybe logit scale? !!! + # - maybe just softmax is fine since we expect it to be unifrom + # p(good_sammple) >> p(bad_sample) + # - TODO: right now posterior = evidence, we need to figure out a smarter + # way of updating our prior which is p(gs) = 1 based on the evidence (CLIP dist) + with torch.no_grad(): + cap_weights = (logit_scale * image_features @ text_features.T).softmax(dim=1).diag().unsqueeze(1) + # adjustment = (cap_weights.shape[0] + 1 - cap_weights.sum()) # in the beginning sim ~ U(bs) + # cap_weights = cap_weights * adjustment + + # caption_loss = caption_loss * cap_weights.unsqueeze(1) + def custom_backward_hook(grad): + return cap_weights * grad + caption_loss.register_hook(custom_backward_hook) + + caption_loss = torch.mean(caption_loss[caption_loss != 0.0]) + caption_loss = caption_loss * self.caption_loss_weight if output_dict: From 45066364f41e02a54fe55e7af87c1b6f36a4dd02 Mon Sep 17 00:00:00 2001 From: iejMac Date: Sun, 26 Mar 2023 08:03:55 +0000 Subject: [PATCH 2/2] update' git pus --- src/open_clip/loss.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index 27fede0ce..60beaabee 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -179,7 +179,6 @@ def forward(self, image_features, text_features, logits, labels, logit_scale, ou # adjustment = (cap_weights.shape[0] + 1 - cap_weights.sum()) # in the beginning sim ~ U(bs) # cap_weights = cap_weights * adjustment - # caption_loss = caption_loss * cap_weights.unsqueeze(1) def custom_backward_hook(grad): return cap_weights * grad caption_loss.register_hook(custom_backward_hook)