Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

should gen id together in all class for multi-class dataset? #96

Open
upsx opened this issue Oct 14, 2021 · 4 comments
Open

should gen id together in all class for multi-class dataset? #96

upsx opened this issue Oct 14, 2021 · 4 comments

Comments

@upsx
Copy link

upsx commented Oct 14, 2021

i have looked the motloss function,the re-id loss is implemented by a classifier. if gen id respectively in multi-clss, then diffierent cls obj have same id, so the in classifier the two same id for different cls obj is consider the same class. therefore, i think this is a error. @even

class MotLoss(torch.nn.Module):                                                         # loss网络
    def __init__(self, opt):
        super(MotLoss, self).__init__()
        self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()                 # heatmap loss: FocalLoss
        self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
            RegLoss() if opt.reg_loss == 'sl1' else None                                # reg loss: Reg1Loss
        self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh else \
            NormRegL1Loss() if opt.norm_wh else \
                RegWeightedL1Loss() if opt.cat_spec_wh else self.crit_reg               # wh loss: Reg1Loss
        self.opt = opt
        self.emb_dim = opt.reid_dim                                                     # 64 or 128
        self.nID = opt.nID                                                              # 数据集的id数目
        self.classifier = nn.Linear(self.emb_dim, self.nID)                             # re-id的分类器: 64->nID
        if opt.id_loss == 'focal':                                                      # ce: False
            torch.nn.init.normal_(self.classifier.weight, std=0.01)
            prior_prob = 0.01
            bias_value = -math.log((1 - prior_prob) / prior_prob)
            torch.nn.init.constant_(self.classifier.bias, bias_value)
        self.IDLoss = nn.CrossEntropyLoss(ignore_index=-1)                              # re-id loss: CrossEntropyLoss,忽略id为-1的样本
        self.emb_scale = math.sqrt(2) * math.log(self.nID - 1)                          # ??
        self.s_det = nn.Parameter(-1.85 * torch.ones(1))                                # detect loss weight
        self.s_id = nn.Parameter(-1.05 * torch.ones(1))                                 # re-id loss weight      

    def forward(self, outputs, batch):                                                  # outputs: model的heads输出,batch: {'input': imgs, 'hm': hm, 'reg_mask': reg_mask, 'ind': ind, 'wh': wh, 'reg': reg, 'ids': ids, 'bbox': bbox_xys}
        opt = self.opt
        hm_loss, wh_loss, off_loss, id_loss = 0, 0, 0, 0
        for s in range(opt.num_stacks):                                                 # opt.num_statcks = 1
            output = outputs[s]                                                         # output: model的heads输出,{'hm':hm, 'wh':wh, 'reg':reg, 'id':id}
            if not opt.mse_loss:                                                        # True
                output['hm'] = _sigmoid(output['hm'])                                   # 对heatmap sigmoid计算

            hm_loss += self.crit(output['hm'], batch['hm']) / opt.num_stacks            # hm loss
            if opt.wh_weight > 0:                                                       # wh loss: 0.1
                wh_loss += self.crit_reg(
                    output['wh'], batch['reg_mask'],
                    batch['ind'], batch['wh']) / opt.num_stacks

            if opt.reg_offset and opt.off_weight > 0:                                   # reg loss: True and 1
                off_loss += self.crit_reg(output['reg'], batch['reg_mask'],
                                          batch['ind'], batch['reg']) / opt.num_stacks  # 

            if opt.id_weight > 0:                                                       # re-id loss: 1
                id_head = _tranpose_and_gather_feat(output['id'], batch['ind'])
                id_head = id_head[batch['reg_mask'] > 0].contiguous()
                id_head = self.emb_scale * F.normalize(id_head)
                id_target = batch['ids'][batch['reg_mask'] > 0]

                id_output = self.classifier(id_head).contiguous()
                if self.opt.id_loss == 'focal':
                    id_target_one_hot = id_output.new_zeros((id_head.size(0), self.nID)).scatter_(1,
                                                                                                  id_target.long().view(
                                                                                                      -1, 1), 1)
                    id_loss += sigmoid_focal_loss_jit(id_output, id_target_one_hot,
                                                      alpha=0.25, gamma=2.0, reduction="sum"
                                                      ) / id_output.size(0)
                else:
                    id_loss += self.IDLoss(id_output, id_target)                        # True

        det_loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + opt.off_weight * off_loss    # detect loss = heatmapLoss + wh_loss + reg_loss(为什么不把4个head权重也设置成可学习参数)
        if opt.multi_loss == 'uncertainty':                                                         # True: loss = det_loss + re-id_loss
            loss = torch.exp(-self.s_det) * det_loss + torch.exp(-self.s_id) * id_loss + (self.s_det + self.s_id)   # 可学习
            loss *= 0.5
        else:
            loss = det_loss + 0.1 * id_loss                                                                         # 固定

        loss_stats = {'loss': loss, 'hm_loss': hm_loss,
                      'wh_loss': wh_loss, 'off_loss': off_loss, 'id_loss': id_loss}                                 # loss: final total loss, loss_stats: 4 heads loss
        return loss, loss_stats                                                                                     

@CaptainEven
Copy link
Owner

@upsx You have misunderstood. We use multi-classifier rather than a classifier, as a result different classes use different IDs.

@upsx upsx closed this as completed Oct 14, 2021
@upsx upsx reopened this Oct 14, 2021
@upsx
Copy link
Author

upsx commented Oct 14, 2021

@upsx You have misunderstood. We use multi-classifier rather than a classifier, as a result different classes use different IDs.

oh,tkx for you reply。i have an additional question that whether you have done experiment to campare this next two options which is better:

  1. gen id together for multi class and use single classifier
  2. gen id respectively for multi class and use multi classifier

@CaptainEven
Copy link
Owner

@upsx I think it may not have much difference from the view of MOT, but may cause confusion if all classes share the IDs.

@upsx upsx changed the title should gen id togeger in all class for multi-class dataset? should gen id together in all class for multi-class dataset? Oct 14, 2021
@upsx
Copy link
Author

upsx commented Oct 14, 2021

@upsx I think it may not have much difference from the view of MOT, but may cause confusion if all classes share the IDs.

ok, tks 💖. i prepare to verify this problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants