This commit is contained in:
2025-09-21 20:19:05 +08:00
parent 5be1bd76ea
commit 1fbdae9961

View File

@@ -0,0 +1,479 @@
import copy
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.init import kaiming_normal_
from ..model_utils.transfusion_utils import clip_sigmoid
from ..model_utils.basic_block_2d import BasicBlock2D
from ..model_utils.transfusion_utils import PositionEmbeddingLearned, TransformerDecoderLayer
from .target_assigner.hungarian_assigner import HungarianAssigner3D
from ...utils import loss_utils
from ..model_utils import centernet_utils
class SeparateHead_Transfusion(nn.Module):
def __init__(self, input_channels, head_channels, kernel_size, sep_head_dict, init_bias=-2.19, use_bias=False):
super().__init__()
self.sep_head_dict = sep_head_dict
for cur_name in self.sep_head_dict:
output_channels = self.sep_head_dict[cur_name]['out_channels']
num_conv = self.sep_head_dict[cur_name]['num_conv']
fc_list = []
for k in range(num_conv - 1):
fc_list.append(nn.Sequential(
nn.Conv1d(input_channels, head_channels, kernel_size, stride=1, padding=kernel_size//2, bias=use_bias),
nn.BatchNorm1d(head_channels),
nn.ReLU()
))
fc_list.append(nn.Conv1d(head_channels, output_channels, kernel_size, stride=1, padding=kernel_size//2, bias=True))
fc = nn.Sequential(*fc_list)
if 'hm' in cur_name:
fc[-1].bias.data.fill_(init_bias)
else:
for m in fc.modules():
if isinstance(m, nn.Conv2d):
kaiming_normal_(m.weight.data)
if hasattr(m, "bias") and m.bias is not None:
nn.init.constant_(m.bias, 0)
self.__setattr__(cur_name, fc)
def forward(self, x):
ret_dict = {}
for cur_name in self.sep_head_dict:
ret_dict[cur_name] = self.__getattr__(cur_name)(x)
return ret_dict
class TransFusionHead(nn.Module):
"""
This module implements TransFusionHead.
The code is adapted from https://github.com/mit-han-lab/bevfusion/ with minimal modifications.
"""
def __init__(
self,
model_cfg, input_channels, num_class, class_names, grid_size, point_cloud_range, voxel_size, predict_boxes_when_training=True,
):
super(TransFusionHead, self).__init__()
self.grid_size = grid_size
self.point_cloud_range = point_cloud_range
self.voxel_size = voxel_size
self.num_classes = num_class
self.model_cfg = model_cfg
self.feature_map_stride = self.model_cfg.TARGET_ASSIGNER_CONFIG.get('FEATURE_MAP_STRIDE', None)
self.dataset_name = self.model_cfg.TARGET_ASSIGNER_CONFIG.get('DATASET', 'nuScenes')
hidden_channel=self.model_cfg.HIDDEN_CHANNEL
self.num_proposals = self.model_cfg.NUM_PROPOSALS
self.bn_momentum = self.model_cfg.BN_MOMENTUM
self.nms_kernel_size = self.model_cfg.NMS_KERNEL_SIZE
num_heads = self.model_cfg.NUM_HEADS
dropout = self.model_cfg.DROPOUT
activation = self.model_cfg.ACTIVATION
ffn_channel = self.model_cfg.FFN_CHANNEL
bias = self.model_cfg.get('USE_BIAS_BEFORE_NORM', False)
loss_cls = self.model_cfg.LOSS_CONFIG.LOSS_CLS
self.use_sigmoid_cls = loss_cls.get("use_sigmoid", False)
if not self.use_sigmoid_cls:
self.num_classes += 1
self.loss_cls = loss_utils.SigmoidFocalClassificationLoss(gamma=loss_cls.gamma,alpha=loss_cls.alpha)
self.loss_cls_weight = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['cls_weight']
self.loss_bbox = loss_utils.L1Loss()
self.loss_bbox_weight = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['bbox_weight']
self.loss_heatmap = loss_utils.GaussianFocalLoss()
self.loss_heatmap_weight = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['hm_weight']
self.code_size = 10
# a shared convolution
self.shared_conv = nn.Conv2d(in_channels=input_channels,out_channels=hidden_channel,kernel_size=3,padding=1)
layers = []
layers.append(BasicBlock2D(hidden_channel,hidden_channel, kernel_size=3,padding=1,bias=bias))
layers.append(nn.Conv2d(in_channels=hidden_channel,out_channels=num_class,kernel_size=3,padding=1))
self.heatmap_head = nn.Sequential(*layers)
self.class_encoding = nn.Conv1d(num_class, hidden_channel, 1)
# transformer decoder layers for object query with LiDAR feature
self.decoder = TransformerDecoderLayer(hidden_channel, num_heads, ffn_channel, dropout, activation,
self_posembed=PositionEmbeddingLearned(2, hidden_channel),
cross_posembed=PositionEmbeddingLearned(2, hidden_channel),
)
# Prediction Head
heads = copy.deepcopy(self.model_cfg.SEPARATE_HEAD_CFG.HEAD_DICT)
heads['heatmap'] = dict(out_channels=self.num_classes, num_conv=self.model_cfg.NUM_HM_CONV)
self.prediction_head = SeparateHead_Transfusion(hidden_channel, 64, 1, heads, use_bias=bias)
self.init_weights()
self.bbox_assigner = HungarianAssigner3D(**self.model_cfg.TARGET_ASSIGNER_CONFIG.HUNGARIAN_ASSIGNER)
# Position Embedding for Cross-Attention, which is re-used during training
x_size = self.grid_size[0] // self.feature_map_stride
y_size = self.grid_size[1] // self.feature_map_stride
self.bev_pos = self.create_2D_grid(x_size, y_size)
self.forward_ret_dict = {}
def create_2D_grid(self, x_size, y_size):
meshgrid = [[0, x_size - 1, x_size], [0, y_size - 1, y_size]]
# NOTE: modified
batch_x, batch_y = torch.meshgrid(
*[torch.linspace(it[0], it[1], it[2]) for it in meshgrid]
)
batch_x = batch_x + 0.5
batch_y = batch_y + 0.5
coord_base = torch.cat([batch_x[None], batch_y[None]], dim=0)[None]
coord_base = coord_base.view(1, 2, -1).permute(0, 2, 1)
return coord_base
def init_weights(self):
# initialize transformer
for m in self.decoder.parameters():
if m.dim() > 1:
nn.init.xavier_uniform_(m)
if hasattr(self, "query"):
nn.init.xavier_normal_(self.query)
self.init_bn_momentum()
def init_bn_momentum(self):
for m in self.modules():
if isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
m.momentum = self.bn_momentum
def predict(self, inputs):
batch_size = inputs.shape[0]
lidar_feat = self.shared_conv(inputs)
lidar_feat_flatten = lidar_feat.view(
batch_size, lidar_feat.shape[1], -1
)
bev_pos = self.bev_pos.repeat(batch_size, 1, 1).to(lidar_feat.device)
# query initialization
dense_heatmap = self.heatmap_head(lidar_feat)
heatmap = dense_heatmap.detach().sigmoid()
padding = self.nms_kernel_size // 2
local_max = torch.zeros_like(heatmap)
local_max_inner = F.max_pool2d(
heatmap, kernel_size=self.nms_kernel_size, stride=1, padding=0
)
local_max[:, :, padding:(-padding), padding:(-padding)] = local_max_inner
# for Pedestrian & Traffic_cone in nuScenes
if self.dataset_name == "nuScenes":
local_max[ :, 8, ] = F.max_pool2d(heatmap[:, 8], kernel_size=1, stride=1, padding=0)
local_max[ :, 9, ] = F.max_pool2d(heatmap[:, 9], kernel_size=1, stride=1, padding=0)
# for Pedestrian & Cyclist in Waymo
elif self.dataset_name == "Waymo":
local_max[ :, 1, ] = F.max_pool2d(heatmap[:, 1], kernel_size=1, stride=1, padding=0)
local_max[ :, 2, ] = F.max_pool2d(heatmap[:, 2], kernel_size=1, stride=1, padding=0)
heatmap = heatmap * (heatmap == local_max)
heatmap = heatmap.view(batch_size, heatmap.shape[1], -1)
# top num_proposals among all classes
top_proposals = heatmap.view(batch_size, -1).argsort(dim=-1, descending=True)[
..., : self.num_proposals
]
top_proposals_class = top_proposals // heatmap.shape[-1]
top_proposals_index = top_proposals % heatmap.shape[-1]
query_feat = lidar_feat_flatten.gather(
index=top_proposals_index[:, None, :].expand(-1, lidar_feat_flatten.shape[1], -1),
dim=-1,
)
self.query_labels = top_proposals_class
# add category embedding
one_hot = F.one_hot(top_proposals_class, num_classes=self.num_classes).permute(0, 2, 1)
query_cat_encoding = self.class_encoding(one_hot.float())
query_feat += query_cat_encoding
query_pos = bev_pos.gather(
index=top_proposals_index[:, None, :].permute(0, 2, 1).expand(-1, -1, bev_pos.shape[-1]),
dim=1,
)
# convert to xy
query_pos = query_pos.flip(dims=[-1])
bev_pos = bev_pos.flip(dims=[-1])
query_feat = self.decoder(
query_feat, lidar_feat_flatten, query_pos, bev_pos
)
res_layer = self.prediction_head(query_feat)
res_layer["center"] = res_layer["center"] + query_pos.permute(0, 2, 1)
res_layer["query_heatmap_score"] = heatmap.gather(
index=top_proposals_index[:, None, :].expand(-1, self.num_classes, -1),
dim=-1,
)
res_layer["dense_heatmap"] = dense_heatmap
return res_layer
def forward(self, batch_dict):
feats = batch_dict['spatial_features_2d']
res = self.predict(feats)
if not self.training:
bboxes = self.get_bboxes(res)
batch_dict['final_box_dicts'] = bboxes
else:
gt_boxes = batch_dict['gt_boxes']
gt_bboxes_3d = gt_boxes[...,:-1]
gt_labels_3d = gt_boxes[...,-1].long() - 1
loss, tb_dict = self.loss(gt_bboxes_3d, gt_labels_3d, res)
batch_dict['loss'] = loss
batch_dict['tb_dict'] = tb_dict
return batch_dict
def get_targets(self, gt_bboxes_3d, gt_labels_3d, pred_dicts):
assign_results = []
for batch_idx in range(len(gt_bboxes_3d)):
pred_dict = {}
for key in pred_dicts.keys():
pred_dict[key] = pred_dicts[key][batch_idx : batch_idx + 1]
gt_bboxes = gt_bboxes_3d[batch_idx]
valid_idx = []
# filter empty boxes
for i in range(len(gt_bboxes)):
if gt_bboxes[i][3] > 0 and gt_bboxes[i][4] > 0:
valid_idx.append(i)
assign_result = self.get_targets_single(gt_bboxes[valid_idx], gt_labels_3d[batch_idx][valid_idx], pred_dict)
assign_results.append(assign_result)
res_tuple = tuple(map(list, zip(*assign_results)))
labels = torch.cat(res_tuple[0], dim=0)
label_weights = torch.cat(res_tuple[1], dim=0)
bbox_targets = torch.cat(res_tuple[2], dim=0)
bbox_weights = torch.cat(res_tuple[3], dim=0)
num_pos = np.sum(res_tuple[4])
matched_ious = np.mean(res_tuple[5])
heatmap = torch.cat(res_tuple[6], dim=0)
return labels, label_weights, bbox_targets, bbox_weights, num_pos, matched_ious, heatmap
def get_targets_single(self, gt_bboxes_3d, gt_labels_3d, preds_dict):
num_proposals = preds_dict["center"].shape[-1]
score = copy.deepcopy(preds_dict["heatmap"].detach())
center = copy.deepcopy(preds_dict["center"].detach())
height = copy.deepcopy(preds_dict["height"].detach())
dim = copy.deepcopy(preds_dict["dim"].detach())
rot = copy.deepcopy(preds_dict["rot"].detach())
if "vel" in preds_dict.keys():
vel = copy.deepcopy(preds_dict["vel"].detach())
else:
vel = None
boxes_dict = self.decode_bbox(score, rot, dim, center, height, vel)
bboxes_tensor = boxes_dict[0]["pred_boxes"]
gt_bboxes_tensor = gt_bboxes_3d.to(score.device)
assigned_gt_inds, ious = self.bbox_assigner.assign(
bboxes_tensor, gt_bboxes_tensor, gt_labels_3d,
score, self.point_cloud_range,
)
pos_inds = torch.nonzero(assigned_gt_inds > 0, as_tuple=False).squeeze(-1).unique()
neg_inds = torch.nonzero(assigned_gt_inds == 0, as_tuple=False).squeeze(-1).unique()
pos_assigned_gt_inds = assigned_gt_inds[pos_inds] - 1
if gt_bboxes_3d.numel() == 0:
assert pos_inds.numel() == 0
pos_gt_bboxes = torch.empty_like(gt_bboxes_3d).view(-1, 9)
else:
pos_gt_bboxes = gt_bboxes_3d[pos_assigned_gt_inds.long(), :]
# create target for loss computation
bbox_targets = torch.zeros([num_proposals, self.code_size]).to(center.device)
bbox_weights = torch.zeros([num_proposals, self.code_size]).to(center.device)
ious = torch.clamp(ious, min=0.0, max=1.0)
labels = bboxes_tensor.new_zeros(num_proposals, dtype=torch.long)
label_weights = bboxes_tensor.new_zeros(num_proposals, dtype=torch.long)
if gt_labels_3d is not None: # default label is -1
labels += self.num_classes
# both pos and neg have classification loss, only pos has regression and iou loss
if len(pos_inds) > 0:
pos_bbox_targets = self.encode_bbox(pos_gt_bboxes)
bbox_targets[pos_inds, :] = pos_bbox_targets
bbox_weights[pos_inds, :] = 1.0
if gt_labels_3d is None:
labels[pos_inds] = 1
else:
labels[pos_inds] = gt_labels_3d[pos_assigned_gt_inds]
label_weights[pos_inds] = 1.0
if len(neg_inds) > 0:
label_weights[neg_inds] = 1.0
# compute dense heatmap targets
device = labels.device
target_assigner_cfg = self.model_cfg.TARGET_ASSIGNER_CONFIG
feature_map_size = (self.grid_size[:2] // self.feature_map_stride)
heatmap = gt_bboxes_3d.new_zeros(self.num_classes, feature_map_size[1], feature_map_size[0])
for idx in range(len(gt_bboxes_3d)):
width = gt_bboxes_3d[idx][3]
length = gt_bboxes_3d[idx][4]
width = width / self.voxel_size[0] / self.feature_map_stride
length = length / self.voxel_size[1] / self.feature_map_stride
if width > 0 and length > 0:
radius = centernet_utils.gaussian_radius(length.view(-1), width.view(-1), target_assigner_cfg.GAUSSIAN_OVERLAP)[0]
radius = max(target_assigner_cfg.MIN_RADIUS, int(radius))
x, y = gt_bboxes_3d[idx][0], gt_bboxes_3d[idx][1]
coor_x = (x - self.point_cloud_range[0]) / self.voxel_size[0] / self.feature_map_stride
coor_y = (y - self.point_cloud_range[1]) / self.voxel_size[1] / self.feature_map_stride
center = torch.tensor([coor_x, coor_y], dtype=torch.float32, device=device)
center_int = center.to(torch.int32)
centernet_utils.draw_gaussian_to_heatmap(heatmap[gt_labels_3d[idx]], center_int, radius)
mean_iou = ious[pos_inds].sum() / max(len(pos_inds), 1)
return (labels[None], label_weights[None], bbox_targets[None], bbox_weights[None], int(pos_inds.shape[0]), float(mean_iou), heatmap[None])
def loss(self, gt_bboxes_3d, gt_labels_3d, pred_dicts, **kwargs):
labels, label_weights, bbox_targets, bbox_weights, num_pos, matched_ious, heatmap = \
self.get_targets(gt_bboxes_3d, gt_labels_3d, pred_dicts)
loss_dict = dict()
loss_all = 0
# compute heatmap loss
loss_heatmap = self.loss_heatmap(
clip_sigmoid(pred_dicts["dense_heatmap"]),
heatmap,
).sum() / max(heatmap.eq(1).float().sum().item(), 1)
loss_dict["loss_heatmap"] = loss_heatmap.item() * self.loss_heatmap_weight
loss_all += loss_heatmap * self.loss_heatmap_weight
labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1)
cls_score = pred_dicts["heatmap"].permute(0, 2, 1).reshape(-1, self.num_classes)
one_hot_targets = torch.zeros(*list(labels.shape), self.num_classes+1, dtype=cls_score.dtype, device=labels.device)
one_hot_targets.scatter_(-1, labels.unsqueeze(dim=-1).long(), 1.0)
one_hot_targets = one_hot_targets[..., :-1]
loss_cls = self.loss_cls(
cls_score, one_hot_targets, label_weights
).sum() / max(num_pos, 1)
preds = torch.cat([pred_dicts[head_name] for head_name in self.model_cfg.SEPARATE_HEAD_CFG.HEAD_ORDER], dim=1).permute(0, 2, 1)
code_weights = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['code_weights']
reg_weights = bbox_weights * bbox_weights.new_tensor(code_weights)
loss_bbox = self.loss_bbox(preds, bbox_targets)
loss_bbox = (loss_bbox * reg_weights).sum() / max(num_pos, 1)
loss_dict["loss_cls"] = loss_cls.item() * self.loss_cls_weight
loss_dict["loss_bbox"] = loss_bbox.item() * self.loss_bbox_weight
loss_all = loss_all + loss_cls * self.loss_cls_weight + loss_bbox * self.loss_bbox_weight
loss_dict[f"matched_ious"] = loss_cls.new_tensor(matched_ious)
loss_dict['loss_trans'] = loss_all
return loss_all,loss_dict
def encode_bbox(self, bboxes):
code_size = 10
targets = torch.zeros([bboxes.shape[0], code_size]).to(bboxes.device)
targets[:, 0] = (bboxes[:, 0] - self.point_cloud_range[0]) / (self.feature_map_stride * self.voxel_size[0])
targets[:, 1] = (bboxes[:, 1] - self.point_cloud_range[1]) / (self.feature_map_stride * self.voxel_size[1])
targets[:, 3:6] = bboxes[:, 3:6].log()
targets[:, 2] = bboxes[:, 2]
targets[:, 6] = torch.sin(bboxes[:, 6])
targets[:, 7] = torch.cos(bboxes[:, 6])
if code_size == 10:
targets[:, 8:10] = bboxes[:, 7:]
return targets
def decode_bbox(self, heatmap, rot, dim, center, height, vel, filter=False):
post_process_cfg = self.model_cfg.POST_PROCESSING
score_thresh = post_process_cfg.SCORE_THRESH
post_center_range = post_process_cfg.POST_CENTER_RANGE
post_center_range = torch.tensor(post_center_range).cuda().float()
# class label
final_preds = heatmap.max(1, keepdims=False).indices
final_scores = heatmap.max(1, keepdims=False).values
center[:, 0, :] = center[:, 0, :] * self.feature_map_stride * self.voxel_size[0] + self.point_cloud_range[0]
center[:, 1, :] = center[:, 1, :] * self.feature_map_stride * self.voxel_size[1] + self.point_cloud_range[1]
dim = dim.exp()
rots, rotc = rot[:, 0:1, :], rot[:, 1:2, :]
rot = torch.atan2(rots, rotc)
if vel is None:
final_box_preds = torch.cat([center, height, dim, rot], dim=1).permute(0, 2, 1)
else:
final_box_preds = torch.cat([center, height, dim, rot, vel], dim=1).permute(0, 2, 1)
predictions_dicts = []
for i in range(heatmap.shape[0]):
boxes3d = final_box_preds[i]
scores = final_scores[i]
labels = final_preds[i]
predictions_dict = {
'pred_boxes': boxes3d,
'pred_scores': scores,
'pred_labels': labels
}
predictions_dicts.append(predictions_dict)
if filter is False:
return predictions_dicts
thresh_mask = final_scores > score_thresh
mask = (final_box_preds[..., :3] >= post_center_range[:3]).all(2)
mask &= (final_box_preds[..., :3] <= post_center_range[3:]).all(2)
predictions_dicts = []
for i in range(heatmap.shape[0]):
cmask = mask[i, :]
cmask &= thresh_mask[i]
boxes3d = final_box_preds[i, cmask]
scores = final_scores[i, cmask]
labels = final_preds[i, cmask]
predictions_dict = {
'pred_boxes': boxes3d,
'pred_scores': scores,
'pred_labels': labels,
}
predictions_dicts.append(predictions_dict)
return predictions_dicts
def get_bboxes(self, preds_dicts):
batch_size = preds_dicts["heatmap"].shape[0]
batch_score = preds_dicts["heatmap"].sigmoid()
one_hot = F.one_hot(
self.query_labels, num_classes=self.num_classes
).permute(0, 2, 1)
batch_score = batch_score * preds_dicts["query_heatmap_score"] * one_hot
batch_center = preds_dicts["center"]
batch_height = preds_dicts["height"]
batch_dim = preds_dicts["dim"]
batch_rot = preds_dicts["rot"]
batch_vel = None
if "vel" in preds_dicts:
batch_vel = preds_dicts["vel"]
ret_dict = self.decode_bbox(
batch_score, batch_rot, batch_dim,
batch_center, batch_height, batch_vel,
filter=True,
)
for k in range(batch_size):
ret_dict[k]['pred_labels'] = ret_dict[k]['pred_labels'].int() + 1
return ret_dict