77 lines
2.9 KiB
Python
77 lines
2.9 KiB
Python
import torch
|
|
|
|
from .vfe_template import VFETemplate
|
|
|
|
try:
|
|
import torch_scatter
|
|
except Exception as e:
|
|
# Incase someone doesn't want to use dynamic pillar vfe and hasn't installed torch_scatter
|
|
pass
|
|
|
|
from .vfe_template import VFETemplate
|
|
|
|
|
|
class DynamicMeanVFE(VFETemplate):
|
|
def __init__(self, model_cfg, num_point_features, voxel_size, grid_size, point_cloud_range, **kwargs):
|
|
super().__init__(model_cfg=model_cfg)
|
|
self.num_point_features = num_point_features
|
|
|
|
self.grid_size = torch.tensor(grid_size).cuda()
|
|
self.voxel_size = torch.tensor(voxel_size).cuda()
|
|
self.point_cloud_range = torch.tensor(point_cloud_range).cuda()
|
|
|
|
self.voxel_x = voxel_size[0]
|
|
self.voxel_y = voxel_size[1]
|
|
self.voxel_z = voxel_size[2]
|
|
self.x_offset = self.voxel_x / 2 + point_cloud_range[0]
|
|
self.y_offset = self.voxel_y / 2 + point_cloud_range[1]
|
|
self.z_offset = self.voxel_z / 2 + point_cloud_range[2]
|
|
|
|
self.scale_xyz = grid_size[0] * grid_size[1] * grid_size[2]
|
|
self.scale_yz = grid_size[1] * grid_size[2]
|
|
self.scale_z = grid_size[2]
|
|
|
|
def get_output_feature_dim(self):
|
|
return self.num_point_features
|
|
|
|
@torch.no_grad()
|
|
def forward(self, batch_dict, **kwargs):
|
|
"""
|
|
Args:
|
|
batch_dict:
|
|
voxels: (num_voxels, max_points_per_voxel, C)
|
|
voxel_num_points: optional (num_voxels)
|
|
**kwargs:
|
|
|
|
Returns:
|
|
vfe_features: (num_voxels, C)
|
|
"""
|
|
batch_size = batch_dict['batch_size']
|
|
points = batch_dict['points'] # (batch_idx, x, y, z, i, e)
|
|
|
|
# # debug
|
|
point_coords = torch.floor((points[:, 1:4] - self.point_cloud_range[0:3]) / self.voxel_size).int()
|
|
mask = ((point_coords >= 0) & (point_coords < self.grid_size)).all(dim=1)
|
|
points = points[mask]
|
|
point_coords = point_coords[mask]
|
|
merge_coords = points[:, 0].int() * self.scale_xyz + \
|
|
point_coords[:, 0] * self.scale_yz + \
|
|
point_coords[:, 1] * self.scale_z + \
|
|
point_coords[:, 2]
|
|
points_data = points[:, 1:].contiguous()
|
|
|
|
unq_coords, unq_inv, unq_cnt = torch.unique(merge_coords, return_inverse=True, return_counts=True)
|
|
|
|
points_mean = torch_scatter.scatter_mean(points_data, unq_inv, dim=0)
|
|
|
|
unq_coords = unq_coords.int()
|
|
voxel_coords = torch.stack((unq_coords // self.scale_xyz,
|
|
(unq_coords % self.scale_xyz) // self.scale_yz,
|
|
(unq_coords % self.scale_yz) // self.scale_z,
|
|
unq_coords % self.scale_z), dim=1)
|
|
voxel_coords = voxel_coords[:, [0, 3, 2, 1]]
|
|
|
|
batch_dict['voxel_features'] = points_mean.contiguous()
|
|
batch_dict['voxel_coords'] = voxel_coords.contiguous()
|
|
return batch_dict
|