diff --git a/pcdet/ops/bev_pool/bev_pool.py b/pcdet/ops/bev_pool/bev_pool.py new file mode 100644 index 0000000..5769a40 --- /dev/null +++ b/pcdet/ops/bev_pool/bev_pool.py @@ -0,0 +1,97 @@ +import torch + +from . import bev_pool_ext + +__all__ = ["bev_pool"] + + +class QuickCumsum(torch.autograd.Function): + @staticmethod + def forward(ctx, x, geom_feats, ranks): + x = x.cumsum(0) + kept = torch.ones(x.shape[0], device=x.device, dtype=torch.bool) + kept[:-1] = ranks[1:] != ranks[:-1] + + x, geom_feats = x[kept], geom_feats[kept] + x = torch.cat((x[:1], x[1:] - x[:-1])) + + # save kept for backward + ctx.save_for_backward(kept) + + # no gradient for geom_feats + ctx.mark_non_differentiable(geom_feats) + + return x, geom_feats + + @staticmethod + def backward(ctx, gradx, gradgeom): + (kept,) = ctx.saved_tensors + back = torch.cumsum(kept, 0) + back[kept] -= 1 + + val = gradx[back] + + return val, None, None + + +class QuickCumsumCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, geom_feats, ranks, B, D, H, W): + kept = torch.ones(x.shape[0], device=x.device, dtype=torch.bool) + kept[1:] = ranks[1:] != ranks[:-1] + interval_starts = torch.where(kept)[0].int() + interval_lengths = torch.zeros_like(interval_starts) + interval_lengths[:-1] = interval_starts[1:] - interval_starts[:-1] + interval_lengths[-1] = x.shape[0] - interval_starts[-1] + geom_feats = geom_feats.int() + + out = bev_pool_ext.bev_pool_forward( + x, + geom_feats, + interval_lengths, + interval_starts, + B, + D, + H, + W, + ) + + ctx.save_for_backward(interval_starts, interval_lengths, geom_feats) + ctx.saved_shapes = B, D, H, W + return out + + @staticmethod + def backward(ctx, out_grad): + interval_starts, interval_lengths, geom_feats = ctx.saved_tensors + B, D, H, W = ctx.saved_shapes + + out_grad = out_grad.contiguous() + x_grad = bev_pool_ext.bev_pool_backward( + out_grad, + geom_feats, + interval_lengths, + interval_starts, + B, + D, + H, + W, + ) + + return x_grad, None, None, None, None, None, None + + +def bev_pool(feats, coords, B, D, H, W): + assert feats.shape[0] == coords.shape[0] + + ranks = ( + coords[:, 0] * (W * D * B) + + coords[:, 1] * (D * B) + + coords[:, 2] * B + + coords[:, 3] + ) + indices = ranks.argsort() + feats, coords, ranks = feats[indices], coords[indices], ranks[indices] + + x = QuickCumsumCuda.apply(feats, coords, ranks, B, D, H, W) + x = x.permute(0, 4, 1, 2, 3).contiguous() + return x