Add File
This commit is contained in:
@@ -0,0 +1,37 @@
|
|||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class Sampler(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, mode="bilinear", padding_mode="zeros"):
|
||||||
|
"""
|
||||||
|
Initializes module
|
||||||
|
Args:
|
||||||
|
mode: string, Sampling mode [bilinear/nearest]
|
||||||
|
padding_mode: string, Padding mode for outside grid values [zeros/border/reflection]
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.mode = mode
|
||||||
|
self.padding_mode = padding_mode
|
||||||
|
|
||||||
|
if torch.__version__ >= '1.3':
|
||||||
|
self.grid_sample = partial(F.grid_sample, align_corners=True)
|
||||||
|
else:
|
||||||
|
self.grid_sample = F.grid_sample
|
||||||
|
|
||||||
|
def forward(self, input_features, grid):
|
||||||
|
"""
|
||||||
|
Samples input using sampling grid
|
||||||
|
Args:
|
||||||
|
input_features: (B, C, D, H, W), Input frustum features
|
||||||
|
grid: (B, X, Y, Z, 3), Sampling grids for input features
|
||||||
|
Returns
|
||||||
|
output_features: (B, C, X, Y, Z) Output voxel features
|
||||||
|
"""
|
||||||
|
# Sample from grid
|
||||||
|
output = self.grid_sample(input=input_features, grid=grid, mode=self.mode, padding_mode=self.padding_mode)
|
||||||
|
return output
|
||||||
Reference in New Issue
Block a user