diff --git a/pcdet/ops/ingroup_inds/ingroup_inds_op.py b/pcdet/ops/ingroup_inds/ingroup_inds_op.py new file mode 100644 index 0000000..5c9b6e0 --- /dev/null +++ b/pcdet/ops/ingroup_inds/ingroup_inds_op.py @@ -0,0 +1,31 @@ +import torch + +try: + from . import ingroup_inds_cuda + # import ingroup_indices +except ImportError: + ingroup_indices = None + print('Can not import ingroup indices') + +ingroup_indices = ingroup_inds_cuda + +from torch.autograd import Function +class IngroupIndicesFunction(Function): + + @staticmethod + def forward(ctx, group_inds): + + out_inds = torch.zeros_like(group_inds) - 1 + + ingroup_indices.forward(group_inds, out_inds) + + ctx.mark_non_differentiable(out_inds) + + return out_inds + + @staticmethod + def backward(ctx, g): + + return None + +ingroup_inds = IngroupIndicesFunction.apply \ No newline at end of file