From 128902bbda6a4e641d6f96d6337a5853720bc912 Mon Sep 17 00:00:00 2001 From: inter Date: Sun, 21 Sep 2025 20:19:05 +0800 Subject: [PATCH] Add File --- .../models/dense_heads/point_head_template.py | 210 ++++++++++++++++++ 1 file changed, 210 insertions(+) create mode 100644 pcdet/models/dense_heads/point_head_template.py diff --git a/pcdet/models/dense_heads/point_head_template.py b/pcdet/models/dense_heads/point_head_template.py new file mode 100644 index 0000000..9ea0af0 --- /dev/null +++ b/pcdet/models/dense_heads/point_head_template.py @@ -0,0 +1,210 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...ops.roiaware_pool3d import roiaware_pool3d_utils +from ...utils import common_utils, loss_utils + + +class PointHeadTemplate(nn.Module): + def __init__(self, model_cfg, num_class): + super().__init__() + self.model_cfg = model_cfg + self.num_class = num_class + + self.build_losses(self.model_cfg.LOSS_CONFIG) + self.forward_ret_dict = None + + def build_losses(self, losses_cfg): + self.add_module( + 'cls_loss_func', + loss_utils.SigmoidFocalClassificationLoss(alpha=0.25, gamma=2.0) + ) + reg_loss_type = losses_cfg.get('LOSS_REG', None) + if reg_loss_type == 'smooth-l1': + self.reg_loss_func = F.smooth_l1_loss + elif reg_loss_type == 'l1': + self.reg_loss_func = F.l1_loss + elif reg_loss_type == 'WeightedSmoothL1Loss': + self.reg_loss_func = loss_utils.WeightedSmoothL1Loss( + code_weights=losses_cfg.LOSS_WEIGHTS.get('code_weights', None) + ) + else: + self.reg_loss_func = F.smooth_l1_loss + + @staticmethod + def make_fc_layers(fc_cfg, input_channels, output_channels): + fc_layers = [] + c_in = input_channels + for k in range(0, fc_cfg.__len__()): + fc_layers.extend([ + nn.Linear(c_in, fc_cfg[k], bias=False), + nn.BatchNorm1d(fc_cfg[k]), + nn.ReLU(), + ]) + c_in = fc_cfg[k] + fc_layers.append(nn.Linear(c_in, output_channels, bias=True)) + return nn.Sequential(*fc_layers) + + def assign_stack_targets(self, points, gt_boxes, extend_gt_boxes=None, + ret_box_labels=False, ret_part_labels=False, + set_ignore_flag=True, use_ball_constraint=False, central_radius=2.0): + """ + Args: + points: (N1 + N2 + N3 + ..., 4) [bs_idx, x, y, z] + gt_boxes: (B, M, 8) + extend_gt_boxes: [B, M, 8] + ret_box_labels: + ret_part_labels: + set_ignore_flag: + use_ball_constraint: + central_radius: + + Returns: + point_cls_labels: (N1 + N2 + N3 + ...), long type, 0:background, -1:ignored + point_box_labels: (N1 + N2 + N3 + ..., code_size) + + """ + assert len(points.shape) == 2 and points.shape[1] == 4, 'points.shape=%s' % str(points.shape) + assert len(gt_boxes.shape) == 3 and gt_boxes.shape[2] == 8, 'gt_boxes.shape=%s' % str(gt_boxes.shape) + assert extend_gt_boxes is None or len(extend_gt_boxes.shape) == 3 and extend_gt_boxes.shape[2] == 8, \ + 'extend_gt_boxes.shape=%s' % str(extend_gt_boxes.shape) + assert set_ignore_flag != use_ball_constraint, 'Choose one only!' + batch_size = gt_boxes.shape[0] + bs_idx = points[:, 0] + point_cls_labels = points.new_zeros(points.shape[0]).long() + point_box_labels = gt_boxes.new_zeros((points.shape[0], 8)) if ret_box_labels else None + point_part_labels = gt_boxes.new_zeros((points.shape[0], 3)) if ret_part_labels else None + for k in range(batch_size): + bs_mask = (bs_idx == k) + points_single = points[bs_mask][:, 1:4] + point_cls_labels_single = point_cls_labels.new_zeros(bs_mask.sum()) + box_idxs_of_pts = roiaware_pool3d_utils.points_in_boxes_gpu( + points_single.unsqueeze(dim=0), gt_boxes[k:k + 1, :, 0:7].contiguous() + ).long().squeeze(dim=0) + box_fg_flag = (box_idxs_of_pts >= 0) + if set_ignore_flag: + extend_box_idxs_of_pts = roiaware_pool3d_utils.points_in_boxes_gpu( + points_single.unsqueeze(dim=0), extend_gt_boxes[k:k+1, :, 0:7].contiguous() + ).long().squeeze(dim=0) + fg_flag = box_fg_flag + ignore_flag = fg_flag ^ (extend_box_idxs_of_pts >= 0) + point_cls_labels_single[ignore_flag] = -1 + elif use_ball_constraint: + box_centers = gt_boxes[k][box_idxs_of_pts][:, 0:3].clone() + box_centers[:, 2] += gt_boxes[k][box_idxs_of_pts][:, 5] / 2 + ball_flag = ((box_centers - points_single).norm(dim=1) < central_radius) + fg_flag = box_fg_flag & ball_flag + else: + raise NotImplementedError + + gt_box_of_fg_points = gt_boxes[k][box_idxs_of_pts[fg_flag]] + point_cls_labels_single[fg_flag] = 1 if self.num_class == 1 else gt_box_of_fg_points[:, -1].long() + point_cls_labels[bs_mask] = point_cls_labels_single + + if ret_box_labels and gt_box_of_fg_points.shape[0] > 0: + point_box_labels_single = point_box_labels.new_zeros((bs_mask.sum(), 8)) + fg_point_box_labels = self.box_coder.encode_torch( + gt_boxes=gt_box_of_fg_points[:, :-1], points=points_single[fg_flag], + gt_classes=gt_box_of_fg_points[:, -1].long() + ) + point_box_labels_single[fg_flag] = fg_point_box_labels + point_box_labels[bs_mask] = point_box_labels_single + + if ret_part_labels: + point_part_labels_single = point_part_labels.new_zeros((bs_mask.sum(), 3)) + transformed_points = points_single[fg_flag] - gt_box_of_fg_points[:, 0:3] + transformed_points = common_utils.rotate_points_along_z( + transformed_points.view(-1, 1, 3), -gt_box_of_fg_points[:, 6] + ).view(-1, 3) + offset = torch.tensor([0.5, 0.5, 0.5]).view(1, 3).type_as(transformed_points) + point_part_labels_single[fg_flag] = (transformed_points / gt_box_of_fg_points[:, 3:6]) + offset + point_part_labels[bs_mask] = point_part_labels_single + + targets_dict = { + 'point_cls_labels': point_cls_labels, + 'point_box_labels': point_box_labels, + 'point_part_labels': point_part_labels + } + return targets_dict + + def get_cls_layer_loss(self, tb_dict=None): + point_cls_labels = self.forward_ret_dict['point_cls_labels'].view(-1) + point_cls_preds = self.forward_ret_dict['point_cls_preds'].view(-1, self.num_class) + + positives = (point_cls_labels > 0) + negative_cls_weights = (point_cls_labels == 0) * 1.0 + cls_weights = (negative_cls_weights + 1.0 * positives).float() + pos_normalizer = positives.sum(dim=0).float() + cls_weights /= torch.clamp(pos_normalizer, min=1.0) + + one_hot_targets = point_cls_preds.new_zeros(*list(point_cls_labels.shape), self.num_class + 1) + one_hot_targets.scatter_(-1, (point_cls_labels * (point_cls_labels >= 0).long()).unsqueeze(dim=-1).long(), 1.0) + one_hot_targets = one_hot_targets[..., 1:] + cls_loss_src = self.cls_loss_func(point_cls_preds, one_hot_targets, weights=cls_weights) + point_loss_cls = cls_loss_src.sum() + + loss_weights_dict = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS + point_loss_cls = point_loss_cls * loss_weights_dict['point_cls_weight'] + if tb_dict is None: + tb_dict = {} + tb_dict.update({ + 'point_loss_cls': point_loss_cls.item(), + 'point_pos_num': pos_normalizer.item() + }) + return point_loss_cls, tb_dict + + def get_part_layer_loss(self, tb_dict=None): + pos_mask = self.forward_ret_dict['point_cls_labels'] > 0 + pos_normalizer = max(1, (pos_mask > 0).sum().item()) + point_part_labels = self.forward_ret_dict['point_part_labels'] + point_part_preds = self.forward_ret_dict['point_part_preds'] + point_loss_part = F.binary_cross_entropy(torch.sigmoid(point_part_preds), point_part_labels, reduction='none') + point_loss_part = (point_loss_part.sum(dim=-1) * pos_mask.float()).sum() / (3 * pos_normalizer) + + loss_weights_dict = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS + point_loss_part = point_loss_part * loss_weights_dict['point_part_weight'] + if tb_dict is None: + tb_dict = {} + tb_dict.update({'point_loss_part': point_loss_part.item()}) + return point_loss_part, tb_dict + + def get_box_layer_loss(self, tb_dict=None): + pos_mask = self.forward_ret_dict['point_cls_labels'] > 0 + point_box_labels = self.forward_ret_dict['point_box_labels'] + point_box_preds = self.forward_ret_dict['point_box_preds'] + + reg_weights = pos_mask.float() + pos_normalizer = pos_mask.sum().float() + reg_weights /= torch.clamp(pos_normalizer, min=1.0) + + point_loss_box_src = self.reg_loss_func( + point_box_preds[None, ...], point_box_labels[None, ...], weights=reg_weights[None, ...] + ) + point_loss_box = point_loss_box_src.sum() + + loss_weights_dict = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS + point_loss_box = point_loss_box * loss_weights_dict['point_box_weight'] + if tb_dict is None: + tb_dict = {} + tb_dict.update({'point_loss_box': point_loss_box.item()}) + return point_loss_box, tb_dict + + def generate_predicted_boxes(self, points, point_cls_preds, point_box_preds): + """ + Args: + points: (N, 3) + point_cls_preds: (N, num_class) + point_box_preds: (N, box_code_size) + Returns: + point_cls_preds: (N, num_class) + point_box_preds: (N, box_code_size) + + """ + _, pred_classes = point_cls_preds.max(dim=-1) + point_box_preds = self.box_coder.decode_torch(point_box_preds, points, pred_classes + 1) + + return point_cls_preds, point_box_preds + + def forward(self, **kwargs): + raise NotImplementedError