From 0d10a850b5379c00cb21212e4f01eb0dd80613db Mon Sep 17 00:00:00 2001 From: inter Date: Sun, 21 Sep 2025 20:18:45 +0800 Subject: [PATCH] Add File --- pcdet/models/detectors/transfusion.py | 50 +++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 pcdet/models/detectors/transfusion.py diff --git a/pcdet/models/detectors/transfusion.py b/pcdet/models/detectors/transfusion.py new file mode 100644 index 0000000..16d81e8 --- /dev/null +++ b/pcdet/models/detectors/transfusion.py @@ -0,0 +1,50 @@ +from .detector3d_template import Detector3DTemplate + + +class TransFusion(Detector3DTemplate): + def __init__(self, model_cfg, num_class, dataset): + super().__init__(model_cfg=model_cfg, num_class=num_class, dataset=dataset) + self.module_list = self.build_networks() + + def forward(self, batch_dict): + for cur_module in self.module_list: + batch_dict = cur_module(batch_dict) + + if self.training: + loss, tb_dict, disp_dict = self.get_training_loss(batch_dict) + + ret_dict = { + 'loss': loss + } + return ret_dict, tb_dict, disp_dict + else: + pred_dicts, recall_dicts = self.post_processing(batch_dict) + return pred_dicts, recall_dicts + + def get_training_loss(self,batch_dict): + disp_dict = {} + + loss_trans, tb_dict = batch_dict['loss'],batch_dict['tb_dict'] + tb_dict = { + 'loss_trans': loss_trans.item(), + **tb_dict + } + + loss = loss_trans + return loss, tb_dict, disp_dict + + def post_processing(self, batch_dict): + post_process_cfg = self.model_cfg.POST_PROCESSING + batch_size = batch_dict['batch_size'] + final_pred_dict = batch_dict['final_box_dicts'] + recall_dict = {} + for index in range(batch_size): + pred_boxes = final_pred_dict[index]['pred_boxes'] + + recall_dict = self.generate_recall_record( + box_preds=pred_boxes, + recall_dict=recall_dict, batch_index=index, data_dict=batch_dict, + thresh_list=post_process_cfg.RECALL_THRESH_LIST + ) + + return final_pred_dict, recall_dict