From b428262db7549a59b26443b0a9575f5783536095 Mon Sep 17 00:00:00 2001 From: inter Date: Sun, 21 Sep 2025 20:18:48 +0800 Subject: [PATCH] Add File --- pcdet/models/detectors/pv_rcnn_plusplus.py | 53 ++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 pcdet/models/detectors/pv_rcnn_plusplus.py diff --git a/pcdet/models/detectors/pv_rcnn_plusplus.py b/pcdet/models/detectors/pv_rcnn_plusplus.py new file mode 100644 index 0000000..2c64e67 --- /dev/null +++ b/pcdet/models/detectors/pv_rcnn_plusplus.py @@ -0,0 +1,53 @@ +from .detector3d_template import Detector3DTemplate + + +class PVRCNNPlusPlus(Detector3DTemplate): + def __init__(self, model_cfg, num_class, dataset): + super().__init__(model_cfg=model_cfg, num_class=num_class, dataset=dataset) + self.module_list = self.build_networks() + + def forward(self, batch_dict): + batch_dict = self.vfe(batch_dict) + batch_dict = self.backbone_3d(batch_dict) + batch_dict = self.map_to_bev_module(batch_dict) + batch_dict = self.backbone_2d(batch_dict) + batch_dict = self.dense_head(batch_dict) + + batch_dict = self.roi_head.proposal_layer( + batch_dict, nms_config=self.roi_head.model_cfg.NMS_CONFIG['TRAIN' if self.training else 'TEST'] + ) + if self.training: + targets_dict = self.roi_head.assign_targets(batch_dict) + batch_dict['rois'] = targets_dict['rois'] + batch_dict['roi_labels'] = targets_dict['roi_labels'] + batch_dict['roi_targets_dict'] = targets_dict + num_rois_per_scene = targets_dict['rois'].shape[1] + if 'roi_valid_num' in batch_dict: + batch_dict['roi_valid_num'] = [num_rois_per_scene for _ in range(batch_dict['batch_size'])] + + batch_dict = self.pfe(batch_dict) + batch_dict = self.point_head(batch_dict) + batch_dict = self.roi_head(batch_dict) + + if self.training: + loss, tb_dict, disp_dict = self.get_training_loss() + + ret_dict = { + 'loss': loss + } + return ret_dict, tb_dict, disp_dict + else: + pred_dicts, recall_dicts = self.post_processing(batch_dict) + return pred_dicts, recall_dicts + + def get_training_loss(self): + disp_dict = {} + loss_rpn, tb_dict = self.dense_head.get_loss() + if self.point_head is not None: + loss_point, tb_dict = self.point_head.get_loss(tb_dict) + else: + loss_point = 0 + loss_rcnn, tb_dict = self.roi_head.get_loss(tb_dict) + + loss = loss_rpn + loss_point + loss_rcnn + return loss, tb_dict, disp_dict