diff --git a/pcdet/models/dense_heads/anchor_head_template.py b/pcdet/models/dense_heads/anchor_head_template.py new file mode 100644 index 0000000..db8167a --- /dev/null +++ b/pcdet/models/dense_heads/anchor_head_template.py @@ -0,0 +1,275 @@ +import numpy as np +import torch +import torch.nn as nn + +from ...utils import box_coder_utils, common_utils, loss_utils +from .target_assigner.anchor_generator import AnchorGenerator +from .target_assigner.atss_target_assigner import ATSSTargetAssigner +from .target_assigner.axis_aligned_target_assigner import AxisAlignedTargetAssigner + + +class AnchorHeadTemplate(nn.Module): + def __init__(self, model_cfg, num_class, class_names, grid_size, point_cloud_range, predict_boxes_when_training): + super().__init__() + self.model_cfg = model_cfg + self.num_class = num_class + self.class_names = class_names + self.predict_boxes_when_training = predict_boxes_when_training + self.use_multihead = self.model_cfg.get('USE_MULTIHEAD', False) + + anchor_target_cfg = self.model_cfg.TARGET_ASSIGNER_CONFIG + self.box_coder = getattr(box_coder_utils, anchor_target_cfg.BOX_CODER)( + num_dir_bins=anchor_target_cfg.get('NUM_DIR_BINS', 6), + **anchor_target_cfg.get('BOX_CODER_CONFIG', {}) + ) + + anchor_generator_cfg = self.model_cfg.ANCHOR_GENERATOR_CONFIG + anchors, self.num_anchors_per_location = self.generate_anchors( + anchor_generator_cfg, grid_size=grid_size, point_cloud_range=point_cloud_range, + anchor_ndim=self.box_coder.code_size + ) + self.anchors = [x.cuda() for x in anchors] + self.target_assigner = self.get_target_assigner(anchor_target_cfg) + + self.forward_ret_dict = {} + self.build_losses(self.model_cfg.LOSS_CONFIG) + + @staticmethod + def generate_anchors(anchor_generator_cfg, grid_size, point_cloud_range, anchor_ndim=7): + anchor_generator = AnchorGenerator( + anchor_range=point_cloud_range, + anchor_generator_config=anchor_generator_cfg + ) + feature_map_size = [grid_size[:2] // config['feature_map_stride'] for config in anchor_generator_cfg] + anchors_list, num_anchors_per_location_list = anchor_generator.generate_anchors(feature_map_size) + + if anchor_ndim != 7: + for idx, anchors in enumerate(anchors_list): + pad_zeros = anchors.new_zeros([*anchors.shape[0:-1], anchor_ndim - 7]) + new_anchors = torch.cat((anchors, pad_zeros), dim=-1) + anchors_list[idx] = new_anchors + + return anchors_list, num_anchors_per_location_list + + def get_target_assigner(self, anchor_target_cfg): + if anchor_target_cfg.NAME == 'ATSS': + target_assigner = ATSSTargetAssigner( + topk=anchor_target_cfg.TOPK, + box_coder=self.box_coder, + use_multihead=self.use_multihead, + match_height=anchor_target_cfg.MATCH_HEIGHT + ) + elif anchor_target_cfg.NAME == 'AxisAlignedTargetAssigner': + target_assigner = AxisAlignedTargetAssigner( + model_cfg=self.model_cfg, + class_names=self.class_names, + box_coder=self.box_coder, + match_height=anchor_target_cfg.MATCH_HEIGHT + ) + else: + raise NotImplementedError + return target_assigner + + def build_losses(self, losses_cfg): + self.add_module( + 'cls_loss_func', + loss_utils.SigmoidFocalClassificationLoss(alpha=0.25, gamma=2.0) + ) + reg_loss_name = 'WeightedSmoothL1Loss' if losses_cfg.get('REG_LOSS_TYPE', None) is None \ + else losses_cfg.REG_LOSS_TYPE + self.add_module( + 'reg_loss_func', + getattr(loss_utils, reg_loss_name)(code_weights=losses_cfg.LOSS_WEIGHTS['code_weights']) + ) + self.add_module( + 'dir_loss_func', + loss_utils.WeightedCrossEntropyLoss() + ) + + def assign_targets(self, gt_boxes): + """ + Args: + gt_boxes: (B, M, 8) + Returns: + + """ + targets_dict = self.target_assigner.assign_targets( + self.anchors, gt_boxes + ) + return targets_dict + + def get_cls_layer_loss(self): + cls_preds = self.forward_ret_dict['cls_preds'] + box_cls_labels = self.forward_ret_dict['box_cls_labels'] + batch_size = int(cls_preds.shape[0]) + cared = box_cls_labels >= 0 # [N, num_anchors] + positives = box_cls_labels > 0 + negatives = box_cls_labels == 0 + negative_cls_weights = negatives * 1.0 + cls_weights = (negative_cls_weights + 1.0 * positives).float() + reg_weights = positives.float() + if self.num_class == 1: + # class agnostic + box_cls_labels[positives] = 1 + + pos_normalizer = positives.sum(1, keepdim=True).float() + reg_weights /= torch.clamp(pos_normalizer, min=1.0) + cls_weights /= torch.clamp(pos_normalizer, min=1.0) + cls_targets = box_cls_labels * cared.type_as(box_cls_labels) + cls_targets = cls_targets.unsqueeze(dim=-1) + + cls_targets = cls_targets.squeeze(dim=-1) + one_hot_targets = torch.zeros( + *list(cls_targets.shape), self.num_class + 1, dtype=cls_preds.dtype, device=cls_targets.device + ) + one_hot_targets.scatter_(-1, cls_targets.unsqueeze(dim=-1).long(), 1.0) + cls_preds = cls_preds.view(batch_size, -1, self.num_class) + one_hot_targets = one_hot_targets[..., 1:] + cls_loss_src = self.cls_loss_func(cls_preds, one_hot_targets, weights=cls_weights) # [N, M] + cls_loss = cls_loss_src.sum() / batch_size + + cls_loss = cls_loss * self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['cls_weight'] + tb_dict = { + 'rpn_loss_cls': cls_loss.item() + } + return cls_loss, tb_dict + + @staticmethod + def add_sin_difference(boxes1, boxes2, dim=6): + assert dim != -1 + rad_pred_encoding = torch.sin(boxes1[..., dim:dim + 1]) * torch.cos(boxes2[..., dim:dim + 1]) + rad_tg_encoding = torch.cos(boxes1[..., dim:dim + 1]) * torch.sin(boxes2[..., dim:dim + 1]) + boxes1 = torch.cat([boxes1[..., :dim], rad_pred_encoding, boxes1[..., dim + 1:]], dim=-1) + boxes2 = torch.cat([boxes2[..., :dim], rad_tg_encoding, boxes2[..., dim + 1:]], dim=-1) + return boxes1, boxes2 + + @staticmethod + def get_direction_target(anchors, reg_targets, one_hot=True, dir_offset=0, num_bins=2): + batch_size = reg_targets.shape[0] + anchors = anchors.view(batch_size, -1, anchors.shape[-1]) + rot_gt = reg_targets[..., 6] + anchors[..., 6] + offset_rot = common_utils.limit_period(rot_gt - dir_offset, 0, 2 * np.pi) + dir_cls_targets = torch.floor(offset_rot / (2 * np.pi / num_bins)).long() + dir_cls_targets = torch.clamp(dir_cls_targets, min=0, max=num_bins - 1) + + if one_hot: + dir_targets = torch.zeros(*list(dir_cls_targets.shape), num_bins, dtype=anchors.dtype, + device=dir_cls_targets.device) + dir_targets.scatter_(-1, dir_cls_targets.unsqueeze(dim=-1).long(), 1.0) + dir_cls_targets = dir_targets + return dir_cls_targets + + def get_box_reg_layer_loss(self): + box_preds = self.forward_ret_dict['box_preds'] + box_dir_cls_preds = self.forward_ret_dict.get('dir_cls_preds', None) + box_reg_targets = self.forward_ret_dict['box_reg_targets'] + box_cls_labels = self.forward_ret_dict['box_cls_labels'] + batch_size = int(box_preds.shape[0]) + + positives = box_cls_labels > 0 + reg_weights = positives.float() + pos_normalizer = positives.sum(1, keepdim=True).float() + reg_weights /= torch.clamp(pos_normalizer, min=1.0) + + if isinstance(self.anchors, list): + if self.use_multihead: + anchors = torch.cat( + [anchor.permute(3, 4, 0, 1, 2, 5).contiguous().view(-1, anchor.shape[-1]) for anchor in + self.anchors], dim=0) + else: + anchors = torch.cat(self.anchors, dim=-3) + else: + anchors = self.anchors + anchors = anchors.view(1, -1, anchors.shape[-1]).repeat(batch_size, 1, 1) + box_preds = box_preds.view(batch_size, -1, + box_preds.shape[-1] // self.num_anchors_per_location if not self.use_multihead else + box_preds.shape[-1]) + # sin(a - b) = sinacosb-cosasinb + box_preds_sin, reg_targets_sin = self.add_sin_difference(box_preds, box_reg_targets) + loc_loss_src = self.reg_loss_func(box_preds_sin, reg_targets_sin, weights=reg_weights) # [N, M] + loc_loss = loc_loss_src.sum() / batch_size + + loc_loss = loc_loss * self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['loc_weight'] + box_loss = loc_loss + tb_dict = { + 'rpn_loss_loc': loc_loss.item() + } + + if box_dir_cls_preds is not None: + dir_targets = self.get_direction_target( + anchors, box_reg_targets, + dir_offset=self.model_cfg.DIR_OFFSET, + num_bins=self.model_cfg.NUM_DIR_BINS + ) + + dir_logits = box_dir_cls_preds.view(batch_size, -1, self.model_cfg.NUM_DIR_BINS) + weights = positives.type_as(dir_logits) + weights /= torch.clamp(weights.sum(-1, keepdim=True), min=1.0) + dir_loss = self.dir_loss_func(dir_logits, dir_targets, weights=weights) + dir_loss = dir_loss.sum() / batch_size + dir_loss = dir_loss * self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['dir_weight'] + box_loss += dir_loss + tb_dict['rpn_loss_dir'] = dir_loss.item() + + return box_loss, tb_dict + + def get_loss(self): + cls_loss, tb_dict = self.get_cls_layer_loss() + box_loss, tb_dict_box = self.get_box_reg_layer_loss() + tb_dict.update(tb_dict_box) + rpn_loss = cls_loss + box_loss + + tb_dict['rpn_loss'] = rpn_loss.item() + return rpn_loss, tb_dict + + def generate_predicted_boxes(self, batch_size, cls_preds, box_preds, dir_cls_preds=None): + """ + Args: + batch_size: + cls_preds: (N, H, W, C1) + box_preds: (N, H, W, C2) + dir_cls_preds: (N, H, W, C3) + + Returns: + batch_cls_preds: (B, num_boxes, num_classes) + batch_box_preds: (B, num_boxes, 7+C) + + """ + if isinstance(self.anchors, list): + if self.use_multihead: + anchors = torch.cat([anchor.permute(3, 4, 0, 1, 2, 5).contiguous().view(-1, anchor.shape[-1]) + for anchor in self.anchors], dim=0) + else: + anchors = torch.cat(self.anchors, dim=-3) + else: + anchors = self.anchors + num_anchors = anchors.view(-1, anchors.shape[-1]).shape[0] + batch_anchors = anchors.view(1, -1, anchors.shape[-1]).repeat(batch_size, 1, 1) + batch_cls_preds = cls_preds.view(batch_size, num_anchors, -1).float() \ + if not isinstance(cls_preds, list) else cls_preds + batch_box_preds = box_preds.view(batch_size, num_anchors, -1) if not isinstance(box_preds, list) \ + else torch.cat(box_preds, dim=1).view(batch_size, num_anchors, -1) + batch_box_preds = self.box_coder.decode_torch(batch_box_preds, batch_anchors) + + if dir_cls_preds is not None: + dir_offset = self.model_cfg.DIR_OFFSET + dir_limit_offset = self.model_cfg.DIR_LIMIT_OFFSET + dir_cls_preds = dir_cls_preds.view(batch_size, num_anchors, -1) if not isinstance(dir_cls_preds, list) \ + else torch.cat(dir_cls_preds, dim=1).view(batch_size, num_anchors, -1) + dir_labels = torch.max(dir_cls_preds, dim=-1)[1] + + period = (2 * np.pi / self.model_cfg.NUM_DIR_BINS) + dir_rot = common_utils.limit_period( + batch_box_preds[..., 6] - dir_offset, dir_limit_offset, period + ) + batch_box_preds[..., 6] = dir_rot + dir_offset + period * dir_labels.to(batch_box_preds.dtype) + + if isinstance(self.box_coder, box_coder_utils.PreviousResidualDecoder): + batch_box_preds[..., 6] = common_utils.limit_period( + -(batch_box_preds[..., 6] + np.pi / 2), offset=0.5, period=np.pi * 2 + ) + + return batch_cls_preds, batch_box_preds + + def forward(self, **kwargs): + raise NotImplementedError