Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ [Update] BoxMatcher matching criteria #125

Merged
merged 4 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion yolo/tools/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(self, loss_cfg: LossConfig, vec2box: Vec2Box, class_num: int = 80,
self.dfl = DFLoss(vec2box, reg_max)
self.iou = BoxLoss()

self.matcher = BoxMatcher(loss_cfg.matcher, self.class_num, vec2box.anchor_grid)
self.matcher = BoxMatcher(loss_cfg.matcher, self.class_num, vec2box, reg_max)

def separate_anchor(self, anchors):
"""
Expand Down
104 changes: 71 additions & 33 deletions yolo/utils/bounding_box_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor, tensor
from torchmetrics.detection import MeanAveragePrecision
Expand Down Expand Up @@ -143,28 +142,35 @@ def generate_anchors(image_size: List[int], strides: List[int]):


class BoxMatcher:
def __init__(self, cfg: MatcherConfig, class_num: int, anchors: Tensor) -> None:
def __init__(self, cfg: MatcherConfig, class_num: int, vec2box, reg_max: int) -> None:
self.class_num = class_num
self.anchors = anchors
self.vec2box = vec2box
self.reg_max = reg_max
for attr_name in cfg:
setattr(self, attr_name, cfg[attr_name])

def get_valid_matrix(self, target_bbox: Tensor):
"""
Get a boolean mask that indicates whether each target bounding box overlaps with each anchor.
Get a boolean mask that indicates whether each target bounding box overlaps with each anchor
and is able to correctly predict it with the available reg_max value.

Args:
target_bbox [batch x targets x 4]: The bounding box of each targets.
target_bbox [batch x targets x 4]: The bounding box of each target.
Returns:
[batch x targets x anchors]: A boolean tensor indicates if target bounding box overlaps with anchors.
[batch x targets x anchors]: A boolean tensor indicates if target bounding box overlaps
with the anchors, and the anchor is able to predict the target.
"""
Xmin, Ymin, Xmax, Ymax = target_bbox[:, :, None].unbind(3)
anchors = self.anchors[None, None] # add a axis at first, second dimension
x_min, y_min, x_max, y_max = target_bbox[:, :, None].unbind(3)
anchors = self.vec2box.anchor_grid[None, None] # add a axis at first, second dimension
anchors_x, anchors_y = anchors.unbind(dim=3)
target_in_x = (Xmin < anchors_x) & (anchors_x < Xmax)
target_in_y = (Ymin < anchors_y) & (anchors_y < Ymax)
target_on_anchor = target_in_x & target_in_y
return target_on_anchor
x_min_dist, x_max_dist = anchors_x - x_min, x_max - anchors_x
y_min_dist, y_max_dist = anchors_y - y_min, y_max - anchors_y
targets_dist = torch.stack((x_min_dist, y_min_dist, x_max_dist, y_max_dist), dim=-1)
targets_dist /= self.vec2box.scaler[None, None, :, None] # (1, 1, anchors, 1)
min_reg_dist, max_reg_dist = targets_dist.amin(dim=-1), targets_dist.amax(dim=-1)
target_on_anchor = min_reg_dist >= 0
target_in_reg_max = max_reg_dist <= self.reg_max - 1.01
return target_on_anchor & target_in_reg_max

def get_cls_matrix(self, predict_cls: Tensor, target_cls: Tensor) -> Tensor:
"""
Expand Down Expand Up @@ -194,40 +200,68 @@ def get_iou_matrix(self, predict_bbox, target_bbox) -> Tensor:
"""
return calculate_iou(target_bbox, predict_bbox, self.iou).clamp(0, 1)

def filter_topk(self, target_matrix: Tensor, topk: int = 10) -> Tuple[Tensor, Tensor]:
def filter_topk(self, target_matrix: Tensor, grid_mask: Tensor, topk: int = 10) -> Tuple[Tensor, Tensor]:
"""
Filter the top-k suitability of targets for each anchor.

Args:
target_matrix [batch x targets x anchors]: The suitability for each targets-anchors
grid_mask [batch x targets x anchors]: The match validity for each target to anchors
topk (int, optional): Number of top scores to retain per anchor.

Returns:
topk_targets [batch x targets x anchors]: Only leave the topk targets for each anchor
topk_masks [batch x targets x anchors]: A boolean mask indicating the top-k scores' positions.
topk_mask [batch x targets x anchors]: A boolean mask indicating the top-k scores' positions.
"""
values, indices = target_matrix.topk(topk, dim=-1)
masked_target_matrix = grid_mask * target_matrix
values, indices = masked_target_matrix.topk(topk, dim=-1)
topk_targets = torch.zeros_like(target_matrix, device=target_matrix.device)
topk_targets.scatter_(dim=-1, index=indices, src=values)
topk_masks = topk_targets > 0
return topk_targets, topk_masks
topk_mask = topk_targets > 0
return topk_targets, topk_mask

def filter_duplicates(self, iou_mat: Tensor, topk_mask: Tensor, grid_mask: Tensor):
def ensure_one_anchor(self, target_matrix: Tensor, topk_mask: tensor) -> Tensor:
"""
Filter the maximum suitability target index of each anchor.
Ensures each valid target gets at least one anchor matched based on the unmasked target matrix,
which enables an otherwise invalid match. This enables too small or too large targets to be
learned as well, even if they can't be predicted perfectly.

Args:
iou_mat [batch x targets x anchors]: The suitability for each targets-anchors
target_matrix [batch x targets x anchors]: The suitability for each targets-anchors
topk_mask [batch x targets x anchors]: A boolean mask indicating the top-k scores' positions.

Returns:
topk_mask [batch x targets x anchors]: A boolean mask indicating the updated top-k scores' positions.
"""
values, indices = target_matrix.max(dim=-1)
best_anchor_mask = torch.zeros_like(target_matrix, dtype=torch.bool)
best_anchor_mask.scatter_(-1, index=indices[..., None], src=~best_anchor_mask)
matched_anchor_num = torch.sum(topk_mask, dim=-1)
target_without_anchor = (matched_anchor_num == 0) & (values > 0)
topk_mask = torch.where(target_without_anchor[..., None], best_anchor_mask, topk_mask)
return topk_mask

def filter_duplicates(self, iou_mat: Tensor, topk_mask: Tensor):
"""
Filter the maximum suitability target index of each anchor based on IoU.

Args:
iou_mat [batch x targets x anchors]: The IoU for each targets-anchors
topk_mask [batch x targets x anchors]: A boolean mask indicating the top-k scores' positions.

Returns:
unique_indices [batch x anchors x 1]: The index of the best targets for each anchors
valid_mask [batch x anchors]: Mask indicating the validity of each anchor
topk_mask [batch x targets x anchors]: A boolean mask indicating the updated top-k scores' positions.
"""
duplicates = (topk_mask.sum(1, keepdim=True) > 1).repeat([1, topk_mask.size(1), 1])
max_idx = F.one_hot(iou_mat.argmax(1), topk_mask.size(1)).permute(0, 2, 1)
topk_mask = torch.where(duplicates, max_idx, topk_mask)
topk_mask &= grid_mask
unique_indices = topk_mask.argmax(dim=1)
return unique_indices[..., None], topk_mask.sum(1), topk_mask
masked_iou_mat = topk_mask * iou_mat
best_indices = masked_iou_mat.argmax(1)[:, None, :]
best_target_mask = torch.zeros_like(duplicates, dtype=torch.bool)
best_target_mask.scatter_(1, index=best_indices, src=~best_target_mask)
topk_mask = torch.where(duplicates, best_target_mask, topk_mask)
unique_indices = topk_mask.to(torch.uint8).argmax(dim=1)
return unique_indices[..., None], topk_mask.any(dim=1), topk_mask

def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tensor]:
"""Matches each target to the most suitable anchor.
Expand Down Expand Up @@ -273,17 +307,21 @@ def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tens
# get cls matrix (cls prob with each gt class and each predict class)
cls_mat = self.get_cls_matrix(predict_cls.sigmoid(), target_cls)

target_matrix = grid_mask * (iou_mat ** self.factor["iou"]) * (cls_mat ** self.factor["cls"])
target_matrix = (iou_mat ** self.factor["iou"]) * (cls_mat ** self.factor["cls"])

# choose topk
topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)
topk_targets, topk_mask = self.filter_topk(target_matrix, grid_mask, topk=self.topk)

