Add File
This commit is contained in:
@@ -0,0 +1,147 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
class FocalLoss(nn.Module):
|
||||
|
||||
def __init__(self, gamma=2.0, eps=1e-7):
|
||||
super(FocalLoss, self).__init__()
|
||||
self.gamma = gamma
|
||||
self.eps = eps
|
||||
|
||||
def one_hot(self, index, classes):
|
||||
size = index.size() + (classes,)
|
||||
view = index.size() + (1,)
|
||||
|
||||
mask = torch.Tensor(*size).fill_(0).to(index.device)
|
||||
|
||||
index = index.view(*view)
|
||||
ones = 1.
|
||||
|
||||
if isinstance(index, Variable):
|
||||
ones = Variable(torch.Tensor(index.size()).fill_(1).to(index.device))
|
||||
mask = Variable(mask, volatile=index.volatile)
|
||||
|
||||
return mask.scatter_(1, index, ones)
|
||||
|
||||
def forward(self, input, target):
|
||||
y = self.one_hot(target, input.size(-1))
|
||||
logit = F.softmax(input, dim=-1)
|
||||
logit = logit.clamp(self.eps, 1. - self.eps)
|
||||
|
||||
loss = -1 * y * torch.log(logit) # cross entropy
|
||||
loss = loss * (1 - logit) ** self.gamma # focal loss
|
||||
|
||||
return loss.mean()
|
||||
|
||||
def sort_by_indices(features, indices, features_add=None):
|
||||
"""
|
||||
To sort the sparse features with its indices in a convenient manner.
|
||||
Args:
|
||||
features: [N, C], sparse features
|
||||
indices: [N, 4], indices of sparse features
|
||||
features_add: [N, C], additional features to sort
|
||||
"""
|
||||
idx = indices[:, 1:]
|
||||
idx_sum = idx.select(1, 0) * idx[:, 1].max() * idx[:, 2].max() + idx.select(1, 1) * idx[:, 2].max() + idx.select(1, 2)
|
||||
_, ind = idx_sum.sort()
|
||||
features = features[ind]
|
||||
indices = indices[ind]
|
||||
if not features_add is None:
|
||||
features_add = features_add[ind]
|
||||
return features, indices, features_add
|
||||
|
||||
def check_repeat(features, indices, features_add=None, sort_first=True, flip_first=True):
|
||||
"""
|
||||
Check that whether there are replicate indices in the sparse features,
|
||||
remove the replicate features if any.
|
||||
"""
|
||||
if sort_first:
|
||||
features, indices, features_add = sort_by_indices(features, indices, features_add)
|
||||
|
||||
if flip_first:
|
||||
features, indices = features.flip([0]), indices.flip([0])
|
||||
|
||||
if not features_add is None:
|
||||
features_add=features_add.flip([0])
|
||||
|
||||
idx = indices[:, 1:].int()
|
||||
idx_sum = torch.add(torch.add(idx.select(1, 0) * idx[:, 1].max() * idx[:, 2].max(), idx.select(1, 1) * idx[:, 2].max()), idx.select(1, 2))
|
||||
_unique, inverse, counts = torch.unique_consecutive(idx_sum, return_inverse=True, return_counts=True, dim=0)
|
||||
|
||||
if _unique.shape[0] < indices.shape[0]:
|
||||
perm = torch.arange(inverse.size(0), dtype=inverse.dtype, device=inverse.device)
|
||||
features_new = torch.zeros((_unique.shape[0], features.shape[-1]), device=features.device)
|
||||
features_new.index_add_(0, inverse.long(), features)
|
||||
features = features_new
|
||||
perm_ = inverse.new_empty(_unique.size(0)).scatter_(0, inverse, perm)
|
||||
indices = indices[perm_].int()
|
||||
|
||||
if not features_add is None:
|
||||
features_add_new = torch.zeros((_unique.shape[0],), device=features_add.device)
|
||||
features_add_new.index_add_(0, inverse.long(), features_add)
|
||||
features_add = features_add_new / counts
|
||||
return features, indices, features_add
|
||||
|
||||
|
||||
def split_voxels(x, b, imps_3d, voxels_3d, kernel_offsets, mask_multi=True, topk=True, threshold=0.5):
|
||||
"""
|
||||
Generate and split the voxels into foreground and background sparse features, based on the predicted importance values.
|
||||
Args:
|
||||
x: [N, C], input sparse features
|
||||
b: int, batch size id
|
||||
imps_3d: [N, kernelsize**3], the prediced importance values
|
||||
voxels_3d: [N, 3], the 3d positions of voxel centers
|
||||
kernel_offsets: [kernelsize**3, 3], the offset coords in an kernel
|
||||
mask_multi: bool, whether to multiply the predicted mask to features
|
||||
topk: bool, whether to use topk or threshold for selection
|
||||
threshold: float, threshold value
|
||||
"""
|
||||
index = x.indices[:, 0]
|
||||
batch_index = index==b
|
||||
indices_ori = x.indices[batch_index]
|
||||
features_ori = x.features[batch_index]
|
||||
mask_voxel = imps_3d[batch_index, -1].sigmoid()
|
||||
mask_kernel = imps_3d[batch_index, :-1].sigmoid()
|
||||
|
||||
if mask_multi:
|
||||
features_ori *= mask_voxel.unsqueeze(-1)
|
||||
|
||||
if topk:
|
||||
_, indices = mask_voxel.sort(descending=True)
|
||||
indices_fore = indices[:int(mask_voxel.shape[0]*threshold)]
|
||||
indices_back = indices[int(mask_voxel.shape[0]*threshold):]
|
||||
else:
|
||||
indices_fore = mask_voxel > threshold
|
||||
indices_back = mask_voxel <= threshold
|
||||
|
||||
features_fore = features_ori[indices_fore]
|
||||
coords_fore = indices_ori[indices_fore]
|
||||
|
||||
mask_kernel_fore = mask_kernel[indices_fore]
|
||||
mask_kernel_bool = mask_kernel_fore>=threshold
|
||||
voxel_kerels_imp = kernel_offsets.unsqueeze(0).repeat(mask_kernel_bool.shape[0],1, 1)
|
||||
mask_kernel_fore = mask_kernel[indices_fore][mask_kernel_bool]
|
||||
indices_fore_kernels = coords_fore[:, 1:].unsqueeze(1).repeat(1, kernel_offsets.shape[0], 1)
|
||||
indices_with_imp = indices_fore_kernels + voxel_kerels_imp
|
||||
selected_indices = indices_with_imp[mask_kernel_bool]
|
||||
spatial_indices = (selected_indices[:, 0] >0) * (selected_indices[:, 1] >0) * (selected_indices[:, 2] >0) * \
|
||||
(selected_indices[:, 0] < x.spatial_shape[0]) * (selected_indices[:, 1] < x.spatial_shape[1]) * (selected_indices[:, 2] < x.spatial_shape[2])
|
||||
selected_indices = selected_indices[spatial_indices]
|
||||
mask_kernel_fore = mask_kernel_fore[spatial_indices]
|
||||
selected_indices = torch.cat([torch.ones((selected_indices.shape[0], 1), device=features_fore.device)*b, selected_indices], dim=1)
|
||||
|
||||
selected_features = torch.zeros((selected_indices.shape[0], features_ori.shape[1]), device=features_fore.device)
|
||||
|
||||
features_fore_cat = torch.cat([features_fore, selected_features], dim=0)
|
||||
coords_fore = torch.cat([coords_fore, selected_indices], dim=0)
|
||||
mask_kernel_fore = torch.cat([torch.ones(features_fore.shape[0], device=features_fore.device), mask_kernel_fore], dim=0)
|
||||
|
||||
features_fore, coords_fore, mask_kernel_fore = check_repeat(features_fore_cat, coords_fore, features_add=mask_kernel_fore)
|
||||
|
||||
features_back = features_ori[indices_back]
|
||||
coords_back = indices_ori[indices_back]
|
||||
|
||||
return features_fore, coords_fore, features_back, coords_back, mask_kernel_fore
|
||||
Reference in New Issue
Block a user