You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
if 'enc_outputs' in outputs:
enc_outputs = outputs['enc_outputs']
bin_targets = copy.deepcopy(targets)
for bt in bin_targets:
bt['labels'] = torch.zeros_like(bt['labels'])
indices = self.matcher(enc_outputs, bin_targets)
for loss in self.losses:
if loss == 'masks':
# Intermediate masks losses are too costly to compute, we ignore them.
continue
kwargs = {}
if loss == 'labels':
# Logging is enabled only for the last layer
kwargs['log'] = False
l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes, indicators, **kwargs)
l_dict = {k + f'_enc': v for k, v in l_dict.items()}
losses.update(l_dict)
if "enc_outputs" in outputs:
# {"pred_logits": [B, two_stage_num_proposals, C], "pred_boxes": [B, two_stage_num_proposals, 4]}
enc_outputs = outputs["enc_outputs"]
if self.two_stage_binary_cls:
for bt in targets:
bt["labels"] = torch.zeros_like(bt["labels"])
indices = self.matcher(enc_outputs, targets)
for loss in self.losses:
l_dict = self.get_loss(loss, enc_outputs, targets, indices, num_boxes)
l_dict = {k + "_enc": v for k, v in l_dict.items()}
losses.update(l_dict)
The text was updated successfully, but these errors were encountered:
作者你好,我查看了deformable-detr和dino的two-stage代码,两个用于计算proposals得分的方式不同,请问哪一个效果更好?
deformable-detr筛选proposals的实现如下:
topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
encoder输出的损失为:
dino筛选proposals的实现如下:
topk_proposals = torch.topk(enc_outputs_class.max(-1)[0], topk, dim=1)[1]
encoder输出的损失为:其中self.two_stage_binary_cls为False
The text was updated successfully, but these errors were encountered: