Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support DSVT training #2738

Merged
merged 28 commits into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mmdet3d/models/dense_heads/centerpoint_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def forward(self, x):
Returns:
dict[str: torch.Tensor]: contains the following keys:

-reg torch.Tensor): 2D regression value with the
sunjiahao1999 marked this conversation as resolved.
Show resolved Hide resolved
-reg (torch.Tensor): 2D regression value with the
shape of [B, 2, H, W].
-height (torch.Tensor): Height value with the
shape of [B, 1, H, W].
Expand Down Expand Up @@ -217,7 +217,7 @@ def forward(self, x):
Returns:
dict[str: torch.Tensor]: contains the following keys:

-reg torch.Tensor): 2D regression value with the
-reg (torch.Tensor): 2D regression value with the
shape of [B, 2, H, W].
-height (torch.Tensor): Height value with the
shape of [B, 1, H, W].
Expand Down
18 changes: 11 additions & 7 deletions mmdet3d/models/necks/second_fpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ class SECONDFPN(BaseModule):
upsample_cfg (dict): Config dict of upsample layers.
conv_cfg (dict): Config dict of conv layers.
use_conv_for_no_stride (bool): Whether to use conv when stride is 1.
init_cfg (dict or :obj:`ConfigDict` or list[dict or :obj:`ConfigDict`],
optional): Initialization config dict. Defaults to
[dict(type='Kaiming', layer='ConvTranspose2d'),
dict(type='Constant', layer='NaiveSyncBatchNorm2d', val=1.0)].
"""

def __init__(self,
Expand All @@ -31,7 +35,13 @@ def __init__(self,
upsample_cfg=dict(type='deconv', bias=False),
conv_cfg=dict(type='Conv2d', bias=False),
use_conv_for_no_stride=False,
init_cfg=None):
init_cfg=[
dict(type='Kaiming', layer='ConvTranspose2d'),
dict(
type='Constant',
layer='NaiveSyncBatchNorm2d',
val=1.0)
]):
# if for GroupNorm,
# cfg is dict(type='GN', num_groups=num_groups, eps=1e-3, affine=True)
super(SECONDFPN, self).__init__(init_cfg=init_cfg)
Expand Down Expand Up @@ -64,12 +74,6 @@ def __init__(self,
deblocks.append(deblock)
self.deblocks = nn.ModuleList(deblocks)

if init_cfg is None:
sunjiahao1999 marked this conversation as resolved.
Show resolved Hide resolved
self.init_cfg = [
dict(type='Kaiming', layer='ConvTranspose2d'),
dict(type='Constant', layer='NaiveSyncBatchNorm2d', val=1.0)
]

def forward(self, x):
"""Forward function.

