Files
OpenPCDet/pcdet/models/dense_heads/point_head_template.py

211 lines
9.5 KiB
Python
Raw Normal View History

2025-09-21 20:19:05 +08:00
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...ops.roiaware_pool3d import roiaware_pool3d_utils
from ...utils import common_utils, loss_utils
class PointHeadTemplate(nn.Module):
def __init__(self, model_cfg, num_class):
super().__init__()
self.model_cfg = model_cfg
self.num_class = num_class
self.build_losses(self.model_cfg.LOSS_CONFIG)
self.forward_ret_dict = None
def build_losses(self, losses_cfg):
self.add_module(
'cls_loss_func',
loss_utils.SigmoidFocalClassificationLoss(alpha=0.25, gamma=2.0)
)
reg_loss_type = losses_cfg.get('LOSS_REG', None)
if reg_loss_type == 'smooth-l1':
self.reg_loss_func = F.smooth_l1_loss
elif reg_loss_type == 'l1':
self.reg_loss_func = F.l1_loss
elif reg_loss_type == 'WeightedSmoothL1Loss':
self.reg_loss_func = loss_utils.WeightedSmoothL1Loss(
code_weights=losses_cfg.LOSS_WEIGHTS.get('code_weights', None)
)
else:
self.reg_loss_func = F.smooth_l1_loss
@staticmethod
def make_fc_layers(fc_cfg, input_channels, output_channels):
fc_layers = []
c_in = input_channels
for k in range(0, fc_cfg.__len__()):
fc_layers.extend([
nn.Linear(c_in, fc_cfg[k], bias=False),
nn.BatchNorm1d(fc_cfg[k]),
nn.ReLU(),
])
c_in = fc_cfg[k]
fc_layers.append(nn.Linear(c_in, output_channels, bias=True))
return nn.Sequential(*fc_layers)
def assign_stack_targets(self, points, gt_boxes, extend_gt_boxes=None,
ret_box_labels=False, ret_part_labels=False,
set_ignore_flag=True, use_ball_constraint=False, central_radius=2.0):
"""
Args:
points: (N1 + N2 + N3 + ..., 4) [bs_idx, x, y, z]
gt_boxes: (B, M, 8)
extend_gt_boxes: [B, M, 8]
ret_box_labels:
ret_part_labels:
set_ignore_flag:
use_ball_constraint:
central_radius:
Returns:
point_cls_labels: (N1 + N2 + N3 + ...), long type, 0:background, -1:ignored
point_box_labels: (N1 + N2 + N3 + ..., code_size)
"""
assert len(points.shape) == 2 and points.shape[1] == 4, 'points.shape=%s' % str(points.shape)
assert len(gt_boxes.shape) == 3 and gt_boxes.shape[2] == 8, 'gt_boxes.shape=%s' % str(gt_boxes.shape)
assert extend_gt_boxes is None or len(extend_gt_boxes.shape) == 3 and extend_gt_boxes.shape[2] == 8, \
'extend_gt_boxes.shape=%s' % str(extend_gt_boxes.shape)
assert set_ignore_flag != use_ball_constraint, 'Choose one only!'
batch_size = gt_boxes.shape[0]
bs_idx = points[:, 0]
point_cls_labels = points.new_zeros(points.shape[0]).long()
point_box_labels = gt_boxes.new_zeros((points.shape[0], 8)) if ret_box_labels else None
point_part_labels = gt_boxes.new_zeros((points.shape[0], 3)) if ret_part_labels else None
for k in range(batch_size):
bs_mask = (bs_idx == k)
points_single = points[bs_mask][:, 1:4]
point_cls_labels_single = point_cls_labels.new_zeros(bs_mask.sum())
box_idxs_of_pts = roiaware_pool3d_utils.points_in_boxes_gpu(
points_single.unsqueeze(dim=0), gt_boxes[k:k + 1, :, 0:7].contiguous()
).long().squeeze(dim=0)
box_fg_flag = (box_idxs_of_pts >= 0)
if set_ignore_flag:
extend_box_idxs_of_pts = roiaware_pool3d_utils.points_in_boxes_gpu(
points_single.unsqueeze(dim=0), extend_gt_boxes[k:k+1, :, 0:7].contiguous()
).long().squeeze(dim=0)
fg_flag = box_fg_flag
ignore_flag = fg_flag ^ (extend_box_idxs_of_pts >= 0)
point_cls_labels_single[ignore_flag] = -1
elif use_ball_constraint:
box_centers = gt_boxes[k][box_idxs_of_pts][:, 0:3].clone()
box_centers[:, 2] += gt_boxes[k][box_idxs_of_pts][:, 5] / 2
ball_flag = ((box_centers - points_single).norm(dim=1) < central_radius)
fg_flag = box_fg_flag & ball_flag
else:
raise NotImplementedError
gt_box_of_fg_points = gt_boxes[k][box_idxs_of_pts[fg_flag]]
point_cls_labels_single[fg_flag] = 1 if self.num_class == 1 else gt_box_of_fg_points[:, -1].long()
point_cls_labels[bs_mask] = point_cls_labels_single
if ret_box_labels and gt_box_of_fg_points.shape[0] > 0:
point_box_labels_single = point_box_labels.new_zeros((bs_mask.sum(), 8))
fg_point_box_labels = self.box_coder.encode_torch(
gt_boxes=gt_box_of_fg_points[:, :-1], points=points_single[fg_flag],
gt_classes=gt_box_of_fg_points[:, -1].long()
)
point_box_labels_single[fg_flag] = fg_point_box_labels
point_box_labels[bs_mask] = point_box_labels_single
if ret_part_labels:
point_part_labels_single = point_part_labels.new_zeros((bs_mask.sum(), 3))
transformed_points = points_single[fg_flag] - gt_box_of_fg_points[:, 0:3]
transformed_points = common_utils.rotate_points_along_z(
transformed_points.view(-1, 1, 3), -gt_box_of_fg_points[:, 6]
).view(-1, 3)
offset = torch.tensor([0.5, 0.5, 0.5]).view(1, 3).type_as(transformed_points)
point_part_labels_single[fg_flag] = (transformed_points / gt_box_of_fg_points[:, 3:6]) + offset
point_part_labels[bs_mask] = point_part_labels_single
targets_dict = {
'point_cls_labels': point_cls_labels,
'point_box_labels': point_box_labels,
'point_part_labels': point_part_labels
}
return targets_dict
def get_cls_layer_loss(self, tb_dict=None):
point_cls_labels = self.forward_ret_dict['point_cls_labels'].view(-1)
point_cls_preds = self.forward_ret_dict['point_cls_preds'].view(-1, self.num_class)
positives = (point_cls_labels > 0)
negative_cls_weights = (point_cls_labels == 0) * 1.0
cls_weights = (negative_cls_weights + 1.0 * positives).float()
pos_normalizer = positives.sum(dim=0).float()
cls_weights /= torch.clamp(pos_normalizer, min=1.0)
one_hot_targets = point_cls_preds.new_zeros(*list(point_cls_labels.shape), self.num_class + 1)
one_hot_targets.scatter_(-1, (point_cls_labels * (point_cls_labels >= 0).long()).unsqueeze(dim=-1).long(), 1.0)
one_hot_targets = one_hot_targets[..., 1:]
cls_loss_src = self.cls_loss_func(point_cls_preds, one_hot_targets, weights=cls_weights)
point_loss_cls = cls_loss_src.sum()
loss_weights_dict = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS
point_loss_cls = point_loss_cls * loss_weights_dict['point_cls_weight']
if tb_dict is None:
tb_dict = {}
tb_dict.update({
'point_loss_cls': point_loss_cls.item(),
'point_pos_num': pos_normalizer.item()
})
return point_loss_cls, tb_dict
def get_part_layer_loss(self, tb_dict=None):
pos_mask = self.forward_ret_dict['point_cls_labels'] > 0
pos_normalizer = max(1, (pos_mask > 0).sum().item())
point_part_labels = self.forward_ret_dict['point_part_labels']
point_part_preds = self.forward_ret_dict['point_part_preds']
point_loss_part = F.binary_cross_entropy(torch.sigmoid(point_part_preds), point_part_labels, reduction='none')
point_loss_part = (point_loss_part.sum(dim=-1) * pos_mask.float()).sum() / (3 * pos_normalizer)
loss_weights_dict = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS
point_loss_part = point_loss_part * loss_weights_dict['point_part_weight']
if tb_dict is None:
tb_dict = {}
tb_dict.update({'point_loss_part': point_loss_part.item()})
return point_loss_part, tb_dict
def get_box_layer_loss(self, tb_dict=None):
pos_mask = self.forward_ret_dict['point_cls_labels'] > 0
point_box_labels = self.forward_ret_dict['point_box_labels']
point_box_preds = self.forward_ret_dict['point_box_preds']
reg_weights = pos_mask.float()
pos_normalizer = pos_mask.sum().float()
reg_weights /= torch.clamp(pos_normalizer, min=1.0)
point_loss_box_src = self.reg_loss_func(
point_box_preds[None, ...], point_box_labels[None, ...], weights=reg_weights[None, ...]
)
point_loss_box = point_loss_box_src.sum()
loss_weights_dict = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS
point_loss_box = point_loss_box * loss_weights_dict['point_box_weight']
if tb_dict is None:
tb_dict = {}
tb_dict.update({'point_loss_box': point_loss_box.item()})
return point_loss_box, tb_dict
def generate_predicted_boxes(self, points, point_cls_preds, point_box_preds):
"""
Args:
points: (N, 3)
point_cls_preds: (N, num_class)
point_box_preds: (N, box_code_size)
Returns:
point_cls_preds: (N, num_class)
point_box_preds: (N, box_code_size)
"""
_, pred_classes = point_cls_preds.max(dim=-1)
point_box_preds = self.box_coder.decode_torch(point_box_preds, points, pred_classes + 1)
return point_cls_preds, point_box_preds
def forward(self, **kwargs):
raise NotImplementedError