From c22d7b90e048aadda3c9e22d7418aedc56f07110 Mon Sep 17 00:00:00 2001 From: inter Date: Sun, 21 Sep 2025 20:19:03 +0800 Subject: [PATCH] Add File --- .../target_assigner/proposal_target_layer.py | 228 ++++++++++++++++++ 1 file changed, 228 insertions(+) create mode 100644 pcdet/models/roi_heads/target_assigner/proposal_target_layer.py diff --git a/pcdet/models/roi_heads/target_assigner/proposal_target_layer.py b/pcdet/models/roi_heads/target_assigner/proposal_target_layer.py new file mode 100644 index 0000000..49f5f0a --- /dev/null +++ b/pcdet/models/roi_heads/target_assigner/proposal_target_layer.py @@ -0,0 +1,228 @@ +import numpy as np +import torch +import torch.nn as nn + +from ....ops.iou3d_nms import iou3d_nms_utils + + +class ProposalTargetLayer(nn.Module): + def __init__(self, roi_sampler_cfg): + super().__init__() + self.roi_sampler_cfg = roi_sampler_cfg + + def forward(self, batch_dict): + """ + Args: + batch_dict: + batch_size: + rois: (B, num_rois, 7 + C) + roi_scores: (B, num_rois) + gt_boxes: (B, N, 7 + C + 1) + roi_labels: (B, num_rois) + Returns: + batch_dict: + rois: (B, M, 7 + C) + gt_of_rois: (B, M, 7 + C) + gt_iou_of_rois: (B, M) + roi_scores: (B, M) + roi_labels: (B, M) + reg_valid_mask: (B, M) + rcnn_cls_labels: (B, M) + """ + batch_rois, batch_gt_of_rois, batch_roi_ious, batch_roi_scores, batch_roi_labels = self.sample_rois_for_rcnn( + batch_dict=batch_dict + ) + # regression valid mask + reg_valid_mask = (batch_roi_ious > self.roi_sampler_cfg.REG_FG_THRESH).long() + + # classification label + if self.roi_sampler_cfg.CLS_SCORE_TYPE == 'cls': + batch_cls_labels = (batch_roi_ious > self.roi_sampler_cfg.CLS_FG_THRESH).long() + ignore_mask = (batch_roi_ious > self.roi_sampler_cfg.CLS_BG_THRESH) & \ + (batch_roi_ious < self.roi_sampler_cfg.CLS_FG_THRESH) + batch_cls_labels[ignore_mask > 0] = -1 + elif self.roi_sampler_cfg.CLS_SCORE_TYPE == 'roi_iou': + iou_bg_thresh = self.roi_sampler_cfg.CLS_BG_THRESH + iou_fg_thresh = self.roi_sampler_cfg.CLS_FG_THRESH + fg_mask = batch_roi_ious > iou_fg_thresh + bg_mask = batch_roi_ious < iou_bg_thresh + interval_mask = (fg_mask == 0) & (bg_mask == 0) + + batch_cls_labels = (fg_mask > 0).float() + batch_cls_labels[interval_mask] = \ + (batch_roi_ious[interval_mask] - iou_bg_thresh) / (iou_fg_thresh - iou_bg_thresh) + else: + raise NotImplementedError + + targets_dict = {'rois': batch_rois, 'gt_of_rois': batch_gt_of_rois, 'gt_iou_of_rois': batch_roi_ious, + 'roi_scores': batch_roi_scores, 'roi_labels': batch_roi_labels, + 'reg_valid_mask': reg_valid_mask, + 'rcnn_cls_labels': batch_cls_labels} + + return targets_dict + + def sample_rois_for_rcnn(self, batch_dict): + """ + Args: + batch_dict: + batch_size: + rois: (B, num_rois, 7 + C) + roi_scores: (B, num_rois) + gt_boxes: (B, N, 7 + C + 1) + roi_labels: (B, num_rois) + Returns: + + """ + batch_size = batch_dict['batch_size'] + rois = batch_dict['rois'] + roi_scores = batch_dict['roi_scores'] + roi_labels = batch_dict['roi_labels'] + gt_boxes = batch_dict['gt_boxes'] + + code_size = rois.shape[-1] + batch_rois = rois.new_zeros(batch_size, self.roi_sampler_cfg.ROI_PER_IMAGE, code_size) + batch_gt_of_rois = rois.new_zeros(batch_size, self.roi_sampler_cfg.ROI_PER_IMAGE, code_size + 1) + batch_roi_ious = rois.new_zeros(batch_size, self.roi_sampler_cfg.ROI_PER_IMAGE) + batch_roi_scores = rois.new_zeros(batch_size, self.roi_sampler_cfg.ROI_PER_IMAGE) + batch_roi_labels = rois.new_zeros((batch_size, self.roi_sampler_cfg.ROI_PER_IMAGE), dtype=torch.long) + + for index in range(batch_size): + cur_roi, cur_gt, cur_roi_labels, cur_roi_scores = \ + rois[index], gt_boxes[index], roi_labels[index], roi_scores[index] + k = cur_gt.__len__() - 1 + while k >= 0 and cur_gt[k].sum() == 0: + k -= 1 + cur_gt = cur_gt[:k + 1] + cur_gt = cur_gt.new_zeros((1, cur_gt.shape[1])) if len(cur_gt) == 0 else cur_gt + + if self.roi_sampler_cfg.get('SAMPLE_ROI_BY_EACH_CLASS', False): + max_overlaps, gt_assignment = self.get_max_iou_with_same_class( + rois=cur_roi, roi_labels=cur_roi_labels, + gt_boxes=cur_gt[:, 0:7], gt_labels=cur_gt[:, -1].long() + ) + else: + iou3d = iou3d_nms_utils.boxes_iou3d_gpu(cur_roi, cur_gt[:, 0:7]) # (M, N) + max_overlaps, gt_assignment = torch.max(iou3d, dim=1) + + sampled_inds = self.subsample_rois(max_overlaps=max_overlaps) + + batch_rois[index] = cur_roi[sampled_inds] + batch_roi_labels[index] = cur_roi_labels[sampled_inds] + batch_roi_ious[index] = max_overlaps[sampled_inds] + batch_roi_scores[index] = cur_roi_scores[sampled_inds] + batch_gt_of_rois[index] = cur_gt[gt_assignment[sampled_inds]] + + return batch_rois, batch_gt_of_rois, batch_roi_ious, batch_roi_scores, batch_roi_labels + + def subsample_rois(self, max_overlaps): + # sample fg, easy_bg, hard_bg + fg_rois_per_image = int(np.round(self.roi_sampler_cfg.FG_RATIO * self.roi_sampler_cfg.ROI_PER_IMAGE)) + fg_thresh = min(self.roi_sampler_cfg.REG_FG_THRESH, self.roi_sampler_cfg.CLS_FG_THRESH) + + fg_inds = ((max_overlaps >= fg_thresh)).nonzero().view(-1) + easy_bg_inds = ((max_overlaps < self.roi_sampler_cfg.CLS_BG_THRESH_LO)).nonzero().view(-1) + hard_bg_inds = ((max_overlaps < self.roi_sampler_cfg.REG_FG_THRESH) & + (max_overlaps >= self.roi_sampler_cfg.CLS_BG_THRESH_LO)).nonzero().view(-1) + + fg_num_rois = fg_inds.numel() + bg_num_rois = hard_bg_inds.numel() + easy_bg_inds.numel() + + if fg_num_rois > 0 and bg_num_rois > 0: + # sampling fg + fg_rois_per_this_image = min(fg_rois_per_image, fg_num_rois) + + rand_num = torch.from_numpy(np.random.permutation(fg_num_rois)).type_as(max_overlaps).long() + fg_inds = fg_inds[rand_num[:fg_rois_per_this_image]] + + # sampling bg + bg_rois_per_this_image = self.roi_sampler_cfg.ROI_PER_IMAGE - fg_rois_per_this_image + bg_inds = self.sample_bg_inds( + hard_bg_inds, easy_bg_inds, bg_rois_per_this_image, self.roi_sampler_cfg.HARD_BG_RATIO + ) + + elif fg_num_rois > 0 and bg_num_rois == 0: + # sampling fg + rand_num = np.floor(np.random.rand(self.roi_sampler_cfg.ROI_PER_IMAGE) * fg_num_rois) + rand_num = torch.from_numpy(rand_num).type_as(max_overlaps).long() + fg_inds = fg_inds[rand_num] + bg_inds = fg_inds[fg_inds < 0] # yield empty tensor + + elif bg_num_rois > 0 and fg_num_rois == 0: + # sampling bg + bg_rois_per_this_image = self.roi_sampler_cfg.ROI_PER_IMAGE + bg_inds = self.sample_bg_inds( + hard_bg_inds, easy_bg_inds, bg_rois_per_this_image, self.roi_sampler_cfg.HARD_BG_RATIO + ) + else: + print('maxoverlaps:(min=%f, max=%f)' % (max_overlaps.min().item(), max_overlaps.max().item())) + print('ERROR: FG=%d, BG=%d' % (fg_num_rois, bg_num_rois)) + raise NotImplementedError + + sampled_inds = torch.cat((fg_inds, bg_inds), dim=0) + return sampled_inds + + @staticmethod + def sample_bg_inds(hard_bg_inds, easy_bg_inds, bg_rois_per_this_image, hard_bg_ratio): + if hard_bg_inds.numel() > 0 and easy_bg_inds.numel() > 0: + hard_bg_rois_num = min(int(bg_rois_per_this_image * hard_bg_ratio), len(hard_bg_inds)) + easy_bg_rois_num = bg_rois_per_this_image - hard_bg_rois_num + + # sampling hard bg + rand_idx = torch.randint(low=0, high=hard_bg_inds.numel(), size=(hard_bg_rois_num,)).long() + hard_bg_inds = hard_bg_inds[rand_idx] + + # sampling easy bg + rand_idx = torch.randint(low=0, high=easy_bg_inds.numel(), size=(easy_bg_rois_num,)).long() + easy_bg_inds = easy_bg_inds[rand_idx] + + bg_inds = torch.cat([hard_bg_inds, easy_bg_inds], dim=0) + elif hard_bg_inds.numel() > 0 and easy_bg_inds.numel() == 0: + hard_bg_rois_num = bg_rois_per_this_image + # sampling hard bg + rand_idx = torch.randint(low=0, high=hard_bg_inds.numel(), size=(hard_bg_rois_num,)).long() + bg_inds = hard_bg_inds[rand_idx] + elif hard_bg_inds.numel() == 0 and easy_bg_inds.numel() > 0: + easy_bg_rois_num = bg_rois_per_this_image + # sampling easy bg + rand_idx = torch.randint(low=0, high=easy_bg_inds.numel(), size=(easy_bg_rois_num,)).long() + bg_inds = easy_bg_inds[rand_idx] + else: + raise NotImplementedError + + return bg_inds + + @staticmethod + def get_max_iou_with_same_class(rois, roi_labels, gt_boxes, gt_labels): + """ + Args: + rois: (N, 7) + roi_labels: (N) + gt_boxes: (N, ) + gt_labels: + + Returns: + + """ + """ + :param rois: (N, 7) + :param roi_labels: (N) + :param gt_boxes: (N, 8) + :return: + """ + max_overlaps = rois.new_zeros(rois.shape[0]) + gt_assignment = roi_labels.new_zeros(roi_labels.shape[0]) + + for k in range(gt_labels.min().item(), gt_labels.max().item() + 1): + roi_mask = (roi_labels == k) + gt_mask = (gt_labels == k) + if roi_mask.sum() > 0 and gt_mask.sum() > 0: + cur_roi = rois[roi_mask] + cur_gt = gt_boxes[gt_mask] + original_gt_assignment = gt_mask.nonzero().view(-1) + + iou3d = iou3d_nms_utils.boxes_iou3d_gpu(cur_roi[:, :7], cur_gt[:, :7]) # (M, N) + cur_max_overlaps, cur_gt_assignment = torch.max(iou3d, dim=1) + max_overlaps[roi_mask] = cur_max_overlaps + gt_assignment[roi_mask] = original_gt_assignment[cur_gt_assignment] + + return max_overlaps, gt_assignment