# match best anchor to valid targets without valid anchors
topk_mask = self.ensure_one_anchor(target_matrix, topk_mask)

# delete one anchor pred assign to mutliple gts
unique_indices, valid_mask, topk_mask = self.filter_duplicates(iou_mat, topk_mask, grid_mask)
unique_indices, valid_mask, topk_mask = self.filter_duplicates(iou_mat, topk_mask)

align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4))
align_cls = torch.gather(target_cls, 1, unique_indices).squeeze(-1)
align_cls = F.one_hot(align_cls, self.class_num)
align_cls_indices = torch.gather(target_cls, 1, unique_indices)
align_cls = torch.zeros_like(align_cls_indices, dtype=torch.bool).repeat(1, 1, self.class_num)
align_cls.scatter_(-1, index=align_cls_indices, src=~align_cls)

# normalize class ditribution
iou_mat *= topk_mask
Expand All @@ -294,7 +332,7 @@ def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tens
normalize_term = normalize_term.permute(0, 2, 1).gather(2, unique_indices)
align_cls = align_cls * normalize_term * valid_mask[:, :, None]
anchor_matched_targets = torch.cat([align_cls, align_bbox], dim=-1)
return anchor_matched_targets, valid_mask.bool()
return anchor_matched_targets, valid_mask


class Vec2Box:
Expand All @@ -305,7 +343,7 @@ def __init__(self, model: YOLO, anchor_cfg: AnchorConfig, image_size, device):
logger.info(f":japanese_not_free_of_charge_button: Found stride of model {anchor_cfg.strides}")
self.strides = anchor_cfg.strides
else:
logger.info("🧸 Found no stride of model, performed a dummy test for auto-anchor size")
logger.info(":teddy_bear: Found no stride of model, performed a dummy test for auto-anchor size")
self.strides = self.create_auto_anchor(model, image_size)

anchor_grid, scaler = generate_anchors(image_size, self.strides)
Expand Down Expand Up @@ -358,7 +396,7 @@ def __init__(self, model: YOLO, anchor_cfg: AnchorConfig, image_size, device):
logger.info(f":japanese_not_free_of_charge_button: Found stride of model {anchor_cfg.strides}")
self.strides = anchor_cfg.strides
else:
logger.info("🧸 Found no stride of model, performed a dummy test for auto-anchor size")
logger.info(":teddy_bear: Found no stride of model, performed a dummy test for auto-anchor size")
self.strides = self.create_auto_anchor(model, image_size)

self.head_num = len(anchor_cfg.anchor)
Expand Down
Loading