From 5ecd5bb51757ba1d2e5127db87eed50179546665 Mon Sep 17 00:00:00 2001 From: inter Date: Sun, 21 Sep 2025 20:19:00 +0800 Subject: [PATCH] Add File --- .../backbones_3d/vfe/dynamic_mean_vfe.py | 76 +++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 pcdet/models/backbones_3d/vfe/dynamic_mean_vfe.py diff --git a/pcdet/models/backbones_3d/vfe/dynamic_mean_vfe.py b/pcdet/models/backbones_3d/vfe/dynamic_mean_vfe.py new file mode 100644 index 0000000..b4c5b06 --- /dev/null +++ b/pcdet/models/backbones_3d/vfe/dynamic_mean_vfe.py @@ -0,0 +1,76 @@ +import torch + +from .vfe_template import VFETemplate + +try: + import torch_scatter +except Exception as e: + # Incase someone doesn't want to use dynamic pillar vfe and hasn't installed torch_scatter + pass + +from .vfe_template import VFETemplate + + +class DynamicMeanVFE(VFETemplate): + def __init__(self, model_cfg, num_point_features, voxel_size, grid_size, point_cloud_range, **kwargs): + super().__init__(model_cfg=model_cfg) + self.num_point_features = num_point_features + + self.grid_size = torch.tensor(grid_size).cuda() + self.voxel_size = torch.tensor(voxel_size).cuda() + self.point_cloud_range = torch.tensor(point_cloud_range).cuda() + + self.voxel_x = voxel_size[0] + self.voxel_y = voxel_size[1] + self.voxel_z = voxel_size[2] + self.x_offset = self.voxel_x / 2 + point_cloud_range[0] + self.y_offset = self.voxel_y / 2 + point_cloud_range[1] + self.z_offset = self.voxel_z / 2 + point_cloud_range[2] + + self.scale_xyz = grid_size[0] * grid_size[1] * grid_size[2] + self.scale_yz = grid_size[1] * grid_size[2] + self.scale_z = grid_size[2] + + def get_output_feature_dim(self): + return self.num_point_features + + @torch.no_grad() + def forward(self, batch_dict, **kwargs): + """ + Args: + batch_dict: + voxels: (num_voxels, max_points_per_voxel, C) + voxel_num_points: optional (num_voxels) + **kwargs: + + Returns: + vfe_features: (num_voxels, C) + """ + batch_size = batch_dict['batch_size'] + points = batch_dict['points'] # (batch_idx, x, y, z, i, e) + + # # debug + point_coords = torch.floor((points[:, 1:4] - self.point_cloud_range[0:3]) / self.voxel_size).int() + mask = ((point_coords >= 0) & (point_coords < self.grid_size)).all(dim=1) + points = points[mask] + point_coords = point_coords[mask] + merge_coords = points[:, 0].int() * self.scale_xyz + \ + point_coords[:, 0] * self.scale_yz + \ + point_coords[:, 1] * self.scale_z + \ + point_coords[:, 2] + points_data = points[:, 1:].contiguous() + + unq_coords, unq_inv, unq_cnt = torch.unique(merge_coords, return_inverse=True, return_counts=True) + + points_mean = torch_scatter.scatter_mean(points_data, unq_inv, dim=0) + + unq_coords = unq_coords.int() + voxel_coords = torch.stack((unq_coords // self.scale_xyz, + (unq_coords % self.scale_xyz) // self.scale_yz, + (unq_coords % self.scale_yz) // self.scale_z, + unq_coords % self.scale_z), dim=1) + voxel_coords = voxel_coords[:, [0, 3, 2, 1]] + + batch_dict['voxel_features'] = points_mean.contiguous() + batch_dict['voxel_coords'] = voxel_coords.contiguous() + return batch_dict