This commit is contained in:
2025-09-21 20:19:25 +08:00
parent 2ae4c63435
commit 202a0e1522

View File

@@ -0,0 +1,31 @@
import torch
try:
from . import ingroup_inds_cuda
# import ingroup_indices
except ImportError:
ingroup_indices = None
print('Can not import ingroup indices')
ingroup_indices = ingroup_inds_cuda
from torch.autograd import Function
class IngroupIndicesFunction(Function):
@staticmethod
def forward(ctx, group_inds):
out_inds = torch.zeros_like(group_inds) - 1
ingroup_indices.forward(group_inds, out_inds)
ctx.mark_non_differentiable(out_inds)
return out_inds
@staticmethod
def backward(ctx, g):
return None
ingroup_inds = IngroupIndicesFunction.apply