From fe291c213680d18135289b84b4633f153b637f54 Mon Sep 17 00:00:00 2001 From: inter Date: Sun, 21 Sep 2025 20:18:43 +0800 Subject: [PATCH] Add File --- pcdet/models/detectors/caddn.py | 38 +++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 pcdet/models/detectors/caddn.py diff --git a/pcdet/models/detectors/caddn.py b/pcdet/models/detectors/caddn.py new file mode 100644 index 0000000..32f56a7 --- /dev/null +++ b/pcdet/models/detectors/caddn.py @@ -0,0 +1,38 @@ +from .detector3d_template import Detector3DTemplate + + +class CaDDN(Detector3DTemplate): + def __init__(self, model_cfg, num_class, dataset): + super().__init__(model_cfg=model_cfg, num_class=num_class, dataset=dataset) + self.module_list = self.build_networks() + + def forward(self, batch_dict): + for cur_module in self.module_list: + batch_dict = cur_module(batch_dict) + + if self.training: + loss, tb_dict, disp_dict = self.get_training_loss() + + ret_dict = { + 'loss': loss + } + return ret_dict, tb_dict, disp_dict + else: + pred_dicts, recall_dicts = self.post_processing(batch_dict) + return pred_dicts, recall_dicts + + def get_training_loss(self): + disp_dict = {} + + loss_rpn, tb_dict_rpn = self.dense_head.get_loss() + loss_depth, tb_dict_depth = self.vfe.get_loss() + + tb_dict = { + 'loss_rpn': loss_rpn.item(), + 'loss_depth': loss_depth.item(), + **tb_dict_rpn, + **tb_dict_depth + } + + loss = loss_rpn + loss_depth + return loss, tb_dict, disp_dict