This commit is contained in:
2025-09-21 20:19:04 +08:00
parent c1fac73c5a
commit 0ffab727fe

View File

@@ -0,0 +1,275 @@
import numpy as np
import torch
import torch.nn as nn
from ...utils import box_coder_utils, common_utils, loss_utils
from .target_assigner.anchor_generator import AnchorGenerator
from .target_assigner.atss_target_assigner import ATSSTargetAssigner
from .target_assigner.axis_aligned_target_assigner import AxisAlignedTargetAssigner
class AnchorHeadTemplate(nn.Module):
def __init__(self, model_cfg, num_class, class_names, grid_size, point_cloud_range, predict_boxes_when_training):
super().__init__()
self.model_cfg = model_cfg
self.num_class = num_class
self.class_names = class_names
self.predict_boxes_when_training = predict_boxes_when_training
self.use_multihead = self.model_cfg.get('USE_MULTIHEAD', False)
anchor_target_cfg = self.model_cfg.TARGET_ASSIGNER_CONFIG
self.box_coder = getattr(box_coder_utils, anchor_target_cfg.BOX_CODER)(
num_dir_bins=anchor_target_cfg.get('NUM_DIR_BINS', 6),
**anchor_target_cfg.get('BOX_CODER_CONFIG', {})
)
anchor_generator_cfg = self.model_cfg.ANCHOR_GENERATOR_CONFIG
anchors, self.num_anchors_per_location = self.generate_anchors(
anchor_generator_cfg, grid_size=grid_size, point_cloud_range=point_cloud_range,
anchor_ndim=self.box_coder.code_size
)
self.anchors = [x.cuda() for x in anchors]
self.target_assigner = self.get_target_assigner(anchor_target_cfg)
self.forward_ret_dict = {}
self.build_losses(self.model_cfg.LOSS_CONFIG)
@staticmethod
def generate_anchors(anchor_generator_cfg, grid_size, point_cloud_range, anchor_ndim=7):
anchor_generator = AnchorGenerator(
anchor_range=point_cloud_range,
anchor_generator_config=anchor_generator_cfg
)
feature_map_size = [grid_size[:2] // config['feature_map_stride'] for config in anchor_generator_cfg]
anchors_list, num_anchors_per_location_list = anchor_generator.generate_anchors(feature_map_size)
if anchor_ndim != 7:
for idx, anchors in enumerate(anchors_list):
pad_zeros = anchors.new_zeros([*anchors.shape[0:-1], anchor_ndim - 7])
new_anchors = torch.cat((anchors, pad_zeros), dim=-1)
anchors_list[idx] = new_anchors
return anchors_list, num_anchors_per_location_list
def get_target_assigner(self, anchor_target_cfg):
if anchor_target_cfg.NAME == 'ATSS':
target_assigner = ATSSTargetAssigner(
topk=anchor_target_cfg.TOPK,
box_coder=self.box_coder,
use_multihead=self.use_multihead,
match_height=anchor_target_cfg.MATCH_HEIGHT
)
elif anchor_target_cfg.NAME == 'AxisAlignedTargetAssigner':
target_assigner = AxisAlignedTargetAssigner(
model_cfg=self.model_cfg,
class_names=self.class_names,
box_coder=self.box_coder,
match_height=anchor_target_cfg.MATCH_HEIGHT
)
else:
raise NotImplementedError
return target_assigner
def build_losses(self, losses_cfg):
self.add_module(
'cls_loss_func',
loss_utils.SigmoidFocalClassificationLoss(alpha=0.25, gamma=2.0)
)
reg_loss_name = 'WeightedSmoothL1Loss' if losses_cfg.get('REG_LOSS_TYPE', None) is None \
else losses_cfg.REG_LOSS_TYPE
self.add_module(
'reg_loss_func',
getattr(loss_utils, reg_loss_name)(code_weights=losses_cfg.LOSS_WEIGHTS['code_weights'])
)
self.add_module(
'dir_loss_func',
loss_utils.WeightedCrossEntropyLoss()
)
def assign_targets(self, gt_boxes):
"""
Args:
gt_boxes: (B, M, 8)
Returns:
"""
targets_dict = self.target_assigner.assign_targets(
self.anchors, gt_boxes
)
return targets_dict
def get_cls_layer_loss(self):
cls_preds = self.forward_ret_dict['cls_preds']
box_cls_labels = self.forward_ret_dict['box_cls_labels']
batch_size = int(cls_preds.shape[0])
cared = box_cls_labels >= 0 # [N, num_anchors]
positives = box_cls_labels > 0
negatives = box_cls_labels == 0
negative_cls_weights = negatives * 1.0
cls_weights = (negative_cls_weights + 1.0 * positives).float()
reg_weights = positives.float()
if self.num_class == 1:
# class agnostic
box_cls_labels[positives] = 1
pos_normalizer = positives.sum(1, keepdim=True).float()
reg_weights /= torch.clamp(pos_normalizer, min=1.0)
cls_weights /= torch.clamp(pos_normalizer, min=1.0)
cls_targets = box_cls_labels * cared.type_as(box_cls_labels)
cls_targets = cls_targets.unsqueeze(dim=-1)
cls_targets = cls_targets.squeeze(dim=-1)
one_hot_targets = torch.zeros(
*list(cls_targets.shape), self.num_class + 1, dtype=cls_preds.dtype, device=cls_targets.device
)
one_hot_targets.scatter_(-1, cls_targets.unsqueeze(dim=-1).long(), 1.0)
cls_preds = cls_preds.view(batch_size, -1, self.num_class)
one_hot_targets = one_hot_targets[..., 1:]
cls_loss_src = self.cls_loss_func(cls_preds, one_hot_targets, weights=cls_weights) # [N, M]
cls_loss = cls_loss_src.sum() / batch_size
cls_loss = cls_loss * self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['cls_weight']
tb_dict = {
'rpn_loss_cls': cls_loss.item()
}
return cls_loss, tb_dict
@staticmethod
def add_sin_difference(boxes1, boxes2, dim=6):
assert dim != -1
rad_pred_encoding = torch.sin(boxes1[..., dim:dim + 1]) * torch.cos(boxes2[..., dim:dim + 1])
rad_tg_encoding = torch.cos(boxes1[..., dim:dim + 1]) * torch.sin(boxes2[..., dim:dim + 1])
boxes1 = torch.cat([boxes1[..., :dim], rad_pred_encoding, boxes1[..., dim + 1:]], dim=-1)
boxes2 = torch.cat([boxes2[..., :dim], rad_tg_encoding, boxes2[..., dim + 1:]], dim=-1)
return boxes1, boxes2
@staticmethod
def get_direction_target(anchors, reg_targets, one_hot=True, dir_offset=0, num_bins=2):
batch_size = reg_targets.shape[0]
anchors = anchors.view(batch_size, -1, anchors.shape[-1])
rot_gt = reg_targets[..., 6] + anchors[..., 6]
offset_rot = common_utils.limit_period(rot_gt - dir_offset, 0, 2 * np.pi)
dir_cls_targets = torch.floor(offset_rot / (2 * np.pi / num_bins)).long()
dir_cls_targets = torch.clamp(dir_cls_targets, min=0, max=num_bins - 1)
if one_hot:
dir_targets = torch.zeros(*list(dir_cls_targets.shape), num_bins, dtype=anchors.dtype,
device=dir_cls_targets.device)
dir_targets.scatter_(-1, dir_cls_targets.unsqueeze(dim=-1).long(), 1.0)
dir_cls_targets = dir_targets
return dir_cls_targets
def get_box_reg_layer_loss(self):
box_preds = self.forward_ret_dict['box_preds']
box_dir_cls_preds = self.forward_ret_dict.get('dir_cls_preds', None)
box_reg_targets = self.forward_ret_dict['box_reg_targets']
box_cls_labels = self.forward_ret_dict['box_cls_labels']
batch_size = int(box_preds.shape[0])
positives = box_cls_labels > 0
reg_weights = positives.float()
pos_normalizer = positives.sum(1, keepdim=True).float()
reg_weights /= torch.clamp(pos_normalizer, min=1.0)
if isinstance(self.anchors, list):
if self.use_multihead:
anchors = torch.cat(
[anchor.permute(3, 4, 0, 1, 2, 5).contiguous().view(-1, anchor.shape[-1]) for anchor in
self.anchors], dim=0)
else:
anchors = torch.cat(self.anchors, dim=-3)
else:
anchors = self.anchors
anchors = anchors.view(1, -1, anchors.shape[-1]).repeat(batch_size, 1, 1)
box_preds = box_preds.view(batch_size, -1,
box_preds.shape[-1] // self.num_anchors_per_location if not self.use_multihead else
box_preds.shape[-1])
# sin(a - b) = sinacosb-cosasinb
box_preds_sin, reg_targets_sin = self.add_sin_difference(box_preds, box_reg_targets)
loc_loss_src = self.reg_loss_func(box_preds_sin, reg_targets_sin, weights=reg_weights) # [N, M]
loc_loss = loc_loss_src.sum() / batch_size
loc_loss = loc_loss * self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['loc_weight']
box_loss = loc_loss
tb_dict = {
'rpn_loss_loc': loc_loss.item()
}
if box_dir_cls_preds is not None:
dir_targets = self.get_direction_target(
anchors, box_reg_targets,
dir_offset=self.model_cfg.DIR_OFFSET,
num_bins=self.model_cfg.NUM_DIR_BINS
)
dir_logits = box_dir_cls_preds.view(batch_size, -1, self.model_cfg.NUM_DIR_BINS)
weights = positives.type_as(dir_logits)
weights /= torch.clamp(weights.sum(-1, keepdim=True), min=1.0)
dir_loss = self.dir_loss_func(dir_logits, dir_targets, weights=weights)
dir_loss = dir_loss.sum() / batch_size
dir_loss = dir_loss * self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['dir_weight']
box_loss += dir_loss
tb_dict['rpn_loss_dir'] = dir_loss.item()
return box_loss, tb_dict
def get_loss(self):
cls_loss, tb_dict = self.get_cls_layer_loss()
box_loss, tb_dict_box = self.get_box_reg_layer_loss()
tb_dict.update(tb_dict_box)
rpn_loss = cls_loss + box_loss
tb_dict['rpn_loss'] = rpn_loss.item()
return rpn_loss, tb_dict
def generate_predicted_boxes(self, batch_size, cls_preds, box_preds, dir_cls_preds=None):
"""
Args:
batch_size:
cls_preds: (N, H, W, C1)
box_preds: (N, H, W, C2)
dir_cls_preds: (N, H, W, C3)
Returns:
batch_cls_preds: (B, num_boxes, num_classes)
batch_box_preds: (B, num_boxes, 7+C)
"""
if isinstance(self.anchors, list):
if self.use_multihead:
anchors = torch.cat([anchor.permute(3, 4, 0, 1, 2, 5).contiguous().view(-1, anchor.shape[-1])
for anchor in self.anchors], dim=0)
else:
anchors = torch.cat(self.anchors, dim=-3)
else:
anchors = self.anchors
num_anchors = anchors.view(-1, anchors.shape[-1]).shape[0]
batch_anchors = anchors.view(1, -1, anchors.shape[-1]).repeat(batch_size, 1, 1)
batch_cls_preds = cls_preds.view(batch_size, num_anchors, -1).float() \
if not isinstance(cls_preds, list) else cls_preds
batch_box_preds = box_preds.view(batch_size, num_anchors, -1) if not isinstance(box_preds, list) \
else torch.cat(box_preds, dim=1).view(batch_size, num_anchors, -1)
batch_box_preds = self.box_coder.decode_torch(batch_box_preds, batch_anchors)
if dir_cls_preds is not None:
dir_offset = self.model_cfg.DIR_OFFSET
dir_limit_offset = self.model_cfg.DIR_LIMIT_OFFSET
dir_cls_preds = dir_cls_preds.view(batch_size, num_anchors, -1) if not isinstance(dir_cls_preds, list) \
else torch.cat(dir_cls_preds, dim=1).view(batch_size, num_anchors, -1)
dir_labels = torch.max(dir_cls_preds, dim=-1)[1]
period = (2 * np.pi / self.model_cfg.NUM_DIR_BINS)
dir_rot = common_utils.limit_period(
batch_box_preds[..., 6] - dir_offset, dir_limit_offset, period
)
batch_box_preds[..., 6] = dir_rot + dir_offset + period * dir_labels.to(batch_box_preds.dtype)
if isinstance(self.box_coder, box_coder_utils.PreviousResidualDecoder):
batch_box_preds[..., 6] = common_utils.limit_period(
-(batch_box_preds[..., 6] + np.pi / 2), offset=0.5, period=np.pi * 2
)
return batch_cls_preds, batch_box_preds
def forward(self, **kwargs):
raise NotImplementedError