import numpy as np import torch import torch.nn as nn from ....ops.iou3d_nms import iou3d_nms_utils class ProposalTargetLayer(nn.Module): def __init__(self, roi_sampler_cfg): super().__init__() self.roi_sampler_cfg = roi_sampler_cfg def forward(self, batch_dict): """ Args: batch_dict: batch_size: rois: (B, num_rois, 7 + C) roi_scores: (B, num_rois) gt_boxes: (B, N, 7 + C + 1) roi_labels: (B, num_rois) Returns: batch_dict: rois: (B, M, 7 + C) gt_of_rois: (B, M, 7 + C) gt_iou_of_rois: (B, M) roi_scores: (B, M) roi_labels: (B, M) reg_valid_mask: (B, M) rcnn_cls_labels: (B, M) """ batch_rois, batch_gt_of_rois, batch_roi_ious, batch_roi_scores, batch_roi_labels = self.sample_rois_for_rcnn( batch_dict=batch_dict ) # regression valid mask reg_valid_mask = (batch_roi_ious > self.roi_sampler_cfg.REG_FG_THRESH).long() # classification label if self.roi_sampler_cfg.CLS_SCORE_TYPE == 'cls': batch_cls_labels = (batch_roi_ious > self.roi_sampler_cfg.CLS_FG_THRESH).long() ignore_mask = (batch_roi_ious > self.roi_sampler_cfg.CLS_BG_THRESH) & \ (batch_roi_ious < self.roi_sampler_cfg.CLS_FG_THRESH) batch_cls_labels[ignore_mask > 0] = -1 elif self.roi_sampler_cfg.CLS_SCORE_TYPE == 'roi_iou': iou_bg_thresh = self.roi_sampler_cfg.CLS_BG_THRESH iou_fg_thresh = self.roi_sampler_cfg.CLS_FG_THRESH fg_mask = batch_roi_ious > iou_fg_thresh bg_mask = batch_roi_ious < iou_bg_thresh interval_mask = (fg_mask == 0) & (bg_mask == 0) batch_cls_labels = (fg_mask > 0).float() batch_cls_labels[interval_mask] = \ (batch_roi_ious[interval_mask] - iou_bg_thresh) / (iou_fg_thresh - iou_bg_thresh) else: raise NotImplementedError targets_dict = {'rois': batch_rois, 'gt_of_rois': batch_gt_of_rois, 'gt_iou_of_rois': batch_roi_ious, 'roi_scores': batch_roi_scores, 'roi_labels': batch_roi_labels, 'reg_valid_mask': reg_valid_mask, 'rcnn_cls_labels': batch_cls_labels} return targets_dict def sample_rois_for_rcnn(self, batch_dict): """ Args: batch_dict: batch_size: rois: (B, num_rois, 7 + C) roi_scores: (B, num_rois) gt_boxes: (B, N, 7 + C + 1) roi_labels: (B, num_rois) Returns: """ batch_size = batch_dict['batch_size'] rois = batch_dict['rois'] roi_scores = batch_dict['roi_scores'] roi_labels = batch_dict['roi_labels'] gt_boxes = batch_dict['gt_boxes'] code_size = rois.shape[-1] batch_rois = rois.new_zeros(batch_size, self.roi_sampler_cfg.ROI_PER_IMAGE, code_size) batch_gt_of_rois = rois.new_zeros(batch_size, self.roi_sampler_cfg.ROI_PER_IMAGE, code_size + 1) batch_roi_ious = rois.new_zeros(batch_size, self.roi_sampler_cfg.ROI_PER_IMAGE) batch_roi_scores = rois.new_zeros(batch_size, self.roi_sampler_cfg.ROI_PER_IMAGE) batch_roi_labels = rois.new_zeros((batch_size, self.roi_sampler_cfg.ROI_PER_IMAGE), dtype=torch.long) for index in range(batch_size): cur_roi, cur_gt, cur_roi_labels, cur_roi_scores = \ rois[index], gt_boxes[index], roi_labels[index], roi_scores[index] k = cur_gt.__len__() - 1 while k >= 0 and cur_gt[k].sum() == 0: k -= 1 cur_gt = cur_gt[:k + 1] cur_gt = cur_gt.new_zeros((1, cur_gt.shape[1])) if len(cur_gt) == 0 else cur_gt if self.roi_sampler_cfg.get('SAMPLE_ROI_BY_EACH_CLASS', False): max_overlaps, gt_assignment = self.get_max_iou_with_same_class( rois=cur_roi, roi_labels=cur_roi_labels, gt_boxes=cur_gt[:, 0:7], gt_labels=cur_gt[:, -1].long() ) else: iou3d = iou3d_nms_utils.boxes_iou3d_gpu(cur_roi, cur_gt[:, 0:7]) # (M, N) max_overlaps, gt_assignment = torch.max(iou3d, dim=1) sampled_inds = self.subsample_rois(max_overlaps=max_overlaps) batch_rois[index] = cur_roi[sampled_inds] batch_roi_labels[index] = cur_roi_labels[sampled_inds] batch_roi_ious[index] = max_overlaps[sampled_inds] batch_roi_scores[index] = cur_roi_scores[sampled_inds] batch_gt_of_rois[index] = cur_gt[gt_assignment[sampled_inds]] return batch_rois, batch_gt_of_rois, batch_roi_ious, batch_roi_scores, batch_roi_labels def subsample_rois(self, max_overlaps): # sample fg, easy_bg, hard_bg fg_rois_per_image = int(np.round(self.roi_sampler_cfg.FG_RATIO * self.roi_sampler_cfg.ROI_PER_IMAGE)) fg_thresh = min(self.roi_sampler_cfg.REG_FG_THRESH, self.roi_sampler_cfg.CLS_FG_THRESH) fg_inds = ((max_overlaps >= fg_thresh)).nonzero().view(-1) easy_bg_inds = ((max_overlaps < self.roi_sampler_cfg.CLS_BG_THRESH_LO)).nonzero().view(-1) hard_bg_inds = ((max_overlaps < self.roi_sampler_cfg.REG_FG_THRESH) & (max_overlaps >= self.roi_sampler_cfg.CLS_BG_THRESH_LO)).nonzero().view(-1) fg_num_rois = fg_inds.numel() bg_num_rois = hard_bg_inds.numel() + easy_bg_inds.numel() if fg_num_rois > 0 and bg_num_rois > 0: # sampling fg fg_rois_per_this_image = min(fg_rois_per_image, fg_num_rois) rand_num = torch.from_numpy(np.random.permutation(fg_num_rois)).type_as(max_overlaps).long() fg_inds = fg_inds[rand_num[:fg_rois_per_this_image]] # sampling bg bg_rois_per_this_image = self.roi_sampler_cfg.ROI_PER_IMAGE - fg_rois_per_this_image bg_inds = self.sample_bg_inds( hard_bg_inds, easy_bg_inds, bg_rois_per_this_image, self.roi_sampler_cfg.HARD_BG_RATIO ) elif fg_num_rois > 0 and bg_num_rois == 0: # sampling fg rand_num = np.floor(np.random.rand(self.roi_sampler_cfg.ROI_PER_IMAGE) * fg_num_rois) rand_num = torch.from_numpy(rand_num).type_as(max_overlaps).long() fg_inds = fg_inds[rand_num] bg_inds = fg_inds[fg_inds < 0] # yield empty tensor elif bg_num_rois > 0 and fg_num_rois == 0: # sampling bg bg_rois_per_this_image = self.roi_sampler_cfg.ROI_PER_IMAGE bg_inds = self.sample_bg_inds( hard_bg_inds, easy_bg_inds, bg_rois_per_this_image, self.roi_sampler_cfg.HARD_BG_RATIO ) else: print('maxoverlaps:(min=%f, max=%f)' % (max_overlaps.min().item(), max_overlaps.max().item())) print('ERROR: FG=%d, BG=%d' % (fg_num_rois, bg_num_rois)) raise NotImplementedError sampled_inds = torch.cat((fg_inds, bg_inds), dim=0) return sampled_inds @staticmethod def sample_bg_inds(hard_bg_inds, easy_bg_inds, bg_rois_per_this_image, hard_bg_ratio): if hard_bg_inds.numel() > 0 and easy_bg_inds.numel() > 0: hard_bg_rois_num = min(int(bg_rois_per_this_image * hard_bg_ratio), len(hard_bg_inds)) easy_bg_rois_num = bg_rois_per_this_image - hard_bg_rois_num # sampling hard bg rand_idx = torch.randint(low=0, high=hard_bg_inds.numel(), size=(hard_bg_rois_num,)).long() hard_bg_inds = hard_bg_inds[rand_idx] # sampling easy bg rand_idx = torch.randint(low=0, high=easy_bg_inds.numel(), size=(easy_bg_rois_num,)).long() easy_bg_inds = easy_bg_inds[rand_idx] bg_inds = torch.cat([hard_bg_inds, easy_bg_inds], dim=0) elif hard_bg_inds.numel() > 0 and easy_bg_inds.numel() == 0: hard_bg_rois_num = bg_rois_per_this_image # sampling hard bg rand_idx = torch.randint(low=0, high=hard_bg_inds.numel(), size=(hard_bg_rois_num,)).long() bg_inds = hard_bg_inds[rand_idx] elif hard_bg_inds.numel() == 0 and easy_bg_inds.numel() > 0: easy_bg_rois_num = bg_rois_per_this_image # sampling easy bg rand_idx = torch.randint(low=0, high=easy_bg_inds.numel(), size=(easy_bg_rois_num,)).long() bg_inds = easy_bg_inds[rand_idx] else: raise NotImplementedError return bg_inds @staticmethod def get_max_iou_with_same_class(rois, roi_labels, gt_boxes, gt_labels): """ Args: rois: (N, 7) roi_labels: (N) gt_boxes: (N, ) gt_labels: Returns: """ """ :param rois: (N, 7) :param roi_labels: (N) :param gt_boxes: (N, 8) :return: """ max_overlaps = rois.new_zeros(rois.shape[0]) gt_assignment = roi_labels.new_zeros(roi_labels.shape[0]) for k in range(gt_labels.min().item(), gt_labels.max().item() + 1): roi_mask = (roi_labels == k) gt_mask = (gt_labels == k) if roi_mask.sum() > 0 and gt_mask.sum() > 0: cur_roi = rois[roi_mask] cur_gt = gt_boxes[gt_mask] original_gt_assignment = gt_mask.nonzero().view(-1) iou3d = iou3d_nms_utils.boxes_iou3d_gpu(cur_roi[:, :7], cur_gt[:, :7]) # (M, N) cur_max_overlaps, cur_gt_assignment = torch.max(iou3d, dim=1) max_overlaps[roi_mask] = cur_max_overlaps gt_assignment[roi_mask] = original_gt_assignment[cur_gt_assignment] return max_overlaps, gt_assignment