Add File
This commit is contained in:
@@ -0,0 +1,210 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ....ops.iou3d_nms import iou3d_nms_utils
|
||||||
|
from ....utils import box_utils
|
||||||
|
|
||||||
|
|
||||||
|
class AxisAlignedTargetAssigner(object):
|
||||||
|
def __init__(self, model_cfg, class_names, box_coder, match_height=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
anchor_generator_cfg = model_cfg.ANCHOR_GENERATOR_CONFIG
|
||||||
|
anchor_target_cfg = model_cfg.TARGET_ASSIGNER_CONFIG
|
||||||
|
self.box_coder = box_coder
|
||||||
|
self.match_height = match_height
|
||||||
|
self.class_names = np.array(class_names)
|
||||||
|
self.anchor_class_names = [config['class_name'] for config in anchor_generator_cfg]
|
||||||
|
self.pos_fraction = anchor_target_cfg.POS_FRACTION if anchor_target_cfg.POS_FRACTION >= 0 else None
|
||||||
|
self.sample_size = anchor_target_cfg.SAMPLE_SIZE
|
||||||
|
self.norm_by_num_examples = anchor_target_cfg.NORM_BY_NUM_EXAMPLES
|
||||||
|
self.matched_thresholds = {}
|
||||||
|
self.unmatched_thresholds = {}
|
||||||
|
for config in anchor_generator_cfg:
|
||||||
|
self.matched_thresholds[config['class_name']] = config['matched_threshold']
|
||||||
|
self.unmatched_thresholds[config['class_name']] = config['unmatched_threshold']
|
||||||
|
|
||||||
|
self.use_multihead = model_cfg.get('USE_MULTIHEAD', False)
|
||||||
|
# self.separate_multihead = model_cfg.get('SEPARATE_MULTIHEAD', False)
|
||||||
|
# if self.seperate_multihead:
|
||||||
|
# rpn_head_cfgs = model_cfg.RPN_HEAD_CFGS
|
||||||
|
# self.gt_remapping = {}
|
||||||
|
# for rpn_head_cfg in rpn_head_cfgs:
|
||||||
|
# for idx, name in enumerate(rpn_head_cfg['HEAD_CLS_NAME']):
|
||||||
|
# self.gt_remapping[name] = idx + 1
|
||||||
|
|
||||||
|
def assign_targets(self, all_anchors, gt_boxes_with_classes):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
all_anchors: [(N, 7), ...]
|
||||||
|
gt_boxes: (B, M, 8)
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
bbox_targets = []
|
||||||
|
cls_labels = []
|
||||||
|
reg_weights = []
|
||||||
|
|
||||||
|
batch_size = gt_boxes_with_classes.shape[0]
|
||||||
|
gt_classes = gt_boxes_with_classes[:, :, -1]
|
||||||
|
gt_boxes = gt_boxes_with_classes[:, :, :-1]
|
||||||
|
for k in range(batch_size):
|
||||||
|
cur_gt = gt_boxes[k]
|
||||||
|
cnt = cur_gt.__len__() - 1
|
||||||
|
while cnt > 0 and cur_gt[cnt].sum() == 0:
|
||||||
|
cnt -= 1
|
||||||
|
cur_gt = cur_gt[:cnt + 1]
|
||||||
|
cur_gt_classes = gt_classes[k][:cnt + 1].int()
|
||||||
|
|
||||||
|
target_list = []
|
||||||
|
for anchor_class_name, anchors in zip(self.anchor_class_names, all_anchors):
|
||||||
|
if cur_gt_classes.shape[0] > 1:
|
||||||
|
mask = torch.from_numpy(self.class_names[cur_gt_classes.cpu() - 1] == anchor_class_name)
|
||||||
|
else:
|
||||||
|
mask = torch.tensor([self.class_names[c - 1] == anchor_class_name
|
||||||
|
for c in cur_gt_classes], dtype=torch.bool)
|
||||||
|
|
||||||
|
if self.use_multihead:
|
||||||
|
anchors = anchors.permute(3, 4, 0, 1, 2, 5).contiguous().view(-1, anchors.shape[-1])
|
||||||
|
# if self.seperate_multihead:
|
||||||
|
# selected_classes = cur_gt_classes[mask].clone()
|
||||||
|
# if len(selected_classes) > 0:
|
||||||
|
# new_cls_id = self.gt_remapping[anchor_class_name]
|
||||||
|
# selected_classes[:] = new_cls_id
|
||||||
|
# else:
|
||||||
|
# selected_classes = cur_gt_classes[mask]
|
||||||
|
selected_classes = cur_gt_classes[mask]
|
||||||
|
else:
|
||||||
|
feature_map_size = anchors.shape[:3]
|
||||||
|
anchors = anchors.view(-1, anchors.shape[-1])
|
||||||
|
selected_classes = cur_gt_classes[mask]
|
||||||
|
|
||||||
|
single_target = self.assign_targets_single(
|
||||||
|
anchors,
|
||||||
|
cur_gt[mask],
|
||||||
|
gt_classes=selected_classes,
|
||||||
|
matched_threshold=self.matched_thresholds[anchor_class_name],
|
||||||
|
unmatched_threshold=self.unmatched_thresholds[anchor_class_name]
|
||||||
|
)
|
||||||
|
target_list.append(single_target)
|
||||||
|
|
||||||
|
if self.use_multihead:
|
||||||
|
target_dict = {
|
||||||
|
'box_cls_labels': [t['box_cls_labels'].view(-1) for t in target_list],
|
||||||
|
'box_reg_targets': [t['box_reg_targets'].view(-1, self.box_coder.code_size) for t in target_list],
|
||||||
|
'reg_weights': [t['reg_weights'].view(-1) for t in target_list]
|
||||||
|
}
|
||||||
|
|
||||||
|
target_dict['box_reg_targets'] = torch.cat(target_dict['box_reg_targets'], dim=0)
|
||||||
|
target_dict['box_cls_labels'] = torch.cat(target_dict['box_cls_labels'], dim=0).view(-1)
|
||||||
|
target_dict['reg_weights'] = torch.cat(target_dict['reg_weights'], dim=0).view(-1)
|
||||||
|
else:
|
||||||
|
target_dict = {
|
||||||
|
'box_cls_labels': [t['box_cls_labels'].view(*feature_map_size, -1) for t in target_list],
|
||||||
|
'box_reg_targets': [t['box_reg_targets'].view(*feature_map_size, -1, self.box_coder.code_size)
|
||||||
|
for t in target_list],
|
||||||
|
'reg_weights': [t['reg_weights'].view(*feature_map_size, -1) for t in target_list]
|
||||||
|
}
|
||||||
|
target_dict['box_reg_targets'] = torch.cat(
|
||||||
|
target_dict['box_reg_targets'], dim=-2
|
||||||
|
).view(-1, self.box_coder.code_size)
|
||||||
|
|
||||||
|
target_dict['box_cls_labels'] = torch.cat(target_dict['box_cls_labels'], dim=-1).view(-1)
|
||||||
|
target_dict['reg_weights'] = torch.cat(target_dict['reg_weights'], dim=-1).view(-1)
|
||||||
|
|
||||||
|
bbox_targets.append(target_dict['box_reg_targets'])
|
||||||
|
cls_labels.append(target_dict['box_cls_labels'])
|
||||||
|
reg_weights.append(target_dict['reg_weights'])
|
||||||
|
|
||||||
|
bbox_targets = torch.stack(bbox_targets, dim=0)
|
||||||
|
|
||||||
|
cls_labels = torch.stack(cls_labels, dim=0)
|
||||||
|
reg_weights = torch.stack(reg_weights, dim=0)
|
||||||
|
all_targets_dict = {
|
||||||
|
'box_cls_labels': cls_labels,
|
||||||
|
'box_reg_targets': bbox_targets,
|
||||||
|
'reg_weights': reg_weights
|
||||||
|
|
||||||
|
}
|
||||||
|
return all_targets_dict
|
||||||
|
|
||||||
|
def assign_targets_single(self, anchors, gt_boxes, gt_classes, matched_threshold=0.6, unmatched_threshold=0.45):
|
||||||
|
|
||||||
|
num_anchors = anchors.shape[0]
|
||||||
|
num_gt = gt_boxes.shape[0]
|
||||||
|
|
||||||
|
labels = torch.ones((num_anchors,), dtype=torch.int32, device=anchors.device) * -1
|
||||||
|
gt_ids = torch.ones((num_anchors,), dtype=torch.int32, device=anchors.device) * -1
|
||||||
|
|
||||||
|
if len(gt_boxes) > 0 and anchors.shape[0] > 0:
|
||||||
|
anchor_by_gt_overlap = iou3d_nms_utils.boxes_iou3d_gpu(anchors[:, 0:7], gt_boxes[:, 0:7]) \
|
||||||
|
if self.match_height else box_utils.boxes3d_nearest_bev_iou(anchors[:, 0:7], gt_boxes[:, 0:7])
|
||||||
|
|
||||||
|
# NOTE: The speed of these two versions depends the environment and the number of anchors
|
||||||
|
# anchor_to_gt_argmax = torch.from_numpy(anchor_by_gt_overlap.cpu().numpy().argmax(axis=1)).cuda()
|
||||||
|
anchor_to_gt_argmax = anchor_by_gt_overlap.argmax(dim=1)
|
||||||
|
anchor_to_gt_max = anchor_by_gt_overlap[torch.arange(num_anchors, device=anchors.device), anchor_to_gt_argmax]
|
||||||
|
|
||||||
|
# gt_to_anchor_argmax = torch.from_numpy(anchor_by_gt_overlap.cpu().numpy().argmax(axis=0)).cuda()
|
||||||
|
gt_to_anchor_argmax = anchor_by_gt_overlap.argmax(dim=0)
|
||||||
|
gt_to_anchor_max = anchor_by_gt_overlap[gt_to_anchor_argmax, torch.arange(num_gt, device=anchors.device)]
|
||||||
|
empty_gt_mask = gt_to_anchor_max == 0
|
||||||
|
gt_to_anchor_max[empty_gt_mask] = -1
|
||||||
|
|
||||||
|
anchors_with_max_overlap = (anchor_by_gt_overlap == gt_to_anchor_max).nonzero()[:, 0]
|
||||||
|
gt_inds_force = anchor_to_gt_argmax[anchors_with_max_overlap]
|
||||||
|
labels[anchors_with_max_overlap] = gt_classes[gt_inds_force]
|
||||||
|
gt_ids[anchors_with_max_overlap] = gt_inds_force.int()
|
||||||
|
|
||||||
|
pos_inds = anchor_to_gt_max >= matched_threshold
|
||||||
|
gt_inds_over_thresh = anchor_to_gt_argmax[pos_inds]
|
||||||
|
labels[pos_inds] = gt_classes[gt_inds_over_thresh]
|
||||||
|
gt_ids[pos_inds] = gt_inds_over_thresh.int()
|
||||||
|
bg_inds = (anchor_to_gt_max < unmatched_threshold).nonzero()[:, 0]
|
||||||
|
else:
|
||||||
|
bg_inds = torch.arange(num_anchors, device=anchors.device)
|
||||||
|
|
||||||
|
fg_inds = (labels > 0).nonzero()[:, 0]
|
||||||
|
|
||||||
|
if self.pos_fraction is not None:
|
||||||
|
num_fg = int(self.pos_fraction * self.sample_size)
|
||||||
|
if len(fg_inds) > num_fg:
|
||||||
|
num_disabled = len(fg_inds) - num_fg
|
||||||
|
disable_inds = torch.randperm(len(fg_inds))[:num_disabled]
|
||||||
|
labels[disable_inds] = -1
|
||||||
|
fg_inds = (labels > 0).nonzero()[:, 0]
|
||||||
|
|
||||||
|
num_bg = self.sample_size - (labels > 0).sum()
|
||||||
|
if len(bg_inds) > num_bg:
|
||||||
|
enable_inds = bg_inds[torch.randint(0, len(bg_inds), size=(num_bg,))]
|
||||||
|
labels[enable_inds] = 0
|
||||||
|
# bg_inds = torch.nonzero(labels == 0)[:, 0]
|
||||||
|
else:
|
||||||
|
if len(gt_boxes) == 0 or anchors.shape[0] == 0:
|
||||||
|
labels[:] = 0
|
||||||
|
else:
|
||||||
|
labels[bg_inds] = 0
|
||||||
|
labels[anchors_with_max_overlap] = gt_classes[gt_inds_force]
|
||||||
|
|
||||||
|
bbox_targets = anchors.new_zeros((num_anchors, self.box_coder.code_size))
|
||||||
|
if len(gt_boxes) > 0 and anchors.shape[0] > 0:
|
||||||
|
fg_gt_boxes = gt_boxes[anchor_to_gt_argmax[fg_inds], :]
|
||||||
|
fg_anchors = anchors[fg_inds, :]
|
||||||
|
bbox_targets[fg_inds, :] = self.box_coder.encode_torch(fg_gt_boxes, fg_anchors)
|
||||||
|
|
||||||
|
reg_weights = anchors.new_zeros((num_anchors,))
|
||||||
|
|
||||||
|
if self.norm_by_num_examples:
|
||||||
|
num_examples = (labels >= 0).sum()
|
||||||
|
num_examples = num_examples if num_examples > 1.0 else 1.0
|
||||||
|
reg_weights[labels > 0] = 1.0 / num_examples
|
||||||
|
else:
|
||||||
|
reg_weights[labels > 0] = 1.0
|
||||||
|
|
||||||
|
ret_dict = {
|
||||||
|
'box_cls_labels': labels,
|
||||||
|
'box_reg_targets': bbox_targets,
|
||||||
|
'reg_weights': reg_weights,
|
||||||
|
}
|
||||||
|
return ret_dict
|
||||||
Reference in New Issue
Block a user