This commit is contained in:
2025-09-21 20:19:03 +08:00
parent 33e43a50b6
commit c22d7b90e0

View File

@@ -0,0 +1,228 @@
import numpy as np
import torch
import torch.nn as nn
from ....ops.iou3d_nms import iou3d_nms_utils
class ProposalTargetLayer(nn.Module):
def __init__(self, roi_sampler_cfg):
super().__init__()
self.roi_sampler_cfg = roi_sampler_cfg
def forward(self, batch_dict):
"""
Args:
batch_dict:
batch_size:
rois: (B, num_rois, 7 + C)
roi_scores: (B, num_rois)
gt_boxes: (B, N, 7 + C + 1)
roi_labels: (B, num_rois)
Returns:
batch_dict:
rois: (B, M, 7 + C)
gt_of_rois: (B, M, 7 + C)
gt_iou_of_rois: (B, M)
roi_scores: (B, M)
roi_labels: (B, M)
reg_valid_mask: (B, M)
rcnn_cls_labels: (B, M)
"""
batch_rois, batch_gt_of_rois, batch_roi_ious, batch_roi_scores, batch_roi_labels = self.sample_rois_for_rcnn(
batch_dict=batch_dict
)
# regression valid mask
reg_valid_mask = (batch_roi_ious > self.roi_sampler_cfg.REG_FG_THRESH).long()
# classification label
if self.roi_sampler_cfg.CLS_SCORE_TYPE == 'cls':
batch_cls_labels = (batch_roi_ious > self.roi_sampler_cfg.CLS_FG_THRESH).long()
ignore_mask = (batch_roi_ious > self.roi_sampler_cfg.CLS_BG_THRESH) & \
(batch_roi_ious < self.roi_sampler_cfg.CLS_FG_THRESH)
batch_cls_labels[ignore_mask > 0] = -1
elif self.roi_sampler_cfg.CLS_SCORE_TYPE == 'roi_iou':
iou_bg_thresh = self.roi_sampler_cfg.CLS_BG_THRESH
iou_fg_thresh = self.roi_sampler_cfg.CLS_FG_THRESH
fg_mask = batch_roi_ious > iou_fg_thresh
bg_mask = batch_roi_ious < iou_bg_thresh
interval_mask = (fg_mask == 0) & (bg_mask == 0)
batch_cls_labels = (fg_mask > 0).float()
batch_cls_labels[interval_mask] = \
(batch_roi_ious[interval_mask] - iou_bg_thresh) / (iou_fg_thresh - iou_bg_thresh)
else:
raise NotImplementedError
targets_dict = {'rois': batch_rois, 'gt_of_rois': batch_gt_of_rois, 'gt_iou_of_rois': batch_roi_ious,
'roi_scores': batch_roi_scores, 'roi_labels': batch_roi_labels,
'reg_valid_mask': reg_valid_mask,
'rcnn_cls_labels': batch_cls_labels}
return targets_dict
def sample_rois_for_rcnn(self, batch_dict):
"""
Args:
batch_dict:
batch_size:
rois: (B, num_rois, 7 + C)
roi_scores: (B, num_rois)
gt_boxes: (B, N, 7 + C + 1)
roi_labels: (B, num_rois)
Returns:
"""
batch_size = batch_dict['batch_size']
rois = batch_dict['rois']
roi_scores = batch_dict['roi_scores']
roi_labels = batch_dict['roi_labels']
gt_boxes = batch_dict['gt_boxes']
code_size = rois.shape[-1]
batch_rois = rois.new_zeros(batch_size, self.roi_sampler_cfg.ROI_PER_IMAGE, code_size)
batch_gt_of_rois = rois.new_zeros(batch_size, self.roi_sampler_cfg.ROI_PER_IMAGE, code_size + 1)
batch_roi_ious = rois.new_zeros(batch_size, self.roi_sampler_cfg.ROI_PER_IMAGE)
batch_roi_scores = rois.new_zeros(batch_size, self.roi_sampler_cfg.ROI_PER_IMAGE)
batch_roi_labels = rois.new_zeros((batch_size, self.roi_sampler_cfg.ROI_PER_IMAGE), dtype=torch.long)
for index in range(batch_size):
cur_roi, cur_gt, cur_roi_labels, cur_roi_scores = \
rois[index], gt_boxes[index], roi_labels[index], roi_scores[index]
k = cur_gt.__len__() - 1
while k >= 0 and cur_gt[k].sum() == 0:
k -= 1
cur_gt = cur_gt[:k + 1]
cur_gt = cur_gt.new_zeros((1, cur_gt.shape[1])) if len(cur_gt) == 0 else cur_gt
if self.roi_sampler_cfg.get('SAMPLE_ROI_BY_EACH_CLASS', False):
max_overlaps, gt_assignment = self.get_max_iou_with_same_class(
rois=cur_roi, roi_labels=cur_roi_labels,
gt_boxes=cur_gt[:, 0:7], gt_labels=cur_gt[:, -1].long()
)
else:
iou3d = iou3d_nms_utils.boxes_iou3d_gpu(cur_roi, cur_gt[:, 0:7]) # (M, N)
max_overlaps, gt_assignment = torch.max(iou3d, dim=1)
sampled_inds = self.subsample_rois(max_overlaps=max_overlaps)
batch_rois[index] = cur_roi[sampled_inds]
batch_roi_labels[index] = cur_roi_labels[sampled_inds]
batch_roi_ious[index] = max_overlaps[sampled_inds]
batch_roi_scores[index] = cur_roi_scores[sampled_inds]
batch_gt_of_rois[index] = cur_gt[gt_assignment[sampled_inds]]
return batch_rois, batch_gt_of_rois, batch_roi_ious, batch_roi_scores, batch_roi_labels
def subsample_rois(self, max_overlaps):
# sample fg, easy_bg, hard_bg
fg_rois_per_image = int(np.round(self.roi_sampler_cfg.FG_RATIO * self.roi_sampler_cfg.ROI_PER_IMAGE))
fg_thresh = min(self.roi_sampler_cfg.REG_FG_THRESH, self.roi_sampler_cfg.CLS_FG_THRESH)
fg_inds = ((max_overlaps >= fg_thresh)).nonzero().view(-1)
easy_bg_inds = ((max_overlaps < self.roi_sampler_cfg.CLS_BG_THRESH_LO)).nonzero().view(-1)
hard_bg_inds = ((max_overlaps < self.roi_sampler_cfg.REG_FG_THRESH) &
(max_overlaps >= self.roi_sampler_cfg.CLS_BG_THRESH_LO)).nonzero().view(-1)
fg_num_rois = fg_inds.numel()
bg_num_rois = hard_bg_inds.numel() + easy_bg_inds.numel()
if fg_num_rois > 0 and bg_num_rois > 0:
# sampling fg
fg_rois_per_this_image = min(fg_rois_per_image, fg_num_rois)
rand_num = torch.from_numpy(np.random.permutation(fg_num_rois)).type_as(max_overlaps).long()
fg_inds = fg_inds[rand_num[:fg_rois_per_this_image]]
# sampling bg
bg_rois_per_this_image = self.roi_sampler_cfg.ROI_PER_IMAGE - fg_rois_per_this_image
bg_inds = self.sample_bg_inds(
hard_bg_inds, easy_bg_inds, bg_rois_per_this_image, self.roi_sampler_cfg.HARD_BG_RATIO
)
elif fg_num_rois > 0 and bg_num_rois == 0:
# sampling fg
rand_num = np.floor(np.random.rand(self.roi_sampler_cfg.ROI_PER_IMAGE) * fg_num_rois)
rand_num = torch.from_numpy(rand_num).type_as(max_overlaps).long()
fg_inds = fg_inds[rand_num]
bg_inds = fg_inds[fg_inds < 0] # yield empty tensor
elif bg_num_rois > 0 and fg_num_rois == 0:
# sampling bg
bg_rois_per_this_image = self.roi_sampler_cfg.ROI_PER_IMAGE
bg_inds = self.sample_bg_inds(
hard_bg_inds, easy_bg_inds, bg_rois_per_this_image, self.roi_sampler_cfg.HARD_BG_RATIO
)
else:
print('maxoverlaps:(min=%f, max=%f)' % (max_overlaps.min().item(), max_overlaps.max().item()))
print('ERROR: FG=%d, BG=%d' % (fg_num_rois, bg_num_rois))
raise NotImplementedError
sampled_inds = torch.cat((fg_inds, bg_inds), dim=0)
return sampled_inds
@staticmethod
def sample_bg_inds(hard_bg_inds, easy_bg_inds, bg_rois_per_this_image, hard_bg_ratio):
if hard_bg_inds.numel() > 0 and easy_bg_inds.numel() > 0:
hard_bg_rois_num = min(int(bg_rois_per_this_image * hard_bg_ratio), len(hard_bg_inds))
easy_bg_rois_num = bg_rois_per_this_image - hard_bg_rois_num
# sampling hard bg
rand_idx = torch.randint(low=0, high=hard_bg_inds.numel(), size=(hard_bg_rois_num,)).long()
hard_bg_inds = hard_bg_inds[rand_idx]
# sampling easy bg
rand_idx = torch.randint(low=0, high=easy_bg_inds.numel(), size=(easy_bg_rois_num,)).long()
easy_bg_inds = easy_bg_inds[rand_idx]
bg_inds = torch.cat([hard_bg_inds, easy_bg_inds], dim=0)
elif hard_bg_inds.numel() > 0 and easy_bg_inds.numel() == 0:
hard_bg_rois_num = bg_rois_per_this_image
# sampling hard bg
rand_idx = torch.randint(low=0, high=hard_bg_inds.numel(), size=(hard_bg_rois_num,)).long()
bg_inds = hard_bg_inds[rand_idx]
elif hard_bg_inds.numel() == 0 and easy_bg_inds.numel() > 0:
easy_bg_rois_num = bg_rois_per_this_image
# sampling easy bg
rand_idx = torch.randint(low=0, high=easy_bg_inds.numel(), size=(easy_bg_rois_num,)).long()
bg_inds = easy_bg_inds[rand_idx]
else:
raise NotImplementedError
return bg_inds
@staticmethod
def get_max_iou_with_same_class(rois, roi_labels, gt_boxes, gt_labels):
"""
Args:
rois: (N, 7)
roi_labels: (N)
gt_boxes: (N, )
gt_labels:
Returns:
"""
"""
:param rois: (N, 7)
:param roi_labels: (N)
:param gt_boxes: (N, 8)
:return:
"""
max_overlaps = rois.new_zeros(rois.shape[0])
gt_assignment = roi_labels.new_zeros(roi_labels.shape[0])
for k in range(gt_labels.min().item(), gt_labels.max().item() + 1):
roi_mask = (roi_labels == k)
gt_mask = (gt_labels == k)
if roi_mask.sum() > 0 and gt_mask.sum() > 0:
cur_roi = rois[roi_mask]
cur_gt = gt_boxes[gt_mask]
original_gt_assignment = gt_mask.nonzero().view(-1)
iou3d = iou3d_nms_utils.boxes_iou3d_gpu(cur_roi[:, :7], cur_gt[:, :7]) # (M, N)
cur_max_overlaps, cur_gt_assignment = torch.max(iou3d, dim=1)
max_overlaps[roi_mask] = cur_max_overlaps
gt_assignment[roi_mask] = original_gt_assignment[cur_gt_assignment]
return max_overlaps, gt_assignment