diff --git a/pcdet/models/backbones_3d/focal_sparse_conv/focal_sparse_utils.py b/pcdet/models/backbones_3d/focal_sparse_conv/focal_sparse_utils.py new file mode 100644 index 0000000..d2c2344 --- /dev/null +++ b/pcdet/models/backbones_3d/focal_sparse_conv/focal_sparse_utils.py @@ -0,0 +1,147 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable + + +class FocalLoss(nn.Module): + + def __init__(self, gamma=2.0, eps=1e-7): + super(FocalLoss, self).__init__() + self.gamma = gamma + self.eps = eps + + def one_hot(self, index, classes): + size = index.size() + (classes,) + view = index.size() + (1,) + + mask = torch.Tensor(*size).fill_(0).to(index.device) + + index = index.view(*view) + ones = 1. + + if isinstance(index, Variable): + ones = Variable(torch.Tensor(index.size()).fill_(1).to(index.device)) + mask = Variable(mask, volatile=index.volatile) + + return mask.scatter_(1, index, ones) + + def forward(self, input, target): + y = self.one_hot(target, input.size(-1)) + logit = F.softmax(input, dim=-1) + logit = logit.clamp(self.eps, 1. - self.eps) + + loss = -1 * y * torch.log(logit) # cross entropy + loss = loss * (1 - logit) ** self.gamma # focal loss + + return loss.mean() + +def sort_by_indices(features, indices, features_add=None): + """ + To sort the sparse features with its indices in a convenient manner. + Args: + features: [N, C], sparse features + indices: [N, 4], indices of sparse features + features_add: [N, C], additional features to sort + """ + idx = indices[:, 1:] + idx_sum = idx.select(1, 0) * idx[:, 1].max() * idx[:, 2].max() + idx.select(1, 1) * idx[:, 2].max() + idx.select(1, 2) + _, ind = idx_sum.sort() + features = features[ind] + indices = indices[ind] + if not features_add is None: + features_add = features_add[ind] + return features, indices, features_add + +def check_repeat(features, indices, features_add=None, sort_first=True, flip_first=True): + """ + Check that whether there are replicate indices in the sparse features, + remove the replicate features if any. + """ + if sort_first: + features, indices, features_add = sort_by_indices(features, indices, features_add) + + if flip_first: + features, indices = features.flip([0]), indices.flip([0]) + + if not features_add is None: + features_add=features_add.flip([0]) + + idx = indices[:, 1:].int() + idx_sum = torch.add(torch.add(idx.select(1, 0) * idx[:, 1].max() * idx[:, 2].max(), idx.select(1, 1) * idx[:, 2].max()), idx.select(1, 2)) + _unique, inverse, counts = torch.unique_consecutive(idx_sum, return_inverse=True, return_counts=True, dim=0) + + if _unique.shape[0] < indices.shape[0]: + perm = torch.arange(inverse.size(0), dtype=inverse.dtype, device=inverse.device) + features_new = torch.zeros((_unique.shape[0], features.shape[-1]), device=features.device) + features_new.index_add_(0, inverse.long(), features) + features = features_new + perm_ = inverse.new_empty(_unique.size(0)).scatter_(0, inverse, perm) + indices = indices[perm_].int() + + if not features_add is None: + features_add_new = torch.zeros((_unique.shape[0],), device=features_add.device) + features_add_new.index_add_(0, inverse.long(), features_add) + features_add = features_add_new / counts + return features, indices, features_add + + +def split_voxels(x, b, imps_3d, voxels_3d, kernel_offsets, mask_multi=True, topk=True, threshold=0.5): + """ + Generate and split the voxels into foreground and background sparse features, based on the predicted importance values. + Args: + x: [N, C], input sparse features + b: int, batch size id + imps_3d: [N, kernelsize**3], the prediced importance values + voxels_3d: [N, 3], the 3d positions of voxel centers + kernel_offsets: [kernelsize**3, 3], the offset coords in an kernel + mask_multi: bool, whether to multiply the predicted mask to features + topk: bool, whether to use topk or threshold for selection + threshold: float, threshold value + """ + index = x.indices[:, 0] + batch_index = index==b + indices_ori = x.indices[batch_index] + features_ori = x.features[batch_index] + mask_voxel = imps_3d[batch_index, -1].sigmoid() + mask_kernel = imps_3d[batch_index, :-1].sigmoid() + + if mask_multi: + features_ori *= mask_voxel.unsqueeze(-1) + + if topk: + _, indices = mask_voxel.sort(descending=True) + indices_fore = indices[:int(mask_voxel.shape[0]*threshold)] + indices_back = indices[int(mask_voxel.shape[0]*threshold):] + else: + indices_fore = mask_voxel > threshold + indices_back = mask_voxel <= threshold + + features_fore = features_ori[indices_fore] + coords_fore = indices_ori[indices_fore] + + mask_kernel_fore = mask_kernel[indices_fore] + mask_kernel_bool = mask_kernel_fore>=threshold + voxel_kerels_imp = kernel_offsets.unsqueeze(0).repeat(mask_kernel_bool.shape[0],1, 1) + mask_kernel_fore = mask_kernel[indices_fore][mask_kernel_bool] + indices_fore_kernels = coords_fore[:, 1:].unsqueeze(1).repeat(1, kernel_offsets.shape[0], 1) + indices_with_imp = indices_fore_kernels + voxel_kerels_imp + selected_indices = indices_with_imp[mask_kernel_bool] + spatial_indices = (selected_indices[:, 0] >0) * (selected_indices[:, 1] >0) * (selected_indices[:, 2] >0) * \ + (selected_indices[:, 0] < x.spatial_shape[0]) * (selected_indices[:, 1] < x.spatial_shape[1]) * (selected_indices[:, 2] < x.spatial_shape[2]) + selected_indices = selected_indices[spatial_indices] + mask_kernel_fore = mask_kernel_fore[spatial_indices] + selected_indices = torch.cat([torch.ones((selected_indices.shape[0], 1), device=features_fore.device)*b, selected_indices], dim=1) + + selected_features = torch.zeros((selected_indices.shape[0], features_ori.shape[1]), device=features_fore.device) + + features_fore_cat = torch.cat([features_fore, selected_features], dim=0) + coords_fore = torch.cat([coords_fore, selected_indices], dim=0) + mask_kernel_fore = torch.cat([torch.ones(features_fore.shape[0], device=features_fore.device), mask_kernel_fore], dim=0) + + features_fore, coords_fore, mask_kernel_fore = check_repeat(features_fore_cat, coords_fore, features_add=mask_kernel_fore) + + features_back = features_ori[indices_back] + coords_back = indices_ori[indices_back] + + return features_fore, coords_fore, features_back, coords_back, mask_kernel_fore