This commit is contained in:
2025-09-21 20:19:07 +08:00
parent af2f146e7e
commit 3a66ba728f

View File

@@ -0,0 +1,210 @@
import numpy as np
import torch
from ....ops.iou3d_nms import iou3d_nms_utils
from ....utils import box_utils
class AxisAlignedTargetAssigner(object):
def __init__(self, model_cfg, class_names, box_coder, match_height=False):
super().__init__()
anchor_generator_cfg = model_cfg.ANCHOR_GENERATOR_CONFIG
anchor_target_cfg = model_cfg.TARGET_ASSIGNER_CONFIG
self.box_coder = box_coder
self.match_height = match_height
self.class_names = np.array(class_names)
self.anchor_class_names = [config['class_name'] for config in anchor_generator_cfg]
self.pos_fraction = anchor_target_cfg.POS_FRACTION if anchor_target_cfg.POS_FRACTION >= 0 else None
self.sample_size = anchor_target_cfg.SAMPLE_SIZE
self.norm_by_num_examples = anchor_target_cfg.NORM_BY_NUM_EXAMPLES
self.matched_thresholds = {}
self.unmatched_thresholds = {}
for config in anchor_generator_cfg:
self.matched_thresholds[config['class_name']] = config['matched_threshold']
self.unmatched_thresholds[config['class_name']] = config['unmatched_threshold']
self.use_multihead = model_cfg.get('USE_MULTIHEAD', False)
# self.separate_multihead = model_cfg.get('SEPARATE_MULTIHEAD', False)
# if self.seperate_multihead:
# rpn_head_cfgs = model_cfg.RPN_HEAD_CFGS
# self.gt_remapping = {}
# for rpn_head_cfg in rpn_head_cfgs:
# for idx, name in enumerate(rpn_head_cfg['HEAD_CLS_NAME']):
# self.gt_remapping[name] = idx + 1
def assign_targets(self, all_anchors, gt_boxes_with_classes):
"""
Args:
all_anchors: [(N, 7), ...]
gt_boxes: (B, M, 8)
Returns:
"""
bbox_targets = []
cls_labels = []
reg_weights = []
batch_size = gt_boxes_with_classes.shape[0]
gt_classes = gt_boxes_with_classes[:, :, -1]
gt_boxes = gt_boxes_with_classes[:, :, :-1]
for k in range(batch_size):
cur_gt = gt_boxes[k]
cnt = cur_gt.__len__() - 1
while cnt > 0 and cur_gt[cnt].sum() == 0:
cnt -= 1
cur_gt = cur_gt[:cnt + 1]
cur_gt_classes = gt_classes[k][:cnt + 1].int()
target_list = []
for anchor_class_name, anchors in zip(self.anchor_class_names, all_anchors):
if cur_gt_classes.shape[0] > 1:
mask = torch.from_numpy(self.class_names[cur_gt_classes.cpu() - 1] == anchor_class_name)
else:
mask = torch.tensor([self.class_names[c - 1] == anchor_class_name
for c in cur_gt_classes], dtype=torch.bool)
if self.use_multihead:
anchors = anchors.permute(3, 4, 0, 1, 2, 5).contiguous().view(-1, anchors.shape[-1])
# if self.seperate_multihead:
# selected_classes = cur_gt_classes[mask].clone()
# if len(selected_classes) > 0:
# new_cls_id = self.gt_remapping[anchor_class_name]
# selected_classes[:] = new_cls_id
# else:
# selected_classes = cur_gt_classes[mask]
selected_classes = cur_gt_classes[mask]
else:
feature_map_size = anchors.shape[:3]
anchors = anchors.view(-1, anchors.shape[-1])
selected_classes = cur_gt_classes[mask]
single_target = self.assign_targets_single(
anchors,
cur_gt[mask],
gt_classes=selected_classes,
matched_threshold=self.matched_thresholds[anchor_class_name],
unmatched_threshold=self.unmatched_thresholds[anchor_class_name]
)
target_list.append(single_target)
if self.use_multihead:
target_dict = {
'box_cls_labels': [t['box_cls_labels'].view(-1) for t in target_list],
'box_reg_targets': [t['box_reg_targets'].view(-1, self.box_coder.code_size) for t in target_list],
'reg_weights': [t['reg_weights'].view(-1) for t in target_list]
}
target_dict['box_reg_targets'] = torch.cat(target_dict['box_reg_targets'], dim=0)
target_dict['box_cls_labels'] = torch.cat(target_dict['box_cls_labels'], dim=0).view(-1)
target_dict['reg_weights'] = torch.cat(target_dict['reg_weights'], dim=0).view(-1)
else:
target_dict = {
'box_cls_labels': [t['box_cls_labels'].view(*feature_map_size, -1) for t in target_list],
'box_reg_targets': [t['box_reg_targets'].view(*feature_map_size, -1, self.box_coder.code_size)
for t in target_list],
'reg_weights': [t['reg_weights'].view(*feature_map_size, -1) for t in target_list]
}
target_dict['box_reg_targets'] = torch.cat(
target_dict['box_reg_targets'], dim=-2
).view(-1, self.box_coder.code_size)
target_dict['box_cls_labels'] = torch.cat(target_dict['box_cls_labels'], dim=-1).view(-1)
target_dict['reg_weights'] = torch.cat(target_dict['reg_weights'], dim=-1).view(-1)
bbox_targets.append(target_dict['box_reg_targets'])
cls_labels.append(target_dict['box_cls_labels'])
reg_weights.append(target_dict['reg_weights'])
bbox_targets = torch.stack(bbox_targets, dim=0)
cls_labels = torch.stack(cls_labels, dim=0)
reg_weights = torch.stack(reg_weights, dim=0)
all_targets_dict = {
'box_cls_labels': cls_labels,
'box_reg_targets': bbox_targets,
'reg_weights': reg_weights
}
return all_targets_dict
def assign_targets_single(self, anchors, gt_boxes, gt_classes, matched_threshold=0.6, unmatched_threshold=0.45):
num_anchors = anchors.shape[0]
num_gt = gt_boxes.shape[0]
labels = torch.ones((num_anchors,), dtype=torch.int32, device=anchors.device) * -1
gt_ids = torch.ones((num_anchors,), dtype=torch.int32, device=anchors.device) * -1
if len(gt_boxes) > 0 and anchors.shape[0] > 0:
anchor_by_gt_overlap = iou3d_nms_utils.boxes_iou3d_gpu(anchors[:, 0:7], gt_boxes[:, 0:7]) \
if self.match_height else box_utils.boxes3d_nearest_bev_iou(anchors[:, 0:7], gt_boxes[:, 0:7])
# NOTE: The speed of these two versions depends the environment and the number of anchors
# anchor_to_gt_argmax = torch.from_numpy(anchor_by_gt_overlap.cpu().numpy().argmax(axis=1)).cuda()
anchor_to_gt_argmax = anchor_by_gt_overlap.argmax(dim=1)
anchor_to_gt_max = anchor_by_gt_overlap[torch.arange(num_anchors, device=anchors.device), anchor_to_gt_argmax]
# gt_to_anchor_argmax = torch.from_numpy(anchor_by_gt_overlap.cpu().numpy().argmax(axis=0)).cuda()
gt_to_anchor_argmax = anchor_by_gt_overlap.argmax(dim=0)
gt_to_anchor_max = anchor_by_gt_overlap[gt_to_anchor_argmax, torch.arange(num_gt, device=anchors.device)]
empty_gt_mask = gt_to_anchor_max == 0
gt_to_anchor_max[empty_gt_mask] = -1
anchors_with_max_overlap = (anchor_by_gt_overlap == gt_to_anchor_max).nonzero()[:, 0]
gt_inds_force = anchor_to_gt_argmax[anchors_with_max_overlap]
labels[anchors_with_max_overlap] = gt_classes[gt_inds_force]
gt_ids[anchors_with_max_overlap] = gt_inds_force.int()
pos_inds = anchor_to_gt_max >= matched_threshold
gt_inds_over_thresh = anchor_to_gt_argmax[pos_inds]
labels[pos_inds] = gt_classes[gt_inds_over_thresh]
gt_ids[pos_inds] = gt_inds_over_thresh.int()
bg_inds = (anchor_to_gt_max < unmatched_threshold).nonzero()[:, 0]
else:
bg_inds = torch.arange(num_anchors, device=anchors.device)
fg_inds = (labels > 0).nonzero()[:, 0]
if self.pos_fraction is not None:
num_fg = int(self.pos_fraction * self.sample_size)
if len(fg_inds) > num_fg:
num_disabled = len(fg_inds) - num_fg
disable_inds = torch.randperm(len(fg_inds))[:num_disabled]
labels[disable_inds] = -1
fg_inds = (labels > 0).nonzero()[:, 0]
num_bg = self.sample_size - (labels > 0).sum()
if len(bg_inds) > num_bg:
enable_inds = bg_inds[torch.randint(0, len(bg_inds), size=(num_bg,))]
labels[enable_inds] = 0
# bg_inds = torch.nonzero(labels == 0)[:, 0]
else:
if len(gt_boxes) == 0 or anchors.shape[0] == 0:
labels[:] = 0
else:
labels[bg_inds] = 0
labels[anchors_with_max_overlap] = gt_classes[gt_inds_force]
bbox_targets = anchors.new_zeros((num_anchors, self.box_coder.code_size))
if len(gt_boxes) > 0 and anchors.shape[0] > 0:
fg_gt_boxes = gt_boxes[anchor_to_gt_argmax[fg_inds], :]
fg_anchors = anchors[fg_inds, :]
bbox_targets[fg_inds, :] = self.box_coder.encode_torch(fg_gt_boxes, fg_anchors)
reg_weights = anchors.new_zeros((num_anchors,))
if self.norm_by_num_examples:
num_examples = (labels >= 0).sum()
num_examples = num_examples if num_examples > 1.0 else 1.0
reg_weights[labels > 0] = 1.0 / num_examples
else:
reg_weights[labels > 0] = 1.0
ret_dict = {
'box_cls_labels': labels,
'box_reg_targets': bbox_targets,
'reg_weights': reg_weights,
}
return ret_dict