Expand Down
13 changes: 7 additions & 6 deletions mmdet3d/structures/bbox_3d/base_box3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,12 +275,13 @@ def in_range_3d(
Tensor: A binary vector indicating whether each point is inside the
reference range.
"""
in_range_flags = ((self.tensor[:, 0] > box_range[0])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bbox3d.in_range_3d should filter out bboxes where the gravity center of the bbox is not in the target range, not the bottom center. Btw, this function not use the training/testing pipeline of any other model, except for DSVT. Therefore, this modification does not cause BC-breaking.

& (self.tensor[:, 1] > box_range[1])
& (self.tensor[:, 2] > box_range[2])
& (self.tensor[:, 0] < box_range[3])
& (self.tensor[:, 1] < box_range[4])
& (self.tensor[:, 2] < box_range[5]))
gravity_center = self.gravity_center
in_range_flags = ((gravity_center[:, 0] > box_range[0])
& (gravity_center[:, 1] > box_range[1])
& (gravity_center[:, 2] > box_range[2])
& (gravity_center[:, 0] < box_range[3])
& (gravity_center[:, 1] < box_range[4])
& (gravity_center[:, 2] < box_range[5]))
return in_range_flags

@abstractmethod
Expand Down
18 changes: 13 additions & 5 deletions projects/DSVT/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,25 @@ python tools/test.py projects/DSVT/configs/dsvt_voxel032_res-second_secfpn_8xb1-

### Training commands

The support of training DSVT is on the way.
In MMDetection3D's root directory, run the following command to test the model:

```bash
tools/dist_train.sh projects/DSVT/configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py 8 --sync_bn torch
```

## Results and models

### Waymo

| Middle Encoder | Backbone | Load Interval | Voxel type (voxel size) | Multi-Class NMS | Multi-frames | Mem (GB) | Inf time (fps) | mAP@L1 | mAPH@L1 | mAP@L2 | **mAPH@L2** | Download |
| :------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------: | :-----------: | :---------------------: | :-------------: | :----------: | :------: | :------------: | :----: | :-----: | :----: | :---------: | :------: |
| [DSVT](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | [ResSECOND](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | 5 | voxel (0.32) | ✓ | × | | | 75.2 | 72.2 | 68.9 | 66.1 | |
| Middle Encoder | Backbone | Load Interval | Voxel type (voxel size) | Multi-Class NMS | Multi-frames | mAP@L1 | mAPH@L1 | mAP@L2 | **mAPH@L2** | Download |
| :------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------: | :-----------: | :---------------------: | :-------------: | :----------: | :----: | :-----: | :----: | :---------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| [DSVT](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | [ResSECOND](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | 5 | voxel (0.32) | ✓ | × | 75.5 | 72.4 | 69.2 | 66.3 | \[log\](\<https://download.openmmlab.com/mmdetection3d/v1.1.0_models/dsvt/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class_20230917_102130.log) |

**Note**:

- `ResSECOND` denotes the base block in SECOND has residual layers.

**Note** that `ResSECOND` denotes the base block in SECOND has residual layers.
- Regrettably, we are unable to provide the pre-trained model weights due to [Waymo Dataset License Agreement](https://waymo.com/open/terms/), so we only provide the training logs as shown above.

## Citation

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,25 +88,28 @@
loss_cls=dict(
type='mmdet.GaussianFocalLoss', reduction='mean', loss_weight=1.0),
loss_bbox=dict(type='mmdet.L1Loss', reduction='mean', loss_weight=2.0),
loss_iou=dict(type='mmdet.L1Loss', reduction='sum', loss_weight=1.0),
loss_reg_iou=dict(
type='mmdet3d.DIoU3DLoss', reduction='mean', loss_weight=2.0),
norm_bbox=True),
# model training and testing settings
train_cfg=dict(
pts=dict(
grid_size=grid_size,
voxel_size=voxel_size,
out_size_factor=4,
dense_reg=1,
gaussian_overlap=0.1,
max_objs=500,
min_radius=2,
code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])),
grid_size=grid_size,
voxel_size=voxel_size,
point_cloud_range=point_cloud_range,
out_size_factor=1,
dense_reg=1,
gaussian_overlap=0.1,
max_objs=500,
min_radius=2,
code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]),
test_cfg=dict(
max_per_img=500,
max_pool_nms=False,
min_radius=[4, 12, 10, 1, 0.85, 0.175],
iou_rectifier=[[0.68, 0.71, 0.65]],
pc_range=[-80, -80],
out_size_factor=4,
out_size_factor=1,
voxel_size=voxel_size[:2],
nms_type='rotate',
multi_class_nms=True,
Expand All @@ -128,6 +131,8 @@
coord_type='LIDAR',
load_dim=6,
use_dim=[0, 1, 2, 3, 4],
norm_intensity=True,
norm_elongation=True,
backend_args=backend_args),
backend_args=backend_args)

Expand All @@ -138,25 +143,22 @@
load_dim=6,
use_dim=5,
norm_intensity=True,
norm_elongation=True,
backend_args=backend_args),
# Add this if using `MultiFrameDeformableDecoderRPN`
# dict(
# type='LoadPointsFromMultiSweeps',
# sweeps_num=9,
# load_dim=6,
# use_dim=[0, 1, 2, 3, 4],
# pad_empty_sweeps=True,
# remove_close=True),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
dict(type='ObjectSample', db_sampler=db_sampler),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.5,
flip_ratio_bev_vertical=0.5),
dict(
type='GlobalRotScaleTrans',
rot_range=[-0.78539816, 0.78539816],
scale_ratio_range=[0.95, 1.05],
translation_std=[0.5, 0.5, 0]),
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectNameFilter', classes=class_names),
translation_std=[0.5, 0.5, 0.5]),
dict(type='PointsRangeFilter3D', point_cloud_range=point_cloud_range),
dict(type='ObjectRangeFilter3D', point_cloud_range=point_cloud_range),
dict(type='PointShuffle'),
dict(
type='Pack3DDetInputs',
Expand All @@ -172,25 +174,34 @@
norm_intensity=True,
norm_elongation=True,
backend_args=backend_args),
dict(type='PointsRangeFilter3D', point_cloud_range=point_cloud_range),
dict(
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
pts_scale_ratio=1,
flip=False,
transforms=[
dict(
type='GlobalRotScaleTrans',
rot_range=[0, 0],
scale_ratio_range=[1., 1.],
translation_std=[0, 0, 0]),
dict(type='RandomFlip3D'),
dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range)
]),
dict(type='Pack3DDetInputs', keys=['points'])
type='Pack3DDetInputs',
keys=['points'],
meta_keys=['box_type_3d', 'sample_idx', 'context_name', 'timestamp'])
]

dataset_type = 'WaymoDataset'
train_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='waymo_infos_train.pkl',
data_prefix=dict(pts='training/velodyne', sweeps='training/velodyne'),
pipeline=train_pipeline,
modality=input_modality,
test_mode=False,
metainfo=metainfo,
# we use box_type_3d='LiDAR' in kitti and nuscenes dataset
# and box_type_3d='Depth' in sunrgbd and scannet dataset.
box_type_3d='LiDAR',
# load one frame every five frames
load_interval=5,
backend_args=backend_args))
val_dataloader = dict(
batch_size=4,
num_workers=4,
Expand All @@ -212,18 +223,59 @@

val_evaluator = dict(
type='WaymoMetric',
ann_file='./data/waymo/kitti_format/waymo_infos_val.pkl',
waymo_bin_file='./data/waymo/waymo_format/gt.bin',
data_root='./data/waymo/waymo_format',
backend_args=backend_args,
convert_kitti_format=False,
idx2metainfo='./data/waymo/waymo_format/idx2metainfo.pkl')
result_prefix='./dsvt_pred')
test_evaluator = val_evaluator

vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer')

# schedules
lr = 1e-5
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=lr, weight_decay=0.05, betas=(0.9, 0.99)),
clip_grad=dict(max_norm=10, norm_type=2))
param_scheduler = [
dict(
type='CosineAnnealingLR',
T_max=1.2,
eta_min=lr * 100,
begin=0,
end=1.2,
by_epoch=True,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=10.8,
eta_min=lr * 1e-4,
begin=1.2,
end=12,
by_epoch=True,
convert_to_iter_based=True),
# momentum scheduler
dict(
type='CosineAnnealingMomentum',
T_max=1.2,
eta_min=0.85,
begin=0,
end=1.2,
by_epoch=True,
convert_to_iter_based=True),
dict(
type='CosineAnnealingMomentum',
T_max=10.8,
eta_min=0.95,
begin=1.2,
end=12,
by_epoch=True,
convert_to_iter_based=True)
]

# runtime settings
train_cfg = dict(by_epoch=True, max_epochs=12, val_interval=1)

# runtime settings
val_cfg = dict()
test_cfg = dict()
Expand All @@ -236,4 +288,12 @@

default_hooks = dict(
logger=dict(type='LoggerHook', interval=50),
checkpoint=dict(type='CheckpointHook', interval=5))
checkpoint=dict(type='CheckpointHook', interval=1))
custom_hooks = [
dict(
type='DisableAugHook',
disable_after_epoch=11,
disable_aug_list=[
'GlobalRotScaleTrans', 'RandomFlip3D', 'ObjectSample'
])
]
5 changes: 4 additions & 1 deletion projects/DSVT/dsvt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from .disable_aug_hook import DisableAugHook
from .dsvt import DSVT
from .dsvt_head import DSVTCenterHead
from .dsvt_transformer import DSVTMiddleEncoder
from .dynamic_pillar_vfe import DynamicPillarVFE3D
from .map2bev import PointPillarsScatter3D
from .res_second import ResSECOND
from .transforms_3d import ObjectRangeFilter3D, PointsRangeFilter3D
from .utils import DSVTBBoxCoder

__all__ = [
'DSVTCenterHead', 'DSVT', 'DSVTMiddleEncoder', 'DynamicPillarVFE3D',
'PointPillarsScatter3D', 'ResSECOND', 'DSVTBBoxCoder'
'PointPillarsScatter3D', 'ResSECOND', 'DSVTBBoxCoder',
'ObjectRangeFilter3D', 'PointsRangeFilter3D', 'DisableAugHook'
]
Loading
Loading