Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Refactor structure of cylinder3d #2711

Open
wants to merge 11 commits into
base: dev-1.x
Choose a base branch
from
21 changes: 11 additions & 10 deletions configs/_base_/models/cylinder3d.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
grid_shape = [480, 360, 32]
point_cloud_range = [0, -3.14159265359, -4, 50, 3.14159265359, 2]
model = dict(
type='Cylinder3D',
data_preprocessor=dict(
Expand All @@ -7,18 +8,17 @@
voxel_type='cylindrical',
voxel_layer=dict(
grid_shape=grid_shape,
point_cloud_range=[0, -3.14159265359, -4, 50, 3.14159265359, 2],
point_cloud_range=point_cloud_range,
max_num_points=-1,
max_voxels=-1,
),
),
max_voxels=-1)),
voxel_encoder=dict(
type='SegVFE',
feat_channels=[64, 128, 256, 256],
in_channels=6,
feat_channels=[64, 128, 256, 256],
with_voxel_center=True,
feat_compression=16,
return_point_feats=False),
grid_shape=grid_shape,
point_cloud_range=point_cloud_range,
feat_compression=16),
backbone=dict(
type='Asymm3DSpconv',
grid_size=grid_shape,
Expand All @@ -29,13 +29,14 @@
type='Cylinder3DHead',
channels=128,
num_classes=20,
dropout_ratio=0,
loss_ce=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=False,
class_weight=None,
loss_weight=1.0),
loss_lovasz=dict(type='LovaszLoss', loss_weight=1.0, reduction='none'),
),
conv_seg_kernel_size=3,
ignore_index=19),
train_cfg=None,
test_cfg=dict(mode='whole'),
)
test_cfg=dict(mode='whole'))
2 changes: 1 addition & 1 deletion configs/_base_/models/dgcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
act_cfg=dict(type='LeakyReLU', negative_slope=0.2),
loss_decode=dict(
loss_ce=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=False,
class_weight=None, # modified with dataset
Expand Down
2 changes: 1 addition & 1 deletion configs/_base_/models/minkunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
channels=96,
num_classes=19,
dropout_ratio=0,
loss_decode=dict(type='mmdet.CrossEntropyLoss', avg_non_ignore=True),
loss_ce=dict(type='mmdet.CrossEntropyLoss', avg_non_ignore=True),
ignore_index=19),
train_cfg=dict(),
test_cfg=dict())
2 changes: 1 addition & 1 deletion configs/_base_/models/paconv_ssg.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
act_cfg=dict(type='ReLU'),
loss_decode=dict(
loss_ce=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=False,
class_weight=None, # should be modified with dataset
Expand Down
2 changes: 1 addition & 1 deletion configs/_base_/models/pointnet2_ssg.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
act_cfg=dict(type='ReLU'),
loss_decode=dict(
loss_ce=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=False,
class_weight=None, # should be modified with dataset
Expand Down
2 changes: 1 addition & 1 deletion configs/_base_/models/spvcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
channels=96,
num_classes=19,
dropout_ratio=0,
loss_decode=dict(type='mmdet.CrossEntropyLoss', avg_non_ignore=True),
loss_ce=dict(type='mmdet.CrossEntropyLoss', avg_non_ignore=True),
ignore_index=19),
train_cfg=dict(),
test_cfg=dict())
4 changes: 2 additions & 2 deletions configs/cylinder3d/cylinder3d_4xb4-3x_semantickitti.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@
gamma=0.1)
]

train_dataloader = dict(batch_size=4, )
train_dataloader = dict(batch_size=4)

# Default setting for scaling LR automatically
# - `enable` means enable scaling LR automatically
# or not by default.
# - `base_batch_size` = (8 GPUs) x (4 samples per GPU).
# auto_scale_lr = dict(enable=False, base_batch_size=32)

