Add File
This commit is contained in:
53
pcdet/models/detectors/pv_rcnn_plusplus.py
Normal file
53
pcdet/models/detectors/pv_rcnn_plusplus.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from .detector3d_template import Detector3DTemplate
|
||||
|
||||
|
||||
class PVRCNNPlusPlus(Detector3DTemplate):
|
||||
def __init__(self, model_cfg, num_class, dataset):
|
||||
super().__init__(model_cfg=model_cfg, num_class=num_class, dataset=dataset)
|
||||
self.module_list = self.build_networks()
|
||||
|
||||
def forward(self, batch_dict):
|
||||
batch_dict = self.vfe(batch_dict)
|
||||
batch_dict = self.backbone_3d(batch_dict)
|
||||
batch_dict = self.map_to_bev_module(batch_dict)
|
||||
batch_dict = self.backbone_2d(batch_dict)
|
||||
batch_dict = self.dense_head(batch_dict)
|
||||
|
||||
batch_dict = self.roi_head.proposal_layer(
|
||||
batch_dict, nms_config=self.roi_head.model_cfg.NMS_CONFIG['TRAIN' if self.training else 'TEST']
|
||||
)
|
||||
if self.training:
|
||||
targets_dict = self.roi_head.assign_targets(batch_dict)
|
||||
batch_dict['rois'] = targets_dict['rois']
|
||||
batch_dict['roi_labels'] = targets_dict['roi_labels']
|
||||
batch_dict['roi_targets_dict'] = targets_dict
|
||||
num_rois_per_scene = targets_dict['rois'].shape[1]
|
||||
if 'roi_valid_num' in batch_dict:
|
||||
batch_dict['roi_valid_num'] = [num_rois_per_scene for _ in range(batch_dict['batch_size'])]
|
||||
|
||||
batch_dict = self.pfe(batch_dict)
|
||||
batch_dict = self.point_head(batch_dict)
|
||||
batch_dict = self.roi_head(batch_dict)
|
||||
|
||||
if self.training:
|
||||
loss, tb_dict, disp_dict = self.get_training_loss()
|
||||
|
||||
ret_dict = {
|
||||
'loss': loss
|
||||
}
|
||||
return ret_dict, tb_dict, disp_dict
|
||||
else:
|
||||
pred_dicts, recall_dicts = self.post_processing(batch_dict)
|
||||
return pred_dicts, recall_dicts
|
||||
|
||||
def get_training_loss(self):
|
||||
disp_dict = {}
|
||||
loss_rpn, tb_dict = self.dense_head.get_loss()
|
||||
if self.point_head is not None:
|
||||
loss_point, tb_dict = self.point_head.get_loss(tb_dict)
|
||||
else:
|
||||
loss_point = 0
|
||||
loss_rcnn, tb_dict = self.roi_head.get_loss(tb_dict)
|
||||
|
||||
loss = loss_rpn + loss_point + loss_rcnn
|
||||
return loss, tb_dict, disp_dict
|
||||
Reference in New Issue
Block a user