Add File
This commit is contained in:
31
pcdet/ops/ingroup_inds/ingroup_inds_op.py
Normal file
31
pcdet/ops/ingroup_inds/ingroup_inds_op.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import torch
|
||||
|
||||
try:
|
||||
from . import ingroup_inds_cuda
|
||||
# import ingroup_indices
|
||||
except ImportError:
|
||||
ingroup_indices = None
|
||||
print('Can not import ingroup indices')
|
||||
|
||||
ingroup_indices = ingroup_inds_cuda
|
||||
|
||||
from torch.autograd import Function
|
||||
class IngroupIndicesFunction(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, group_inds):
|
||||
|
||||
out_inds = torch.zeros_like(group_inds) - 1
|
||||
|
||||
ingroup_indices.forward(group_inds, out_inds)
|
||||
|
||||
ctx.mark_non_differentiable(out_inds)
|
||||
|
||||
return out_inds
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, g):
|
||||
|
||||
return None
|
||||
|
||||
ingroup_inds = IngroupIndicesFunction.apply
|
||||
Reference in New Issue
Block a user