Files
OpenPCDet/pcdet/models/backbones_3d/vfe/dynamic_mean_vfe.py

77 lines
2.9 KiB
Python
Raw Normal View History

2025-09-21 20:19:00 +08:00
import torch
from .vfe_template import VFETemplate
try:
import torch_scatter
except Exception as e:
# Incase someone doesn't want to use dynamic pillar vfe and hasn't installed torch_scatter
pass
from .vfe_template import VFETemplate
class DynamicMeanVFE(VFETemplate):
def __init__(self, model_cfg, num_point_features, voxel_size, grid_size, point_cloud_range, **kwargs):
super().__init__(model_cfg=model_cfg)
self.num_point_features = num_point_features
self.grid_size = torch.tensor(grid_size).cuda()
self.voxel_size = torch.tensor(voxel_size).cuda()
self.point_cloud_range = torch.tensor(point_cloud_range).cuda()
self.voxel_x = voxel_size[0]
self.voxel_y = voxel_size[1]
self.voxel_z = voxel_size[2]
self.x_offset = self.voxel_x / 2 + point_cloud_range[0]
self.y_offset = self.voxel_y / 2 + point_cloud_range[1]
self.z_offset = self.voxel_z / 2 + point_cloud_range[2]
self.scale_xyz = grid_size[0] * grid_size[1] * grid_size[2]
self.scale_yz = grid_size[1] * grid_size[2]
self.scale_z = grid_size[2]
def get_output_feature_dim(self):
return self.num_point_features
@torch.no_grad()
def forward(self, batch_dict, **kwargs):
"""
Args:
batch_dict:
voxels: (num_voxels, max_points_per_voxel, C)
voxel_num_points: optional (num_voxels)
**kwargs:
Returns:
vfe_features: (num_voxels, C)
"""
batch_size = batch_dict['batch_size']
points = batch_dict['points'] # (batch_idx, x, y, z, i, e)
# # debug
point_coords = torch.floor((points[:, 1:4] - self.point_cloud_range[0:3]) / self.voxel_size).int()
mask = ((point_coords >= 0) & (point_coords < self.grid_size)).all(dim=1)
points = points[mask]
point_coords = point_coords[mask]
merge_coords = points[:, 0].int() * self.scale_xyz + \
point_coords[:, 0] * self.scale_yz + \
point_coords[:, 1] * self.scale_z + \
point_coords[:, 2]
points_data = points[:, 1:].contiguous()
unq_coords, unq_inv, unq_cnt = torch.unique(merge_coords, return_inverse=True, return_counts=True)
points_mean = torch_scatter.scatter_mean(points_data, unq_inv, dim=0)
unq_coords = unq_coords.int()
voxel_coords = torch.stack((unq_coords // self.scale_xyz,
(unq_coords % self.scale_xyz) // self.scale_yz,
(unq_coords % self.scale_yz) // self.scale_z,
unq_coords % self.scale_z), dim=1)
voxel_coords = voxel_coords[:, [0, 3, 2, 1]]
batch_dict['voxel_features'] = points_mean.contiguous()
batch_dict['voxel_coords'] = voxel_coords.contiguous()
return batch_dict