default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5))
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=1))
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
backbone=dict(in_channels=9), # [xyz, rgb, normalized_xyz]
decode_head=dict(
num_classes=13, ignore_index=13,
loss_decode=dict(class_weight=None)), # S3DIS doesn't use class_weight
loss_ce=dict(class_weight=None)), # S3DIS doesn't use class_weight
test_cfg=dict(
num_points=4096,
block_size=1.0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
model = dict(
decode_head=dict(
num_classes=13, ignore_index=13,
loss_decode=dict(class_weight=None)), # S3DIS doesn't use class_weight
loss_ce=dict(class_weight=None)), # S3DIS doesn't use class_weight
test_cfg=dict(
num_points=4096,
block_size=1.0,
Expand Down
2 changes: 1 addition & 1 deletion configs/paconv/paconv_ssg_8xb8-cosine-150e_s3dis-seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
model = dict(
decode_head=dict(
num_classes=13, ignore_index=13,
loss_decode=dict(class_weight=None)), # S3DIS doesn't use class_weight
loss_ce=dict(class_weight=None)), # S3DIS doesn't use class_weight
test_cfg=dict(
num_points=4096,
block_size=1.0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# `data/scannet/seg_info/train_label_weight.npy`
# you can copy paste the values here, or input the file path as
# `class_weight=data/scannet/seg_info/train_label_weight.npy`
loss_decode=dict(class_weight=[
loss_ce=dict(class_weight=[
2.389689, 2.7215734, 4.5944676, 4.8543367, 4.096086, 4.907941,
4.690836, 4.512031, 4.623311, 4.9242644, 5.358117, 5.360071,
5.019636, 4.967126, 5.3502126, 5.4023647, 5.4027233, 5.4169416,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# `data/scannet/seg_info/train_label_weight.npy`
# you can copy paste the values here, or input the file path as
# `class_weight=data/scannet/seg_info/train_label_weight.npy`
loss_decode=dict(class_weight=[
loss_ce=dict(class_weight=[
2.389689, 2.7215734, 4.5944676, 4.8543367, 4.096086, 4.907941,
4.690836, 4.512031, 4.623311, 4.9242644, 5.358117, 5.360071,
5.019636, 4.967126, 5.3502126, 5.4023647, 5.4027233, 5.4169416,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
backbone=dict(in_channels=9), # [xyz, rgb, normalized_xyz]
decode_head=dict(
num_classes=13, ignore_index=13,
loss_decode=dict(class_weight=None)), # S3DIS doesn't use class_weight
loss_ce=dict(class_weight=None)), # S3DIS doesn't use class_weight
test_cfg=dict(
num_points=4096,
block_size=1.0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# `data/scannet/seg_info/train_label_weight.npy`
# you can copy paste the values here, or input the file path as
# `class_weight=data/scannet/seg_info/train_label_weight.npy`
loss_decode=dict(class_weight=[
loss_ce=dict(class_weight=[
2.389689, 2.7215734, 4.5944676, 4.8543367, 4.096086, 4.907941,
4.690836, 4.512031, 4.623311, 4.9242644, 5.358117, 5.360071,
5.019636, 4.967126, 5.3502126, 5.4023647, 5.4027233, 5.4169416,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# `data/scannet/seg_info/train_label_weight.npy`
# you can copy paste the values here, or input the file path as
# `class_weight=data/scannet/seg_info/train_label_weight.npy`
loss_decode=dict(class_weight=[
loss_ce=dict(class_weight=[
2.389689, 2.7215734, 4.5944676, 4.8543367, 4.096086, 4.907941,
4.690836, 4.512031, 4.623311, 4.9242644, 5.358117, 5.360071,
5.019636, 4.967126, 5.3502126, 5.4023647, 5.4027233, 5.4169416,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
backbone=dict(in_channels=9), # [xyz, rgb, normalized_xyz]
decode_head=dict(
num_classes=13, ignore_index=13,
loss_decode=dict(class_weight=None)), # S3DIS doesn't use class_weight
loss_ce=dict(class_weight=None)), # S3DIS doesn't use class_weight
test_cfg=dict(
num_points=4096,
block_size=1.0,
Expand Down
21 changes: 11 additions & 10 deletions mmdet3d/configs/_base_/models/cylinder3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from mmdet3d.models.voxel_encoders import SegVFE

grid_shape = [480, 360, 32]
point_cloud_range = [0, -3.14159265359, -4, 50, 3.14159265359, 2]
model = dict(
type=Cylinder3D,
data_preprocessor=dict(
Expand All @@ -15,18 +16,17 @@
voxel_type='cylindrical',
voxel_layer=dict(
grid_shape=grid_shape,
point_cloud_range=[0, -3.14159265359, -4, 50, 3.14159265359, 2],
point_cloud_range=point_cloud_range,
max_num_points=-1,
max_voxels=-1,
),
),
max_voxels=-1)),
voxel_encoder=dict(
type=SegVFE,
feat_channels=[64, 128, 256, 256],
in_channels=6,
feat_channels=[64, 128, 256, 256],
with_voxel_center=True,
feat_compression=16,
return_point_feats=False),
grid_shape=grid_shape,
point_cloud_range=point_cloud_range,
feat_compression=16),
backbone=dict(
type=Asymm3DSpconv,
grid_size=grid_shape,
Expand All @@ -37,13 +37,14 @@
type=Cylinder3DHead,
channels=128,
num_classes=20,
dropout_ratio=0,
loss_ce=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=False,
class_weight=None,
loss_weight=1.0),
loss_lovasz=dict(type=LovaszLoss, loss_weight=1.0, reduction='none'),
),
conv_seg_kernel_size=3,
ignore_index=19),
train_cfg=None,
test_cfg=dict(mode='whole'),
)
test_cfg=dict(mode='whole'))
2 changes: 1 addition & 1 deletion mmdet3d/configs/_base_/models/minkunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
channels=96,
num_classes=19,
dropout_ratio=0,
loss_decode=dict(type='mmdet.CrossEntropyLoss', avg_non_ignore=True),
loss_ce=dict(type='mmdet.CrossEntropyLoss', avg_non_ignore=True),
ignore_index=19),
train_cfg=dict(),
test_cfg=dict())
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
gamma=0.1)
]

train_dataloader.update(dict(batch_size=4, ))
train_dataloader.update(dict(batch_size=4))

# Default setting for scaling LR automatically
# - `enable` means enable scaling LR automatically
Expand Down
29 changes: 2 additions & 27 deletions mmdet3d/models/data_preprocessors/data_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from mmdet3d.structures.det3d_data_sample import SampleList
from mmdet3d.utils import OptConfigType
from .utils import multiview_img_stack_batch
from .voxelize import VoxelizationByGridShape, dynamic_scatter_3d
from .voxelize import VoxelizationByGridShape


@MODELS.register_module()
Expand Down Expand Up @@ -393,7 +393,7 @@ def voxelize(self, points: List[Tensor],
coors = torch.cat(coors, dim=0)
elif self.voxel_type == 'cylindrical':
voxels, coors = [], []
for i, (res, data_sample) in enumerate(zip(points, data_samples)):
for i, res in enumerate(points):
rho = torch.sqrt(res[:, 0]**2 + res[:, 1]**2)
phi = torch.atan2(res[:, 1], res[:, 0])
polar_res = torch.stack((rho, phi, res[:, 2]), dim=-1)
Expand All @@ -416,7 +416,6 @@ def voxelize(self, points: List[Tensor],
res_coors = torch.floor(
(polar_res_clamp - min_bound) / polar_res_clamp.new_tensor(
self.voxel_layer.voxel_size)).int()
self.get_voxel_seg(res_coors, data_sample)
res_coors = F.pad(res_coors, (1, 0), mode='constant', value=i)
res_voxels = torch.cat((polar_res, res[:, :2], res[:, 3:]),
dim=-1)
Expand Down Expand Up @@ -467,30 +466,6 @@ def voxelize(self, points: List[Tensor],

return voxel_dict

def get_voxel_seg(self, res_coors: Tensor,
data_sample: SampleList) -> None:
"""Get voxel-wise segmentation label and point2voxel map.

Args:
res_coors (Tensor): The voxel coordinates of points, Nx3.
data_sample: (:obj:`Det3DDataSample`): The annotation data of
every samples. Add voxel-wise annotation forsegmentation.
"""

if self.training:
pts_semantic_mask = data_sample.gt_pts_seg.pts_semantic_mask
voxel_semantic_mask, _, point2voxel_map = dynamic_scatter_3d(
F.one_hot(pts_semantic_mask.long()).float(), res_coors, 'mean',
True)
voxel_semantic_mask = torch.argmax(voxel_semantic_mask, dim=-1)
data_sample.gt_pts_seg.voxel_semantic_mask = voxel_semantic_mask
data_sample.point2voxel_map = point2voxel_map
else:
pseudo_tensor = res_coors.new_ones([res_coors.shape[0], 1]).float()
_, _, point2voxel_map = dynamic_scatter_3d(pseudo_tensor,
res_coors, 'mean', True)
data_sample.point2voxel_map = point2voxel_map

def ravel_hash(self, x: np.ndarray) -> np.ndarray:
"""Get voxel coordinates hash for np.unique.

Expand Down
Loading
Loading