diff --git a/pcdet/models/detectors/bevfusion.py b/pcdet/models/detectors/bevfusion.py new file mode 100644 index 0000000..bb55fc0 --- /dev/null +++ b/pcdet/models/detectors/bevfusion.py @@ -0,0 +1,101 @@ +from .detector3d_template import Detector3DTemplate +from .. import backbones_image, view_transforms +from ..backbones_image import img_neck +from ..backbones_2d import fuser + +class BevFusion(Detector3DTemplate): + def __init__(self, model_cfg, num_class, dataset): + super().__init__(model_cfg=model_cfg, num_class=num_class, dataset=dataset) + self.module_topology = [ + 'vfe', 'backbone_3d', 'map_to_bev_module', 'pfe', + 'image_backbone','neck','vtransform','fuser', + 'backbone_2d', 'dense_head', 'point_head', 'roi_head' + ] + self.module_list = self.build_networks() + + def build_neck(self,model_info_dict): + if self.model_cfg.get('NECK', None) is None: + return None, model_info_dict + neck_module = img_neck.__all__[self.model_cfg.NECK.NAME]( + model_cfg=self.model_cfg.NECK + ) + model_info_dict['module_list'].append(neck_module) + + return neck_module, model_info_dict + + def build_vtransform(self,model_info_dict): + if self.model_cfg.get('VTRANSFORM', None) is None: + return None, model_info_dict + + vtransform_module = view_transforms.__all__[self.model_cfg.VTRANSFORM.NAME]( + model_cfg=self.model_cfg.VTRANSFORM + ) + model_info_dict['module_list'].append(vtransform_module) + + return vtransform_module, model_info_dict + + def build_image_backbone(self, model_info_dict): + if self.model_cfg.get('IMAGE_BACKBONE', None) is None: + return None, model_info_dict + image_backbone_module = backbones_image.__all__[self.model_cfg.IMAGE_BACKBONE.NAME]( + model_cfg=self.model_cfg.IMAGE_BACKBONE + ) + image_backbone_module.init_weights() + model_info_dict['module_list'].append(image_backbone_module) + + return image_backbone_module, model_info_dict + + def build_fuser(self, model_info_dict): + if self.model_cfg.get('FUSER', None) is None: + return None, model_info_dict + + fuser_module = fuser.__all__[self.model_cfg.FUSER.NAME]( + model_cfg=self.model_cfg.FUSER + ) + model_info_dict['module_list'].append(fuser_module) + model_info_dict['num_bev_features'] = self.model_cfg.FUSER.OUT_CHANNEL + return fuser_module, model_info_dict + + def forward(self, batch_dict): + + for i,cur_module in enumerate(self.module_list): + batch_dict = cur_module(batch_dict) + + if self.training: + loss, tb_dict, disp_dict = self.get_training_loss(batch_dict) + + ret_dict = { + 'loss': loss + } + return ret_dict, tb_dict, disp_dict + else: + pred_dicts, recall_dicts = self.post_processing(batch_dict) + return pred_dicts, recall_dicts + + def get_training_loss(self,batch_dict): + disp_dict = {} + + loss_trans, tb_dict = batch_dict['loss'],batch_dict['tb_dict'] + tb_dict = { + 'loss_trans': loss_trans.item(), + **tb_dict + } + + loss = loss_trans + return loss, tb_dict, disp_dict + + def post_processing(self, batch_dict): + post_process_cfg = self.model_cfg.POST_PROCESSING + batch_size = batch_dict['batch_size'] + final_pred_dict = batch_dict['final_box_dicts'] + recall_dict = {} + for index in range(batch_size): + pred_boxes = final_pred_dict[index]['pred_boxes'] + + recall_dict = self.generate_recall_record( + box_preds=pred_boxes, + recall_dict=recall_dict, batch_index=index, data_dict=batch_dict, + thresh_list=post_process_cfg.RECALL_THRESH_LIST + ) + + return final_pred_dict, recall_dict