import copy import numpy as np import torch from torch import nn import torch.nn.functional as F from torch.nn.init import kaiming_normal_ from ..model_utils.transfusion_utils import clip_sigmoid from ..model_utils.basic_block_2d import BasicBlock2D from ..model_utils.transfusion_utils import PositionEmbeddingLearned, TransformerDecoderLayer from .target_assigner.hungarian_assigner import HungarianAssigner3D from ...utils import loss_utils from ..model_utils import centernet_utils class SeparateHead_Transfusion(nn.Module): def __init__(self, input_channels, head_channels, kernel_size, sep_head_dict, init_bias=-2.19, use_bias=False): super().__init__() self.sep_head_dict = sep_head_dict for cur_name in self.sep_head_dict: output_channels = self.sep_head_dict[cur_name]['out_channels'] num_conv = self.sep_head_dict[cur_name]['num_conv'] fc_list = [] for k in range(num_conv - 1): fc_list.append(nn.Sequential( nn.Conv1d(input_channels, head_channels, kernel_size, stride=1, padding=kernel_size//2, bias=use_bias), nn.BatchNorm1d(head_channels), nn.ReLU() )) fc_list.append(nn.Conv1d(head_channels, output_channels, kernel_size, stride=1, padding=kernel_size//2, bias=True)) fc = nn.Sequential(*fc_list) if 'hm' in cur_name: fc[-1].bias.data.fill_(init_bias) else: for m in fc.modules(): if isinstance(m, nn.Conv2d): kaiming_normal_(m.weight.data) if hasattr(m, "bias") and m.bias is not None: nn.init.constant_(m.bias, 0) self.__setattr__(cur_name, fc) def forward(self, x): ret_dict = {} for cur_name in self.sep_head_dict: ret_dict[cur_name] = self.__getattr__(cur_name)(x) return ret_dict class TransFusionHead(nn.Module): """ This module implements TransFusionHead. The code is adapted from https://github.com/mit-han-lab/bevfusion/ with minimal modifications. """ def __init__( self, model_cfg, input_channels, num_class, class_names, grid_size, point_cloud_range, voxel_size, predict_boxes_when_training=True, ): super(TransFusionHead, self).__init__() self.grid_size = grid_size self.point_cloud_range = point_cloud_range self.voxel_size = voxel_size self.num_classes = num_class self.model_cfg = model_cfg self.feature_map_stride = self.model_cfg.TARGET_ASSIGNER_CONFIG.get('FEATURE_MAP_STRIDE', None) self.dataset_name = self.model_cfg.TARGET_ASSIGNER_CONFIG.get('DATASET', 'nuScenes') hidden_channel=self.model_cfg.HIDDEN_CHANNEL self.num_proposals = self.model_cfg.NUM_PROPOSALS self.bn_momentum = self.model_cfg.BN_MOMENTUM self.nms_kernel_size = self.model_cfg.NMS_KERNEL_SIZE num_heads = self.model_cfg.NUM_HEADS dropout = self.model_cfg.DROPOUT activation = self.model_cfg.ACTIVATION ffn_channel = self.model_cfg.FFN_CHANNEL bias = self.model_cfg.get('USE_BIAS_BEFORE_NORM', False) loss_cls = self.model_cfg.LOSS_CONFIG.LOSS_CLS self.use_sigmoid_cls = loss_cls.get("use_sigmoid", False) if not self.use_sigmoid_cls: self.num_classes += 1 self.loss_cls = loss_utils.SigmoidFocalClassificationLoss(gamma=loss_cls.gamma,alpha=loss_cls.alpha) self.loss_cls_weight = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['cls_weight'] self.loss_bbox = loss_utils.L1Loss() self.loss_bbox_weight = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['bbox_weight'] self.loss_heatmap = loss_utils.GaussianFocalLoss() self.loss_heatmap_weight = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['hm_weight'] self.code_size = 10 # a shared convolution self.shared_conv = nn.Conv2d(in_channels=input_channels,out_channels=hidden_channel,kernel_size=3,padding=1) layers = [] layers.append(BasicBlock2D(hidden_channel,hidden_channel, kernel_size=3,padding=1,bias=bias)) layers.append(nn.Conv2d(in_channels=hidden_channel,out_channels=num_class,kernel_size=3,padding=1)) self.heatmap_head = nn.Sequential(*layers) self.class_encoding = nn.Conv1d(num_class, hidden_channel, 1) # transformer decoder layers for object query with LiDAR feature self.decoder = TransformerDecoderLayer(hidden_channel, num_heads, ffn_channel, dropout, activation, self_posembed=PositionEmbeddingLearned(2, hidden_channel), cross_posembed=PositionEmbeddingLearned(2, hidden_channel), ) # Prediction Head heads = copy.deepcopy(self.model_cfg.SEPARATE_HEAD_CFG.HEAD_DICT) heads['heatmap'] = dict(out_channels=self.num_classes, num_conv=self.model_cfg.NUM_HM_CONV) self.prediction_head = SeparateHead_Transfusion(hidden_channel, 64, 1, heads, use_bias=bias) self.init_weights() self.bbox_assigner = HungarianAssigner3D(**self.model_cfg.TARGET_ASSIGNER_CONFIG.HUNGARIAN_ASSIGNER) # Position Embedding for Cross-Attention, which is re-used during training x_size = self.grid_size[0] // self.feature_map_stride y_size = self.grid_size[1] // self.feature_map_stride self.bev_pos = self.create_2D_grid(x_size, y_size) self.forward_ret_dict = {} def create_2D_grid(self, x_size, y_size): meshgrid = [[0, x_size - 1, x_size], [0, y_size - 1, y_size]] # NOTE: modified batch_x, batch_y = torch.meshgrid( *[torch.linspace(it[0], it[1], it[2]) for it in meshgrid] ) batch_x = batch_x + 0.5 batch_y = batch_y + 0.5 coord_base = torch.cat([batch_x[None], batch_y[None]], dim=0)[None] coord_base = coord_base.view(1, 2, -1).permute(0, 2, 1) return coord_base def init_weights(self): # initialize transformer for m in self.decoder.parameters(): if m.dim() > 1: nn.init.xavier_uniform_(m) if hasattr(self, "query"): nn.init.xavier_normal_(self.query) self.init_bn_momentum() def init_bn_momentum(self): for m in self.modules(): if isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)): m.momentum = self.bn_momentum def predict(self, inputs): batch_size = inputs.shape[0] lidar_feat = self.shared_conv(inputs) lidar_feat_flatten = lidar_feat.view( batch_size, lidar_feat.shape[1], -1 ) bev_pos = self.bev_pos.repeat(batch_size, 1, 1).to(lidar_feat.device) # query initialization dense_heatmap = self.heatmap_head(lidar_feat) heatmap = dense_heatmap.detach().sigmoid() padding = self.nms_kernel_size // 2 local_max = torch.zeros_like(heatmap) local_max_inner = F.max_pool2d( heatmap, kernel_size=self.nms_kernel_size, stride=1, padding=0 ) local_max[:, :, padding:(-padding), padding:(-padding)] = local_max_inner # for Pedestrian & Traffic_cone in nuScenes if self.dataset_name == "nuScenes": local_max[ :, 8, ] = F.max_pool2d(heatmap[:, 8], kernel_size=1, stride=1, padding=0) local_max[ :, 9, ] = F.max_pool2d(heatmap[:, 9], kernel_size=1, stride=1, padding=0) # for Pedestrian & Cyclist in Waymo elif self.dataset_name == "Waymo": local_max[ :, 1, ] = F.max_pool2d(heatmap[:, 1], kernel_size=1, stride=1, padding=0) local_max[ :, 2, ] = F.max_pool2d(heatmap[:, 2], kernel_size=1, stride=1, padding=0) heatmap = heatmap * (heatmap == local_max) heatmap = heatmap.view(batch_size, heatmap.shape[1], -1) # top num_proposals among all classes top_proposals = heatmap.view(batch_size, -1).argsort(dim=-1, descending=True)[ ..., : self.num_proposals ] top_proposals_class = top_proposals // heatmap.shape[-1] top_proposals_index = top_proposals % heatmap.shape[-1] query_feat = lidar_feat_flatten.gather( index=top_proposals_index[:, None, :].expand(-1, lidar_feat_flatten.shape[1], -1), dim=-1, ) self.query_labels = top_proposals_class # add category embedding one_hot = F.one_hot(top_proposals_class, num_classes=self.num_classes).permute(0, 2, 1) query_cat_encoding = self.class_encoding(one_hot.float()) query_feat += query_cat_encoding query_pos = bev_pos.gather( index=top_proposals_index[:, None, :].permute(0, 2, 1).expand(-1, -1, bev_pos.shape[-1]), dim=1, ) # convert to xy query_pos = query_pos.flip(dims=[-1]) bev_pos = bev_pos.flip(dims=[-1]) query_feat = self.decoder( query_feat, lidar_feat_flatten, query_pos, bev_pos ) res_layer = self.prediction_head(query_feat) res_layer["center"] = res_layer["center"] + query_pos.permute(0, 2, 1) res_layer["query_heatmap_score"] = heatmap.gather( index=top_proposals_index[:, None, :].expand(-1, self.num_classes, -1), dim=-1, ) res_layer["dense_heatmap"] = dense_heatmap return res_layer def forward(self, batch_dict): feats = batch_dict['spatial_features_2d'] res = self.predict(feats) if not self.training: bboxes = self.get_bboxes(res) batch_dict['final_box_dicts'] = bboxes else: gt_boxes = batch_dict['gt_boxes'] gt_bboxes_3d = gt_boxes[...,:-1] gt_labels_3d = gt_boxes[...,-1].long() - 1 loss, tb_dict = self.loss(gt_bboxes_3d, gt_labels_3d, res) batch_dict['loss'] = loss batch_dict['tb_dict'] = tb_dict return batch_dict def get_targets(self, gt_bboxes_3d, gt_labels_3d, pred_dicts): assign_results = [] for batch_idx in range(len(gt_bboxes_3d)): pred_dict = {} for key in pred_dicts.keys(): pred_dict[key] = pred_dicts[key][batch_idx : batch_idx + 1] gt_bboxes = gt_bboxes_3d[batch_idx] valid_idx = [] # filter empty boxes for i in range(len(gt_bboxes)): if gt_bboxes[i][3] > 0 and gt_bboxes[i][4] > 0: valid_idx.append(i) assign_result = self.get_targets_single(gt_bboxes[valid_idx], gt_labels_3d[batch_idx][valid_idx], pred_dict) assign_results.append(assign_result) res_tuple = tuple(map(list, zip(*assign_results))) labels = torch.cat(res_tuple[0], dim=0) label_weights = torch.cat(res_tuple[1], dim=0) bbox_targets = torch.cat(res_tuple[2], dim=0) bbox_weights = torch.cat(res_tuple[3], dim=0) num_pos = np.sum(res_tuple[4]) matched_ious = np.mean(res_tuple[5]) heatmap = torch.cat(res_tuple[6], dim=0) return labels, label_weights, bbox_targets, bbox_weights, num_pos, matched_ious, heatmap def get_targets_single(self, gt_bboxes_3d, gt_labels_3d, preds_dict): num_proposals = preds_dict["center"].shape[-1] score = copy.deepcopy(preds_dict["heatmap"].detach()) center = copy.deepcopy(preds_dict["center"].detach()) height = copy.deepcopy(preds_dict["height"].detach()) dim = copy.deepcopy(preds_dict["dim"].detach()) rot = copy.deepcopy(preds_dict["rot"].detach()) if "vel" in preds_dict.keys(): vel = copy.deepcopy(preds_dict["vel"].detach()) else: vel = None boxes_dict = self.decode_bbox(score, rot, dim, center, height, vel) bboxes_tensor = boxes_dict[0]["pred_boxes"] gt_bboxes_tensor = gt_bboxes_3d.to(score.device) assigned_gt_inds, ious = self.bbox_assigner.assign( bboxes_tensor, gt_bboxes_tensor, gt_labels_3d, score, self.point_cloud_range, ) pos_inds = torch.nonzero(assigned_gt_inds > 0, as_tuple=False).squeeze(-1).unique() neg_inds = torch.nonzero(assigned_gt_inds == 0, as_tuple=False).squeeze(-1).unique() pos_assigned_gt_inds = assigned_gt_inds[pos_inds] - 1 if gt_bboxes_3d.numel() == 0: assert pos_inds.numel() == 0 pos_gt_bboxes = torch.empty_like(gt_bboxes_3d).view(-1, 9) else: pos_gt_bboxes = gt_bboxes_3d[pos_assigned_gt_inds.long(), :] # create target for loss computation bbox_targets = torch.zeros([num_proposals, self.code_size]).to(center.device) bbox_weights = torch.zeros([num_proposals, self.code_size]).to(center.device) ious = torch.clamp(ious, min=0.0, max=1.0) labels = bboxes_tensor.new_zeros(num_proposals, dtype=torch.long) label_weights = bboxes_tensor.new_zeros(num_proposals, dtype=torch.long) if gt_labels_3d is not None: # default label is -1 labels += self.num_classes # both pos and neg have classification loss, only pos has regression and iou loss if len(pos_inds) > 0: pos_bbox_targets = self.encode_bbox(pos_gt_bboxes) bbox_targets[pos_inds, :] = pos_bbox_targets bbox_weights[pos_inds, :] = 1.0 if gt_labels_3d is None: labels[pos_inds] = 1 else: labels[pos_inds] = gt_labels_3d[pos_assigned_gt_inds] label_weights[pos_inds] = 1.0 if len(neg_inds) > 0: label_weights[neg_inds] = 1.0 # compute dense heatmap targets device = labels.device target_assigner_cfg = self.model_cfg.TARGET_ASSIGNER_CONFIG feature_map_size = (self.grid_size[:2] // self.feature_map_stride) heatmap = gt_bboxes_3d.new_zeros(self.num_classes, feature_map_size[1], feature_map_size[0]) for idx in range(len(gt_bboxes_3d)): width = gt_bboxes_3d[idx][3] length = gt_bboxes_3d[idx][4] width = width / self.voxel_size[0] / self.feature_map_stride length = length / self.voxel_size[1] / self.feature_map_stride if width > 0 and length > 0: radius = centernet_utils.gaussian_radius(length.view(-1), width.view(-1), target_assigner_cfg.GAUSSIAN_OVERLAP)[0] radius = max(target_assigner_cfg.MIN_RADIUS, int(radius)) x, y = gt_bboxes_3d[idx][0], gt_bboxes_3d[idx][1] coor_x = (x - self.point_cloud_range[0]) / self.voxel_size[0] / self.feature_map_stride coor_y = (y - self.point_cloud_range[1]) / self.voxel_size[1] / self.feature_map_stride center = torch.tensor([coor_x, coor_y], dtype=torch.float32, device=device) center_int = center.to(torch.int32) centernet_utils.draw_gaussian_to_heatmap(heatmap[gt_labels_3d[idx]], center_int, radius) mean_iou = ious[pos_inds].sum() / max(len(pos_inds), 1) return (labels[None], label_weights[None], bbox_targets[None], bbox_weights[None], int(pos_inds.shape[0]), float(mean_iou), heatmap[None]) def loss(self, gt_bboxes_3d, gt_labels_3d, pred_dicts, **kwargs): labels, label_weights, bbox_targets, bbox_weights, num_pos, matched_ious, heatmap = \ self.get_targets(gt_bboxes_3d, gt_labels_3d, pred_dicts) loss_dict = dict() loss_all = 0 # compute heatmap loss loss_heatmap = self.loss_heatmap( clip_sigmoid(pred_dicts["dense_heatmap"]), heatmap, ).sum() / max(heatmap.eq(1).float().sum().item(), 1) loss_dict["loss_heatmap"] = loss_heatmap.item() * self.loss_heatmap_weight loss_all += loss_heatmap * self.loss_heatmap_weight labels = labels.reshape(-1) label_weights = label_weights.reshape(-1) cls_score = pred_dicts["heatmap"].permute(0, 2, 1).reshape(-1, self.num_classes) one_hot_targets = torch.zeros(*list(labels.shape), self.num_classes+1, dtype=cls_score.dtype, device=labels.device) one_hot_targets.scatter_(-1, labels.unsqueeze(dim=-1).long(), 1.0) one_hot_targets = one_hot_targets[..., :-1] loss_cls = self.loss_cls( cls_score, one_hot_targets, label_weights ).sum() / max(num_pos, 1) preds = torch.cat([pred_dicts[head_name] for head_name in self.model_cfg.SEPARATE_HEAD_CFG.HEAD_ORDER], dim=1).permute(0, 2, 1) code_weights = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['code_weights'] reg_weights = bbox_weights * bbox_weights.new_tensor(code_weights) loss_bbox = self.loss_bbox(preds, bbox_targets) loss_bbox = (loss_bbox * reg_weights).sum() / max(num_pos, 1) loss_dict["loss_cls"] = loss_cls.item() * self.loss_cls_weight loss_dict["loss_bbox"] = loss_bbox.item() * self.loss_bbox_weight loss_all = loss_all + loss_cls * self.loss_cls_weight + loss_bbox * self.loss_bbox_weight loss_dict[f"matched_ious"] = loss_cls.new_tensor(matched_ious) loss_dict['loss_trans'] = loss_all return loss_all,loss_dict def encode_bbox(self, bboxes): code_size = 10 targets = torch.zeros([bboxes.shape[0], code_size]).to(bboxes.device) targets[:, 0] = (bboxes[:, 0] - self.point_cloud_range[0]) / (self.feature_map_stride * self.voxel_size[0]) targets[:, 1] = (bboxes[:, 1] - self.point_cloud_range[1]) / (self.feature_map_stride * self.voxel_size[1]) targets[:, 3:6] = bboxes[:, 3:6].log() targets[:, 2] = bboxes[:, 2] targets[:, 6] = torch.sin(bboxes[:, 6]) targets[:, 7] = torch.cos(bboxes[:, 6]) if code_size == 10: targets[:, 8:10] = bboxes[:, 7:] return targets def decode_bbox(self, heatmap, rot, dim, center, height, vel, filter=False): post_process_cfg = self.model_cfg.POST_PROCESSING score_thresh = post_process_cfg.SCORE_THRESH post_center_range = post_process_cfg.POST_CENTER_RANGE post_center_range = torch.tensor(post_center_range).cuda().float() # class label final_preds = heatmap.max(1, keepdims=False).indices final_scores = heatmap.max(1, keepdims=False).values center[:, 0, :] = center[:, 0, :] * self.feature_map_stride * self.voxel_size[0] + self.point_cloud_range[0] center[:, 1, :] = center[:, 1, :] * self.feature_map_stride * self.voxel_size[1] + self.point_cloud_range[1] dim = dim.exp() rots, rotc = rot[:, 0:1, :], rot[:, 1:2, :] rot = torch.atan2(rots, rotc) if vel is None: final_box_preds = torch.cat([center, height, dim, rot], dim=1).permute(0, 2, 1) else: final_box_preds = torch.cat([center, height, dim, rot, vel], dim=1).permute(0, 2, 1) predictions_dicts = [] for i in range(heatmap.shape[0]): boxes3d = final_box_preds[i] scores = final_scores[i] labels = final_preds[i] predictions_dict = { 'pred_boxes': boxes3d, 'pred_scores': scores, 'pred_labels': labels } predictions_dicts.append(predictions_dict) if filter is False: return predictions_dicts thresh_mask = final_scores > score_thresh mask = (final_box_preds[..., :3] >= post_center_range[:3]).all(2) mask &= (final_box_preds[..., :3] <= post_center_range[3:]).all(2) predictions_dicts = [] for i in range(heatmap.shape[0]): cmask = mask[i, :] cmask &= thresh_mask[i] boxes3d = final_box_preds[i, cmask] scores = final_scores[i, cmask] labels = final_preds[i, cmask] predictions_dict = { 'pred_boxes': boxes3d, 'pred_scores': scores, 'pred_labels': labels, } predictions_dicts.append(predictions_dict) return predictions_dicts def get_bboxes(self, preds_dicts): batch_size = preds_dicts["heatmap"].shape[0] batch_score = preds_dicts["heatmap"].sigmoid() one_hot = F.one_hot( self.query_labels, num_classes=self.num_classes ).permute(0, 2, 1) batch_score = batch_score * preds_dicts["query_heatmap_score"] * one_hot batch_center = preds_dicts["center"] batch_height = preds_dicts["height"] batch_dim = preds_dicts["dim"] batch_rot = preds_dicts["rot"] batch_vel = None if "vel" in preds_dicts: batch_vel = preds_dicts["vel"] ret_dict = self.decode_bbox( batch_score, batch_rot, batch_dim, batch_center, batch_height, batch_vel, filter=True, ) for k in range(batch_size): ret_dict[k]['pred_labels'] = ret_dict[k]['pred_labels'].int() + 1 return ret_dict