Add File
This commit is contained in:
@@ -0,0 +1,210 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ....ops.iou3d_nms import iou3d_nms_utils
|
||||
from ....utils import box_utils
|
||||
|
||||
|
||||
class AxisAlignedTargetAssigner(object):
|
||||
def __init__(self, model_cfg, class_names, box_coder, match_height=False):
|
||||
super().__init__()
|
||||
|
||||
anchor_generator_cfg = model_cfg.ANCHOR_GENERATOR_CONFIG
|
||||
anchor_target_cfg = model_cfg.TARGET_ASSIGNER_CONFIG
|
||||
self.box_coder = box_coder
|
||||
self.match_height = match_height
|
||||
self.class_names = np.array(class_names)
|
||||
self.anchor_class_names = [config['class_name'] for config in anchor_generator_cfg]
|
||||
self.pos_fraction = anchor_target_cfg.POS_FRACTION if anchor_target_cfg.POS_FRACTION >= 0 else None
|
||||
self.sample_size = anchor_target_cfg.SAMPLE_SIZE
|
||||
self.norm_by_num_examples = anchor_target_cfg.NORM_BY_NUM_EXAMPLES
|
||||
self.matched_thresholds = {}
|
||||
self.unmatched_thresholds = {}
|
||||
for config in anchor_generator_cfg:
|
||||
self.matched_thresholds[config['class_name']] = config['matched_threshold']
|
||||
self.unmatched_thresholds[config['class_name']] = config['unmatched_threshold']
|
||||
|
||||
self.use_multihead = model_cfg.get('USE_MULTIHEAD', False)
|
||||
# self.separate_multihead = model_cfg.get('SEPARATE_MULTIHEAD', False)
|
||||
# if self.seperate_multihead:
|
||||
# rpn_head_cfgs = model_cfg.RPN_HEAD_CFGS
|
||||
# self.gt_remapping = {}
|
||||
# for rpn_head_cfg in rpn_head_cfgs:
|
||||
# for idx, name in enumerate(rpn_head_cfg['HEAD_CLS_NAME']):
|
||||
# self.gt_remapping[name] = idx + 1
|
||||
|
||||
def assign_targets(self, all_anchors, gt_boxes_with_classes):
|
||||
"""
|
||||
Args:
|
||||
all_anchors: [(N, 7), ...]
|
||||
gt_boxes: (B, M, 8)
|
||||
Returns:
|
||||
|
||||
"""
|
||||
|
||||
bbox_targets = []
|
||||
cls_labels = []
|
||||
reg_weights = []
|
||||
|
||||
batch_size = gt_boxes_with_classes.shape[0]
|
||||
gt_classes = gt_boxes_with_classes[:, :, -1]
|
||||
gt_boxes = gt_boxes_with_classes[:, :, :-1]
|
||||
for k in range(batch_size):
|
||||
cur_gt = gt_boxes[k]
|
||||
cnt = cur_gt.__len__() - 1
|
||||
while cnt > 0 and cur_gt[cnt].sum() == 0:
|
||||
cnt -= 1
|
||||
cur_gt = cur_gt[:cnt + 1]
|
||||
cur_gt_classes = gt_classes[k][:cnt + 1].int()
|
||||
|
||||
target_list = []
|
||||
for anchor_class_name, anchors in zip(self.anchor_class_names, all_anchors):
|
||||
if cur_gt_classes.shape[0] > 1:
|
||||
mask = torch.from_numpy(self.class_names[cur_gt_classes.cpu() - 1] == anchor_class_name)
|
||||
else:
|
||||
mask = torch.tensor([self.class_names[c - 1] == anchor_class_name
|
||||
for c in cur_gt_classes], dtype=torch.bool)
|
||||
|
||||
if self.use_multihead:
|
||||
anchors = anchors.permute(3, 4, 0, 1, 2, 5).contiguous().view(-1, anchors.shape[-1])
|
||||
# if self.seperate_multihead:
|
||||
# selected_classes = cur_gt_classes[mask].clone()
|
||||
# if len(selected_classes) > 0:
|
||||
# new_cls_id = self.gt_remapping[anchor_class_name]
|
||||
# selected_classes[:] = new_cls_id
|
||||
# else:
|
||||
# selected_classes = cur_gt_classes[mask]
|
||||
selected_classes = cur_gt_classes[mask]
|
||||
else:
|
||||
feature_map_size = anchors.shape[:3]
|
||||
anchors = anchors.view(-1, anchors.shape[-1])
|
||||
selected_classes = cur_gt_classes[mask]
|
||||
|
||||
single_target = self.assign_targets_single(
|
||||
anchors,
|
||||
cur_gt[mask],
|
||||
gt_classes=selected_classes,
|
||||
matched_threshold=self.matched_thresholds[anchor_class_name],
|
||||
unmatched_threshold=self.unmatched_thresholds[anchor_class_name]
|
||||
)
|
||||
target_list.append(single_target)
|
||||
|
||||
if self.use_multihead:
|
||||
target_dict = {
|
||||
'box_cls_labels': [t['box_cls_labels'].view(-1) for t in target_list],
|
||||
'box_reg_targets': [t['box_reg_targets'].view(-1, self.box_coder.code_size) for t in target_list],
|
||||
'reg_weights': [t['reg_weights'].view(-1) for t in target_list]
|
||||
}
|
||||
|
||||
target_dict['box_reg_targets'] = torch.cat(target_dict['box_reg_targets'], dim=0)
|
||||
target_dict['box_cls_labels'] = torch.cat(target_dict['box_cls_labels'], dim=0).view(-1)
|
||||
target_dict['reg_weights'] = torch.cat(target_dict['reg_weights'], dim=0).view(-1)
|
||||
else:
|
||||
target_dict = {
|
||||
'box_cls_labels': [t['box_cls_labels'].view(*feature_map_size, -1) for t in target_list],
|
||||
'box_reg_targets': [t['box_reg_targets'].view(*feature_map_size, -1, self.box_coder.code_size)
|
||||
for t in target_list],
|
||||
'reg_weights': [t['reg_weights'].view(*feature_map_size, -1) for t in target_list]
|
||||
}
|
||||
target_dict['box_reg_targets'] = torch.cat(
|
||||
target_dict['box_reg_targets'], dim=-2
|
||||
).view(-1, self.box_coder.code_size)
|
||||
|
||||
target_dict['box_cls_labels'] = torch.cat(target_dict['box_cls_labels'], dim=-1).view(-1)
|
||||
target_dict['reg_weights'] = torch.cat(target_dict['reg_weights'], dim=-1).view(-1)
|
||||
|
||||
bbox_targets.append(target_dict['box_reg_targets'])
|
||||
cls_labels.append(target_dict['box_cls_labels'])
|
||||
reg_weights.append(target_dict['reg_weights'])
|
||||
|
||||
bbox_targets = torch.stack(bbox_targets, dim=0)
|
||||
|
||||
cls_labels = torch.stack(cls_labels, dim=0)
|
||||
reg_weights = torch.stack(reg_weights, dim=0)
|
||||
all_targets_dict = {
|
||||
'box_cls_labels': cls_labels,
|
||||
'box_reg_targets': bbox_targets,
|
||||
'reg_weights': reg_weights
|
||||
|
||||
}
|
||||
return all_targets_dict
|
||||
|
||||
def assign_targets_single(self, anchors, gt_boxes, gt_classes, matched_threshold=0.6, unmatched_threshold=0.45):
|
||||
|
||||
num_anchors = anchors.shape[0]
|
||||
num_gt = gt_boxes.shape[0]
|
||||
|
||||
labels = torch.ones((num_anchors,), dtype=torch.int32, device=anchors.device) * -1
|
||||
gt_ids = torch.ones((num_anchors,), dtype=torch.int32, device=anchors.device) * -1
|
||||
|
||||
if len(gt_boxes) > 0 and anchors.shape[0] > 0:
|
||||
anchor_by_gt_overlap = iou3d_nms_utils.boxes_iou3d_gpu(anchors[:, 0:7], gt_boxes[:, 0:7]) \
|
||||
if self.match_height else box_utils.boxes3d_nearest_bev_iou(anchors[:, 0:7], gt_boxes[:, 0:7])
|
||||
|
||||
# NOTE: The speed of these two versions depends the environment and the number of anchors
|
||||
# anchor_to_gt_argmax = torch.from_numpy(anchor_by_gt_overlap.cpu().numpy().argmax(axis=1)).cuda()
|
||||
anchor_to_gt_argmax = anchor_by_gt_overlap.argmax(dim=1)
|
||||
anchor_to_gt_max = anchor_by_gt_overlap[torch.arange(num_anchors, device=anchors.device), anchor_to_gt_argmax]
|
||||
|
||||
# gt_to_anchor_argmax = torch.from_numpy(anchor_by_gt_overlap.cpu().numpy().argmax(axis=0)).cuda()
|
||||
gt_to_anchor_argmax = anchor_by_gt_overlap.argmax(dim=0)
|
||||
gt_to_anchor_max = anchor_by_gt_overlap[gt_to_anchor_argmax, torch.arange(num_gt, device=anchors.device)]
|
||||
empty_gt_mask = gt_to_anchor_max == 0
|
||||
gt_to_anchor_max[empty_gt_mask] = -1
|
||||
|
||||
anchors_with_max_overlap = (anchor_by_gt_overlap == gt_to_anchor_max).nonzero()[:, 0]
|
||||
gt_inds_force = anchor_to_gt_argmax[anchors_with_max_overlap]
|
||||
labels[anchors_with_max_overlap] = gt_classes[gt_inds_force]
|
||||
gt_ids[anchors_with_max_overlap] = gt_inds_force.int()
|
||||
|
||||
pos_inds = anchor_to_gt_max >= matched_threshold
|
||||
gt_inds_over_thresh = anchor_to_gt_argmax[pos_inds]
|
||||
labels[pos_inds] = gt_classes[gt_inds_over_thresh]
|
||||
gt_ids[pos_inds] = gt_inds_over_thresh.int()
|
||||
bg_inds = (anchor_to_gt_max < unmatched_threshold).nonzero()[:, 0]
|
||||
else:
|
||||
bg_inds = torch.arange(num_anchors, device=anchors.device)
|
||||
|
||||
fg_inds = (labels > 0).nonzero()[:, 0]
|
||||
|
||||
if self.pos_fraction is not None:
|
||||
num_fg = int(self.pos_fraction * self.sample_size)
|
||||
if len(fg_inds) > num_fg:
|
||||
num_disabled = len(fg_inds) - num_fg
|
||||
disable_inds = torch.randperm(len(fg_inds))[:num_disabled]
|
||||
labels[disable_inds] = -1
|
||||
fg_inds = (labels > 0).nonzero()[:, 0]
|
||||
|
||||
num_bg = self.sample_size - (labels > 0).sum()
|
||||
if len(bg_inds) > num_bg:
|
||||
enable_inds = bg_inds[torch.randint(0, len(bg_inds), size=(num_bg,))]
|
||||
labels[enable_inds] = 0
|
||||
# bg_inds = torch.nonzero(labels == 0)[:, 0]
|
||||
else:
|
||||
if len(gt_boxes) == 0 or anchors.shape[0] == 0:
|
||||
labels[:] = 0
|
||||
else:
|
||||
labels[bg_inds] = 0
|
||||
labels[anchors_with_max_overlap] = gt_classes[gt_inds_force]
|
||||
|
||||
bbox_targets = anchors.new_zeros((num_anchors, self.box_coder.code_size))
|
||||
if len(gt_boxes) > 0 and anchors.shape[0] > 0:
|
||||
fg_gt_boxes = gt_boxes[anchor_to_gt_argmax[fg_inds], :]
|
||||
fg_anchors = anchors[fg_inds, :]
|
||||
bbox_targets[fg_inds, :] = self.box_coder.encode_torch(fg_gt_boxes, fg_anchors)
|
||||
|
||||
reg_weights = anchors.new_zeros((num_anchors,))
|
||||
|
||||
if self.norm_by_num_examples:
|
||||
num_examples = (labels >= 0).sum()
|
||||
num_examples = num_examples if num_examples > 1.0 else 1.0
|
||||
reg_weights[labels > 0] = 1.0 / num_examples
|
||||
else:
|
||||
reg_weights[labels > 0] = 1.0
|
||||
|
||||
ret_dict = {
|
||||
'box_cls_labels': labels,
|
||||
'box_reg_targets': bbox_targets,
|
||||
'reg_weights': reg_weights,
|
||||
}
|
||||
return ret_dict
|
||||
Reference in New Issue
Block a user