51 lines
1.6 KiB
Python
51 lines
1.6 KiB
Python
from .detector3d_template import Detector3DTemplate
|
|
|
|
|
|
class CenterPoint(Detector3DTemplate):
|
|
def __init__(self, model_cfg, num_class, dataset):
|
|
super().__init__(model_cfg=model_cfg, num_class=num_class, dataset=dataset)
|
|
self.module_list = self.build_networks()
|
|
|
|
def forward(self, batch_dict):
|
|
for cur_module in self.module_list:
|
|
batch_dict = cur_module(batch_dict)
|
|
|
|
if self.training:
|
|
loss, tb_dict, disp_dict = self.get_training_loss()
|
|
|
|
ret_dict = {
|
|
'loss': loss
|
|
}
|
|
return ret_dict, tb_dict, disp_dict
|
|
else:
|
|
pred_dicts, recall_dicts = self.post_processing(batch_dict)
|
|
return pred_dicts, recall_dicts
|
|
|
|
def get_training_loss(self):
|
|
disp_dict = {}
|
|
|
|
loss_rpn, tb_dict = self.dense_head.get_loss()
|
|
tb_dict = {
|
|
'loss_rpn': loss_rpn.item(),
|
|
**tb_dict
|
|
}
|
|
|
|
loss = loss_rpn
|
|
return loss, tb_dict, disp_dict
|
|
|
|
def post_processing(self, batch_dict):
|
|
post_process_cfg = self.model_cfg.POST_PROCESSING
|
|
batch_size = batch_dict['batch_size']
|
|
final_pred_dict = batch_dict['final_box_dicts']
|
|
recall_dict = {}
|
|
for index in range(batch_size):
|
|
pred_boxes = final_pred_dict[index]['pred_boxes']
|
|
|
|
recall_dict = self.generate_recall_record(
|
|
box_preds=pred_boxes,
|
|
recall_dict=recall_dict, batch_index=index, data_dict=batch_dict,
|
|
thresh_list=post_process_cfg.RECALL_THRESH_LIST
|
|
)
|
|
|
|
return final_pred_dict, recall_dict
|