Skip to content

Commit

Permalink
Fix some error and type (#8384)
Browse files Browse the repository at this point in the history
* Fix some error and type

* fix yolox

* update

* rename test
  • Loading branch information
hhaAndroid authored Jul 19, 2022
1 parent 8d4b809 commit e6f4115
Show file tree
Hide file tree
Showing 14 changed files with 36 additions and 47 deletions.
6 changes: 3 additions & 3 deletions configs/autoassign/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ Determining positive/negative samples for object detection is known as label ass

## Results and Models

| Backbone | Style | Lr schd | Mem (GB) | box AP | Config | Download |
|:---------:|:-------:|:-------:|:--------:|:------:|:------:|:--------:|
| R-50 | caffe | 1x | 4.08 | 40.4 | [config](./autoassign_r50_caffe_fpn_8x2_1x_coco.py) |[model](https://download.openmmlab.com/mmdetection/v2.0/autoassign/auto_assign_r50_fpn_1x_coco/auto_assign_r50_fpn_1x_coco_20210413_115540-5e17991f.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/autoassign/auto_assign_r50_fpn_1x_coco/auto_assign_r50_fpn_1x_coco_20210413_115540-5e17991f.log.json) |
| Backbone | Style | Lr schd | Mem (GB) | box AP | Config | Download |
| :------: | :---: | :-----: | :------: | :----: | :-------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| R-50 | caffe | 1x | 4.08 | 40.4 | [config](./autoassign_r50_caffe_fpn_8x2_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/autoassign/auto_assign_r50_fpn_1x_coco/auto_assign_r50_fpn_1x_coco_20210413_115540-5e17991f.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/autoassign/auto_assign_r50_fpn_1x_coco/auto_assign_r50_fpn_1x_coco_20210413_115540-5e17991f.log.json) |

**Note**:

Expand Down
4 changes: 1 addition & 3 deletions mmdet/datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,7 @@ def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]:
instance['bbox'] = bbox
instance['bbox_label'] = self.cat2label[ann['category_id']]

# TODO: Verify if there is a problem with offline evaluation?
# ignore segmentation if iscrowd is 1
if ann.get('segmentation', None) and instance['ignore_flag'] == 0:
if ann.get('segmentation', None):
instance['mask'] = ann['segmentation']

instances.append(instance)
Expand Down
9 changes: 7 additions & 2 deletions mmdet/datasets/transforms/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,13 @@ def transform(self, results: dict) -> dict:
if key not in results:
continue
if key == 'gt_masks':
# no need to consider gt_ignore_flags
instance_data[self.mapping_table[key]] = results[key]
if 'gt_ignore_flags' in results:
instance_data[
self.mapping_table[key]] = results[key][vaild_idx]
ignore_instance_data[
self.mapping_table[key]] = results[key][ignore_idx]
else:
instance_data[self.mapping_table[key]] = results[key]
else:
if 'gt_ignore_flags' in results:
instance_data[self.mapping_table[key]] = to_tensor(
Expand Down
44 changes: 11 additions & 33 deletions mmdet/datasets/transforms/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,6 @@ class FilterAnnotations(BaseTransform):
def __init__(self,
min_gt_bbox_wh: Tuple[int, int] = (1, 1),
min_gt_mask_area: int = 1,
keep_empty: bool = True,
by_box: bool = True,
by_mask: bool = False,
keep_empty: bool = True) -> None:
Expand All @@ -659,25 +658,9 @@ def transform(self, results: dict) -> Union[dict, None]:
Returns:
dict: Updated result dict.
"""
# gt_masks may not match with gt_bboxes, because gt_masks
# will not add into instances if ignore is True
if 'gt_ignore_flags' in results and 'gt_masks' in results:
vaild_idx = np.where(results['gt_ignore_flags'] == 0)[0]
keys = ('gt_bboxes', 'gt_bboxes_labels', 'gt_ignore_flags')
for key in keys:
if key in results:
results[key] = results[key][vaild_idx]

if self.by_box:
assert 'gt_bboxes' in results
gt_bboxes = results['gt_bboxes']
instance_num = gt_bboxes.shape[0]
if self.by_mask:
assert 'gt_masks' in results
gt_masks = results['gt_masks']
instance_num = len(gt_masks)

if instance_num == 0:
assert 'gt_bboxes' in results
gt_bboxes = results['gt_bboxes']
if gt_bboxes.shape[0] == 0:
return results

tests = []
Expand All @@ -687,29 +670,24 @@ def transform(self, results: dict) -> Union[dict, None]:
tests.append((w > self.min_gt_bbox_wh[0])
& (h > self.min_gt_bbox_wh[1]))
if self.by_mask:
assert 'gt_masks' in results
gt_masks = results['gt_masks']
tests.append(gt_masks.areas >= self.min_gt_mask_area)

keep = tests[0]
for t in tests[1:]:
keep = keep & t

keys = ('gt_bboxes', 'gt_labels', 'gt_masks')
for key in keys:
if key in results:
results[key] = results[key][keep]
if not keep.any():
if self.keep_empty:
return None
else:
return results
else:
keys = ('gt_bboxes', 'gt_bboxes_labels', 'gt_masks',
'gt_ignore_flags')
for key in keys:
if key in results:
results[key] = results[key][keep]
return results

keys = ('gt_bboxes', 'gt_bboxes_labels', 'gt_masks', 'gt_ignore_flags')
for key in keys:
if key in results:
results[key] = results[key][keep]

return results

def __repr__(self):
return self.__class__.__name__ + \
Expand Down
1 change: 0 additions & 1 deletion mmdet/models/dense_heads/yolo_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import warnings
from typing import List, Optional, Sequence, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down
1 change: 0 additions & 1 deletion mmdet/models/dense_heads/yolox_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import math
from typing import List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down
10 changes: 9 additions & 1 deletion mmdet/models/task_modules/assigners/atss_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class ATSSAssigner(BaseAssigner):
Args:
topk (int): number of priors selected in each level
alpha (float, optional): param of cost rate for each proposal only
in DDOD. Defaults to None.
iou_calculator (:obj:`ConfigDict` or dict): Config dict for iou
calculator. Defaults to ``dict(type='BboxOverlaps2D')``
ignore_iof_thr (float): IoF threshold for ignoring bboxes (if
Expand All @@ -60,14 +62,15 @@ class ATSSAssigner(BaseAssigner):

def __init__(self,
topk: int,
alpha=None,
alpha: Optional[float] = None,
iou_calculator: ConfigType = dict(type='BboxOverlaps2D'),
ignore_iof_thr: float = -1) -> None:
self.topk = topk
self.alpha = alpha
self.iou_calculator = TASK_UTILS.build(iou_calculator)
self.ignore_iof_thr = ignore_iof_thr

# https://github.com/sfzhang15/ATSS/blob/master/atss_core/modeling/rpn/atss/loss.py
def assign(
self,
pred_instances: InstanceData,
Expand Down Expand Up @@ -124,6 +127,11 @@ def assign(
priors = priors[:, :4]
num_gt, num_priors = gt_bboxes.size(0), priors.size(0)

message = 'Invalid alpha parameter because cls_scores or ' \
'bbox_preds are None. If you want to use the ' \
'cost-based ATSSAssigner, please set cls_scores, ' \
'bbox_preds and self.alpha at the same time. '

# compute iou between all bbox and gt
if self.alpha is None:
# ATSSAssigner
Expand Down
1 change: 1 addition & 0 deletions mmdet/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .memory import AvoidCUDAOOM, AvoidOOM
from .misc import find_latest_checkpoint, update_data_root
from .parallel import MMDataParallel, MMDistributedDataParallel
from .replace_cfg_vals import replace_cfg_vals
from .setup_env import register_all_modules, setup_multi_processes
from .split_batch import split_batch
from .typing import (ConfigType, InstanceList, MultiConfig, OptConfigType,
Expand Down
2 changes: 1 addition & 1 deletion requirements/tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ memory_profiler
-e git+https://github.com/open-mmlab/mmtracking#egg=mmtrack
onnx==1.7.0
onnxruntime>=1.8.0
protobuf<=3.20.1
parameterized
protobuf<=3.20.1
psutil
pytest
ubelt
Expand Down
2 changes: 1 addition & 1 deletion tests/test_datasets/test_transforms/test_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def setUp(self):
'img': rng.rand(300, 400),
'gt_seg_map': rng.rand(300, 400),
'gt_masks':
BitmapMasks(rng.rand(2, 300, 400), height=300, width=400),
BitmapMasks(rng.rand(3, 300, 400), height=300, width=400),
'gt_bboxes_labels': rng.rand(3, ),
'gt_ignore_flags': np.array([0, 0, 1], dtype=np.bool),
'proposals': rng.rand(2, 4)
Expand Down
1 change: 1 addition & 0 deletions tests/test_datasets/test_transforms/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def setUp(self):
}, {
'bbox': [50, 50, 60, 80],
'bbox_label': 2,
'mask': [[50, 50, 60, 50, 60, 80, 50, 80]],
'ignore_flag': 1
}]
}
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from mmengine.runner import Runner

from mmdet.registry import RUNNERS
from mmdet.utils import register_all_modules, replace_cfg_vals
from mmdet.utils import register_all_modules


# TODO: support fuse_conv_bn and format_only
Expand Down

0 comments on commit e6f4115

Please sign in to comment.