Add File
This commit is contained in:
31
pcdet/ops/ingroup_inds/ingroup_inds_op.py
Normal file
31
pcdet/ops/ingroup_inds/ingroup_inds_op.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
|
from . import ingroup_inds_cuda
|
||||||
|
# import ingroup_indices
|
||||||
|
except ImportError:
|
||||||
|
ingroup_indices = None
|
||||||
|
print('Can not import ingroup indices')
|
||||||
|
|
||||||
|
ingroup_indices = ingroup_inds_cuda
|
||||||
|
|
||||||
|
from torch.autograd import Function
|
||||||
|
class IngroupIndicesFunction(Function):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, group_inds):
|
||||||
|
|
||||||
|
out_inds = torch.zeros_like(group_inds) - 1
|
||||||
|
|
||||||
|
ingroup_indices.forward(group_inds, out_inds)
|
||||||
|
|
||||||
|
ctx.mark_non_differentiable(out_inds)
|
||||||
|
|
||||||
|
return out_inds
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, g):
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
ingroup_inds = IngroupIndicesFunction.apply
|
||||||
Reference in New Issue
Block a user