From a4f0f73aa4df5a2181688801a3b0132a99950b78 Mon Sep 17 00:00:00 2001 From: rotorliu Date: Thu, 8 Aug 2024 10:52:32 +0800 Subject: [PATCH] Update nanodet_plus_head.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复多GPU情况下,len(pos_inds) == 0时(当前Batch中的所有图像中均没有检测框真值),reduce_mean不执行导致的dist.all_reduce空等待,卡死的问题。 --- nanodet/model/head/nanodet_plus_head.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/nanodet/model/head/nanodet_plus_head.py b/nanodet/model/head/nanodet_plus_head.py index 09fc2e597..7bad82a2f 100644 --- a/nanodet/model/head/nanodet_plus_head.py +++ b/nanodet/model/head/nanodet_plus_head.py @@ -266,10 +266,9 @@ def _get_loss_from_assign(self, cls_preds, reg_preds, decoded_bboxes, assign): (labels >= 0) & (labels < self.num_classes), as_tuple=False ).squeeze(1) + weight_targets = cls_preds[pos_inds].detach().sigmoid().max(dim=1)[0] + bbox_avg_factor = max(reduce_mean(weight_targets.sum()).item(), 1.0) if len(pos_inds) > 0: - weight_targets = cls_preds[pos_inds].detach().sigmoid().max(dim=1)[0] - bbox_avg_factor = max(reduce_mean(weight_targets.sum()).item(), 1.0) - loss_bbox = self.loss_bbox( decoded_bboxes[pos_inds], bbox_targets[pos_inds],