Add File
This commit is contained in:
76
pcdet/models/backbones_3d/vfe/dynamic_mean_vfe.py
Normal file
76
pcdet/models/backbones_3d/vfe/dynamic_mean_vfe.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import torch
|
||||
|
||||
from .vfe_template import VFETemplate
|
||||
|
||||
try:
|
||||
import torch_scatter
|
||||
except Exception as e:
|
||||
# Incase someone doesn't want to use dynamic pillar vfe and hasn't installed torch_scatter
|
||||
pass
|
||||
|
||||
from .vfe_template import VFETemplate
|
||||
|
||||
|
||||
class DynamicMeanVFE(VFETemplate):
|
||||
def __init__(self, model_cfg, num_point_features, voxel_size, grid_size, point_cloud_range, **kwargs):
|
||||
super().__init__(model_cfg=model_cfg)
|
||||
self.num_point_features = num_point_features
|
||||
|
||||
self.grid_size = torch.tensor(grid_size).cuda()
|
||||
self.voxel_size = torch.tensor(voxel_size).cuda()
|
||||
self.point_cloud_range = torch.tensor(point_cloud_range).cuda()
|
||||
|
||||
self.voxel_x = voxel_size[0]
|
||||
self.voxel_y = voxel_size[1]
|
||||
self.voxel_z = voxel_size[2]
|
||||
self.x_offset = self.voxel_x / 2 + point_cloud_range[0]
|
||||
self.y_offset = self.voxel_y / 2 + point_cloud_range[1]
|
||||
self.z_offset = self.voxel_z / 2 + point_cloud_range[2]
|
||||
|
||||
self.scale_xyz = grid_size[0] * grid_size[1] * grid_size[2]
|
||||
self.scale_yz = grid_size[1] * grid_size[2]
|
||||
self.scale_z = grid_size[2]
|
||||
|
||||
def get_output_feature_dim(self):
|
||||
return self.num_point_features
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, batch_dict, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
batch_dict:
|
||||
voxels: (num_voxels, max_points_per_voxel, C)
|
||||
voxel_num_points: optional (num_voxels)
|
||||
**kwargs:
|
||||
|
||||
Returns:
|
||||
vfe_features: (num_voxels, C)
|
||||
"""
|
||||
batch_size = batch_dict['batch_size']
|
||||
points = batch_dict['points'] # (batch_idx, x, y, z, i, e)
|
||||
|
||||
# # debug
|
||||
point_coords = torch.floor((points[:, 1:4] - self.point_cloud_range[0:3]) / self.voxel_size).int()
|
||||
mask = ((point_coords >= 0) & (point_coords < self.grid_size)).all(dim=1)
|
||||
points = points[mask]
|
||||
point_coords = point_coords[mask]
|
||||
merge_coords = points[:, 0].int() * self.scale_xyz + \
|
||||
point_coords[:, 0] * self.scale_yz + \
|
||||
point_coords[:, 1] * self.scale_z + \
|
||||
point_coords[:, 2]
|
||||
points_data = points[:, 1:].contiguous()
|
||||
|
||||
unq_coords, unq_inv, unq_cnt = torch.unique(merge_coords, return_inverse=True, return_counts=True)
|
||||
|
||||
points_mean = torch_scatter.scatter_mean(points_data, unq_inv, dim=0)
|
||||
|
||||
unq_coords = unq_coords.int()
|
||||
voxel_coords = torch.stack((unq_coords // self.scale_xyz,
|
||||
(unq_coords % self.scale_xyz) // self.scale_yz,
|
||||
(unq_coords % self.scale_yz) // self.scale_z,
|
||||
unq_coords % self.scale_z), dim=1)
|
||||
voxel_coords = voxel_coords[:, [0, 3, 2, 1]]
|
||||
|
||||
batch_dict['voxel_features'] = points_mean.contiguous()
|
||||
batch_dict['voxel_coords'] = voxel_coords.contiguous()
|
||||
return batch_dict
|
||||
Reference in New Issue
Block a user