From 2a0d7ec2b77de9f343b383fd0497c89e23fd3dc8 Mon Sep 17 00:00:00 2001 From: Thijs van der Burgt Date: Thu, 7 Dec 2023 08:24:08 +0100 Subject: [PATCH] Fix class agnostic NMS when #detections exceeds split_thr --- nanodet/model/module/nms.py | 45 +++++++++++----------- tests/test_models/test_modules/test_nms.py | 25 ++++++++++++ 2 files changed, 47 insertions(+), 23 deletions(-) diff --git a/nanodet/model/module/nms.py b/nanodet/model/module/nms.py index e5fa3e216..414a65c97 100644 --- a/nanodet/model/module/nms.py +++ b/nanodet/model/module/nms.py @@ -87,36 +87,35 @@ def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False): If the number of boxes is greater than the threshold, it will perform NMS on each group of boxes separately and sequentially. Defaults to 10000. - class_agnostic (bool): if true, nms is class agnostic, - i.e. IoU thresholding happens over all boxes, - regardless of the predicted class. + class_agnostic (bool): if true, nms is class agnostic, i.e. IoU + thresholding happens over all boxes, regardless of the predicted + class. Defaults to False. Returns: tuple: kept dets and indice. """ nms_cfg_ = nms_cfg.copy() - class_agnostic = nms_cfg_.pop("class_agnostic", class_agnostic) - if class_agnostic: - boxes_for_nms = boxes - else: - max_coordinate = boxes.max() - offsets = idxs.to(boxes) * (max_coordinate + 1) - boxes_for_nms = boxes + offsets[:, None] nms_cfg_.pop("type", "nms") split_thr = nms_cfg_.pop("split_thr", 10000) - if len(boxes_for_nms) < split_thr: + + if class_agnostic: + boxes_for_nms = boxes keep = nms(boxes_for_nms, scores, **nms_cfg_) - boxes = boxes[keep] - scores = scores[keep] else: - total_mask = scores.new_zeros(scores.size(), dtype=torch.bool) - for id in torch.unique(idxs): - mask = (idxs == id).nonzero(as_tuple=False).view(-1) - keep = nms(boxes_for_nms[mask], scores[mask], **nms_cfg_) - total_mask[mask[keep]] = True - - keep = total_mask.nonzero(as_tuple=False).view(-1) - keep = keep[scores[keep].argsort(descending=True)] - boxes = boxes[keep] - scores = scores[keep] + if len(boxes) > split_thr: + total_mask = scores.new_zeros(scores.size(), dtype=torch.bool) + for i in torch.unique(idxs): + mask = (idxs == i).nonzero(as_tuple=False).view(-1) + keep = nms(boxes[mask], scores[mask], **nms_cfg_) + total_mask[mask[keep]] = True + keep = total_mask.nonzero(as_tuple=False).view(-1) + keep = keep[scores[keep].argsort(descending=True)] + else: + max_coordinate = boxes.max() + offsets = idxs.to(boxes) * (max_coordinate + 1) + boxes_for_nms = boxes + offsets[:, None] + keep = nms(boxes_for_nms, scores, **nms_cfg_) + + boxes = boxes[keep] + scores = scores[keep] return torch.cat([boxes, scores[:, None]], -1), keep diff --git a/tests/test_models/test_modules/test_nms.py b/tests/test_models/test_modules/test_nms.py index 902f76cfa..dc68d25d9 100644 --- a/tests/test_models/test_modules/test_nms.py +++ b/tests/test_models/test_modules/test_nms.py @@ -57,3 +57,28 @@ def test_multiclass_nms(): ) assert boxes.shape[0] == 0 assert keep.shape[0] == 0 + +def test_class_agnostic_nms(): + file = open("./tests/assets/data/batched_nms_data.pkl", "rb") + results = pickle.load(file) + + nms_cfg = dict(iou_threshold=0.7) + boxes, keep = batched_nms( + torch.from_numpy(results["boxes"]), + torch.from_numpy(results["scores"]), + torch.from_numpy(results["idxs"]), + nms_cfg, + class_agnostic=True, + ) + + nms_cfg.update(split_thr=100) + seq_boxes, seq_keep = batched_nms( + torch.from_numpy(results["boxes"]), + torch.from_numpy(results["scores"]), + torch.from_numpy(results["idxs"]), + nms_cfg, + class_agnostic=True, + ) + + assert torch.equal(keep, seq_keep) + assert torch.equal(boxes, seq_boxes)