Skip to content

Commit

Permalink
override predict method
Browse files Browse the repository at this point in the history
Signed-off-by: wep21 <[email protected]>
  • Loading branch information
wep21 committed Jun 6, 2023
1 parent 681958e commit 004e436
Showing 1 changed file with 44 additions and 0 deletions.
44 changes: 44 additions & 0 deletions projects/TransFusion/transfusion/transfusion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from typing import Dict, List, Optional

from torch import Tensor

from mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector
from mmdet3d.registry import MODELS
from mmdet3d.structures import Det3DDataSample


@MODELS.register_module()
Expand All @@ -23,3 +28,42 @@ def init_weights(self):
if self.with_img_neck:
for param in self.img_neck.parameters():
param.requires_grad = False

def predict(self, batch_inputs_dict: Dict[str, Optional[Tensor]],
batch_data_samples: List[Det3DDataSample],
**kwargs) -> List[Det3DDataSample]:
"""Forward of testing.
Args:
batch_inputs_dict (dict): The model input dict which include
'points' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`.
Returns:
list[:obj:`Det3DDataSample`]: Detection results of the
input sample. Each Det3DDataSample usually contain
'pred_instances_3d'. And the ``pred_instances_3d`` usually
contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instances, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bbox_3d (:obj:`BaseInstance3DBoxes`): Prediction of bboxes,
contains a tensor with shape (num_instances, 7).
"""
batch_input_metas = [item.metainfo for item in batch_data_samples]
img_feats, pts_feats = self.extract_feat(batch_inputs_dict, batch_input_metas)

if pts_feats and self.with_pts_bbox:
outputs = self.pts_bbox_head.predict(pts_feats, batch_input_metas)
else:
outputs = None

res = self.add_pred_to_datasample(batch_data_samples, outputs)

return res

0 comments on commit 004e436

Please sign in to comment.