Add File
This commit is contained in:
97
pcdet/ops/bev_pool/bev_pool.py
Normal file
97
pcdet/ops/bev_pool/bev_pool.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import torch
|
||||
|
||||
from . import bev_pool_ext
|
||||
|
||||
__all__ = ["bev_pool"]
|
||||
|
||||
|
||||
class QuickCumsum(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, geom_feats, ranks):
|
||||
x = x.cumsum(0)
|
||||
kept = torch.ones(x.shape[0], device=x.device, dtype=torch.bool)
|
||||
kept[:-1] = ranks[1:] != ranks[:-1]
|
||||
|
||||
x, geom_feats = x[kept], geom_feats[kept]
|
||||
x = torch.cat((x[:1], x[1:] - x[:-1]))
|
||||
|
||||
# save kept for backward
|
||||
ctx.save_for_backward(kept)
|
||||
|
||||
# no gradient for geom_feats
|
||||
ctx.mark_non_differentiable(geom_feats)
|
||||
|
||||
return x, geom_feats
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gradx, gradgeom):
|
||||
(kept,) = ctx.saved_tensors
|
||||
back = torch.cumsum(kept, 0)
|
||||
back[kept] -= 1
|
||||
|
||||
val = gradx[back]
|
||||
|
||||
return val, None, None
|
||||
|
||||
|
||||
class QuickCumsumCuda(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, geom_feats, ranks, B, D, H, W):
|
||||
kept = torch.ones(x.shape[0], device=x.device, dtype=torch.bool)
|
||||
kept[1:] = ranks[1:] != ranks[:-1]
|
||||
interval_starts = torch.where(kept)[0].int()
|
||||
interval_lengths = torch.zeros_like(interval_starts)
|
||||
interval_lengths[:-1] = interval_starts[1:] - interval_starts[:-1]
|
||||
interval_lengths[-1] = x.shape[0] - interval_starts[-1]
|
||||
geom_feats = geom_feats.int()
|
||||
|
||||
out = bev_pool_ext.bev_pool_forward(
|
||||
x,
|
||||
geom_feats,
|
||||
interval_lengths,
|
||||
interval_starts,
|
||||
B,
|
||||
D,
|
||||
H,
|
||||
W,
|
||||
)
|
||||
|
||||
ctx.save_for_backward(interval_starts, interval_lengths, geom_feats)
|
||||
ctx.saved_shapes = B, D, H, W
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, out_grad):
|
||||
interval_starts, interval_lengths, geom_feats = ctx.saved_tensors
|
||||
B, D, H, W = ctx.saved_shapes
|
||||
|
||||
out_grad = out_grad.contiguous()
|
||||
x_grad = bev_pool_ext.bev_pool_backward(
|
||||
out_grad,
|
||||
geom_feats,
|
||||
interval_lengths,
|
||||
interval_starts,
|
||||
B,
|
||||
D,
|
||||
H,
|
||||
W,
|
||||
)
|
||||
|
||||
return x_grad, None, None, None, None, None, None
|
||||
|
||||
|
||||
def bev_pool(feats, coords, B, D, H, W):
|
||||
assert feats.shape[0] == coords.shape[0]
|
||||
|
||||
ranks = (
|
||||
coords[:, 0] * (W * D * B)
|
||||
+ coords[:, 1] * (D * B)
|
||||
+ coords[:, 2] * B
|
||||
+ coords[:, 3]
|
||||
)
|
||||
indices = ranks.argsort()
|
||||
feats, coords, ranks = feats[indices], coords[indices], ranks[indices]
|
||||
|
||||
x = QuickCumsumCuda.apply(feats, coords, ranks, B, D, H, W)
|
||||
x = x.permute(0, 4, 1, 2, 3).contiguous()
|
||||
return x
|
||||
Reference in New Issue
Block a user