Add File
This commit is contained in:
210
pcdet/models/dense_heads/point_head_template.py
Normal file
210
pcdet/models/dense_heads/point_head_template.py
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from ...ops.roiaware_pool3d import roiaware_pool3d_utils
|
||||||
|
from ...utils import common_utils, loss_utils
|
||||||
|
|
||||||
|
|
||||||
|
class PointHeadTemplate(nn.Module):
|
||||||
|
def __init__(self, model_cfg, num_class):
|
||||||
|
super().__init__()
|
||||||
|
self.model_cfg = model_cfg
|
||||||
|
self.num_class = num_class
|
||||||
|
|
||||||
|
self.build_losses(self.model_cfg.LOSS_CONFIG)
|
||||||
|
self.forward_ret_dict = None
|
||||||
|
|
||||||
|
def build_losses(self, losses_cfg):
|
||||||
|
self.add_module(
|
||||||
|
'cls_loss_func',
|
||||||
|
loss_utils.SigmoidFocalClassificationLoss(alpha=0.25, gamma=2.0)
|
||||||
|
)
|
||||||
|
reg_loss_type = losses_cfg.get('LOSS_REG', None)
|
||||||
|
if reg_loss_type == 'smooth-l1':
|
||||||
|
self.reg_loss_func = F.smooth_l1_loss
|
||||||
|
elif reg_loss_type == 'l1':
|
||||||
|
self.reg_loss_func = F.l1_loss
|
||||||
|
elif reg_loss_type == 'WeightedSmoothL1Loss':
|
||||||
|
self.reg_loss_func = loss_utils.WeightedSmoothL1Loss(
|
||||||
|
code_weights=losses_cfg.LOSS_WEIGHTS.get('code_weights', None)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.reg_loss_func = F.smooth_l1_loss
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_fc_layers(fc_cfg, input_channels, output_channels):
|
||||||
|
fc_layers = []
|
||||||
|
c_in = input_channels
|
||||||
|
for k in range(0, fc_cfg.__len__()):
|
||||||
|
fc_layers.extend([
|
||||||
|
nn.Linear(c_in, fc_cfg[k], bias=False),
|
||||||
|
nn.BatchNorm1d(fc_cfg[k]),
|
||||||
|
nn.ReLU(),
|
||||||
|
])
|
||||||
|
c_in = fc_cfg[k]
|
||||||
|
fc_layers.append(nn.Linear(c_in, output_channels, bias=True))
|
||||||
|
return nn.Sequential(*fc_layers)
|
||||||
|
|
||||||
|
def assign_stack_targets(self, points, gt_boxes, extend_gt_boxes=None,
|
||||||
|
ret_box_labels=False, ret_part_labels=False,
|
||||||
|
set_ignore_flag=True, use_ball_constraint=False, central_radius=2.0):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
points: (N1 + N2 + N3 + ..., 4) [bs_idx, x, y, z]
|
||||||
|
gt_boxes: (B, M, 8)
|
||||||
|
extend_gt_boxes: [B, M, 8]
|
||||||
|
ret_box_labels:
|
||||||
|
ret_part_labels:
|
||||||
|
set_ignore_flag:
|
||||||
|
use_ball_constraint:
|
||||||
|
central_radius:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
point_cls_labels: (N1 + N2 + N3 + ...), long type, 0:background, -1:ignored
|
||||||
|
point_box_labels: (N1 + N2 + N3 + ..., code_size)
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert len(points.shape) == 2 and points.shape[1] == 4, 'points.shape=%s' % str(points.shape)
|
||||||
|
assert len(gt_boxes.shape) == 3 and gt_boxes.shape[2] == 8, 'gt_boxes.shape=%s' % str(gt_boxes.shape)
|
||||||
|
assert extend_gt_boxes is None or len(extend_gt_boxes.shape) == 3 and extend_gt_boxes.shape[2] == 8, \
|
||||||
|
'extend_gt_boxes.shape=%s' % str(extend_gt_boxes.shape)
|
||||||
|
assert set_ignore_flag != use_ball_constraint, 'Choose one only!'
|
||||||
|
batch_size = gt_boxes.shape[0]
|
||||||
|
bs_idx = points[:, 0]
|
||||||
|
point_cls_labels = points.new_zeros(points.shape[0]).long()
|
||||||
|
point_box_labels = gt_boxes.new_zeros((points.shape[0], 8)) if ret_box_labels else None
|
||||||
|
point_part_labels = gt_boxes.new_zeros((points.shape[0], 3)) if ret_part_labels else None
|
||||||
|
for k in range(batch_size):
|
||||||
|
bs_mask = (bs_idx == k)
|
||||||
|
points_single = points[bs_mask][:, 1:4]
|
||||||
|
point_cls_labels_single = point_cls_labels.new_zeros(bs_mask.sum())
|
||||||
|
box_idxs_of_pts = roiaware_pool3d_utils.points_in_boxes_gpu(
|
||||||
|
points_single.unsqueeze(dim=0), gt_boxes[k:k + 1, :, 0:7].contiguous()
|
||||||
|
).long().squeeze(dim=0)
|
||||||
|
box_fg_flag = (box_idxs_of_pts >= 0)
|
||||||
|
if set_ignore_flag:
|
||||||
|
extend_box_idxs_of_pts = roiaware_pool3d_utils.points_in_boxes_gpu(
|
||||||
|
points_single.unsqueeze(dim=0), extend_gt_boxes[k:k+1, :, 0:7].contiguous()
|
||||||
|
).long().squeeze(dim=0)
|
||||||
|
fg_flag = box_fg_flag
|
||||||
|
ignore_flag = fg_flag ^ (extend_box_idxs_of_pts >= 0)
|
||||||
|
point_cls_labels_single[ignore_flag] = -1
|
||||||
|
elif use_ball_constraint:
|
||||||
|
box_centers = gt_boxes[k][box_idxs_of_pts][:, 0:3].clone()
|
||||||
|
box_centers[:, 2] += gt_boxes[k][box_idxs_of_pts][:, 5] / 2
|
||||||
|
ball_flag = ((box_centers - points_single).norm(dim=1) < central_radius)
|
||||||
|
fg_flag = box_fg_flag & ball_flag
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
gt_box_of_fg_points = gt_boxes[k][box_idxs_of_pts[fg_flag]]
|
||||||
|
point_cls_labels_single[fg_flag] = 1 if self.num_class == 1 else gt_box_of_fg_points[:, -1].long()
|
||||||
|
point_cls_labels[bs_mask] = point_cls_labels_single
|
||||||
|
|
||||||
|
if ret_box_labels and gt_box_of_fg_points.shape[0] > 0:
|
||||||
|
point_box_labels_single = point_box_labels.new_zeros((bs_mask.sum(), 8))
|
||||||
|
fg_point_box_labels = self.box_coder.encode_torch(
|
||||||
|
gt_boxes=gt_box_of_fg_points[:, :-1], points=points_single[fg_flag],
|
||||||
|
gt_classes=gt_box_of_fg_points[:, -1].long()
|
||||||
|
)
|
||||||
|
point_box_labels_single[fg_flag] = fg_point_box_labels
|
||||||
|
point_box_labels[bs_mask] = point_box_labels_single
|
||||||
|
|
||||||
|
if ret_part_labels:
|
||||||
|
point_part_labels_single = point_part_labels.new_zeros((bs_mask.sum(), 3))
|
||||||
|
transformed_points = points_single[fg_flag] - gt_box_of_fg_points[:, 0:3]
|
||||||
|
transformed_points = common_utils.rotate_points_along_z(
|
||||||
|
transformed_points.view(-1, 1, 3), -gt_box_of_fg_points[:, 6]
|
||||||
|
).view(-1, 3)
|
||||||
|
offset = torch.tensor([0.5, 0.5, 0.5]).view(1, 3).type_as(transformed_points)
|
||||||
|
point_part_labels_single[fg_flag] = (transformed_points / gt_box_of_fg_points[:, 3:6]) + offset
|
||||||
|
point_part_labels[bs_mask] = point_part_labels_single
|
||||||
|
|
||||||
|
targets_dict = {
|
||||||
|
'point_cls_labels': point_cls_labels,
|
||||||
|
'point_box_labels': point_box_labels,
|
||||||
|
'point_part_labels': point_part_labels
|
||||||
|
}
|
||||||
|
return targets_dict
|
||||||
|
|
||||||
|
def get_cls_layer_loss(self, tb_dict=None):
|
||||||
|
point_cls_labels = self.forward_ret_dict['point_cls_labels'].view(-1)
|
||||||
|
point_cls_preds = self.forward_ret_dict['point_cls_preds'].view(-1, self.num_class)
|
||||||
|
|
||||||
|
positives = (point_cls_labels > 0)
|
||||||
|
negative_cls_weights = (point_cls_labels == 0) * 1.0
|
||||||
|
cls_weights = (negative_cls_weights + 1.0 * positives).float()
|
||||||
|
pos_normalizer = positives.sum(dim=0).float()
|
||||||
|
cls_weights /= torch.clamp(pos_normalizer, min=1.0)
|
||||||
|
|
||||||
|
one_hot_targets = point_cls_preds.new_zeros(*list(point_cls_labels.shape), self.num_class + 1)
|
||||||
|
one_hot_targets.scatter_(-1, (point_cls_labels * (point_cls_labels >= 0).long()).unsqueeze(dim=-1).long(), 1.0)
|
||||||
|
one_hot_targets = one_hot_targets[..., 1:]
|
||||||
|
cls_loss_src = self.cls_loss_func(point_cls_preds, one_hot_targets, weights=cls_weights)
|
||||||
|
point_loss_cls = cls_loss_src.sum()
|
||||||
|
|
||||||
|
loss_weights_dict = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS
|
||||||
|
point_loss_cls = point_loss_cls * loss_weights_dict['point_cls_weight']
|
||||||
|
if tb_dict is None:
|
||||||
|
tb_dict = {}
|
||||||
|
tb_dict.update({
|
||||||
|
'point_loss_cls': point_loss_cls.item(),
|
||||||
|
'point_pos_num': pos_normalizer.item()
|
||||||
|
})
|
||||||
|
return point_loss_cls, tb_dict
|
||||||
|
|
||||||
|
def get_part_layer_loss(self, tb_dict=None):
|
||||||
|
pos_mask = self.forward_ret_dict['point_cls_labels'] > 0
|
||||||
|
pos_normalizer = max(1, (pos_mask > 0).sum().item())
|
||||||
|
point_part_labels = self.forward_ret_dict['point_part_labels']
|
||||||
|
point_part_preds = self.forward_ret_dict['point_part_preds']
|
||||||
|
point_loss_part = F.binary_cross_entropy(torch.sigmoid(point_part_preds), point_part_labels, reduction='none')
|
||||||
|
point_loss_part = (point_loss_part.sum(dim=-1) * pos_mask.float()).sum() / (3 * pos_normalizer)
|
||||||
|
|
||||||
|
loss_weights_dict = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS
|
||||||
|
point_loss_part = point_loss_part * loss_weights_dict['point_part_weight']
|
||||||
|
if tb_dict is None:
|
||||||
|
tb_dict = {}
|
||||||
|
tb_dict.update({'point_loss_part': point_loss_part.item()})
|
||||||
|
return point_loss_part, tb_dict
|
||||||
|
|
||||||
|
def get_box_layer_loss(self, tb_dict=None):
|
||||||
|
pos_mask = self.forward_ret_dict['point_cls_labels'] > 0
|
||||||
|
point_box_labels = self.forward_ret_dict['point_box_labels']
|
||||||
|
point_box_preds = self.forward_ret_dict['point_box_preds']
|
||||||
|
|
||||||
|
reg_weights = pos_mask.float()
|
||||||
|
pos_normalizer = pos_mask.sum().float()
|
||||||
|
reg_weights /= torch.clamp(pos_normalizer, min=1.0)
|
||||||
|
|
||||||
|
point_loss_box_src = self.reg_loss_func(
|
||||||
|
point_box_preds[None, ...], point_box_labels[None, ...], weights=reg_weights[None, ...]
|
||||||
|
)
|
||||||
|
point_loss_box = point_loss_box_src.sum()
|
||||||
|
|
||||||
|
loss_weights_dict = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS
|
||||||
|
point_loss_box = point_loss_box * loss_weights_dict['point_box_weight']
|
||||||
|
if tb_dict is None:
|
||||||
|
tb_dict = {}
|
||||||
|
tb_dict.update({'point_loss_box': point_loss_box.item()})
|
||||||
|
return point_loss_box, tb_dict
|
||||||
|
|
||||||
|
def generate_predicted_boxes(self, points, point_cls_preds, point_box_preds):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
points: (N, 3)
|
||||||
|
point_cls_preds: (N, num_class)
|
||||||
|
point_box_preds: (N, box_code_size)
|
||||||
|
Returns:
|
||||||
|
point_cls_preds: (N, num_class)
|
||||||
|
point_box_preds: (N, box_code_size)
|
||||||
|
|
||||||
|
"""
|
||||||
|
_, pred_classes = point_cls_preds.max(dim=-1)
|
||||||
|
point_box_preds = self.box_coder.decode_torch(point_box_preds, points, pred_classes + 1)
|
||||||
|
|
||||||
|
return point_cls_preds, point_box_preds
|
||||||
|
|
||||||
|
def forward(self, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
Reference in New Issue
Block a user