From 3849ebad3919bd8d183e8b8ebf9ed9ebe611f985 Mon Sep 17 00:00:00 2001 From: inter Date: Sun, 21 Sep 2025 20:19:09 +0800 Subject: [PATCH] Add File --- .../pointnet2_stack/voxel_pool_modules.py | 131 ++++++++++++++++++ 1 file changed, 131 insertions(+) create mode 100644 pcdet/ops/pointnet2/pointnet2_stack/voxel_pool_modules.py diff --git a/pcdet/ops/pointnet2/pointnet2_stack/voxel_pool_modules.py b/pcdet/ops/pointnet2/pointnet2_stack/voxel_pool_modules.py new file mode 100644 index 0000000..033b5f1 --- /dev/null +++ b/pcdet/ops/pointnet2/pointnet2_stack/voxel_pool_modules.py @@ -0,0 +1,131 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from . import voxel_query_utils +from typing import List + + +class NeighborVoxelSAModuleMSG(nn.Module): + + def __init__(self, *, query_ranges: List[List[int]], radii: List[float], + nsamples: List[int], mlps: List[List[int]], use_xyz: bool = True, pool_method='max_pool'): + """ + Args: + query_ranges: list of int, list of neighbor ranges to group with + nsamples: list of int, number of samples in each ball query + mlps: list of list of int, spec of the pointnet before the global pooling for each scale + use_xyz: + pool_method: max_pool / avg_pool + """ + super().__init__() + + assert len(query_ranges) == len(nsamples) == len(mlps) + + self.groupers = nn.ModuleList() + self.mlps_in = nn.ModuleList() + self.mlps_pos = nn.ModuleList() + self.mlps_out = nn.ModuleList() + for i in range(len(query_ranges)): + max_range = query_ranges[i] + nsample = nsamples[i] + radius = radii[i] + self.groupers.append(voxel_query_utils.VoxelQueryAndGrouping(max_range, radius, nsample)) + mlp_spec = mlps[i] + + cur_mlp_in = nn.Sequential( + nn.Conv1d(mlp_spec[0], mlp_spec[1], kernel_size=1, bias=False), + nn.BatchNorm1d(mlp_spec[1]) + ) + + cur_mlp_pos = nn.Sequential( + nn.Conv2d(3, mlp_spec[1], kernel_size=1, bias=False), + nn.BatchNorm2d(mlp_spec[1]) + ) + + cur_mlp_out = nn.Sequential( + nn.Conv1d(mlp_spec[1], mlp_spec[2], kernel_size=1, bias=False), + nn.BatchNorm1d(mlp_spec[2]), + nn.ReLU() + ) + + self.mlps_in.append(cur_mlp_in) + self.mlps_pos.append(cur_mlp_pos) + self.mlps_out.append(cur_mlp_out) + + self.relu = nn.ReLU() + self.pool_method = pool_method + + self.init_weights() + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1.0) + nn.init.constant_(m.bias, 0) + + def forward(self, xyz, xyz_batch_cnt, new_xyz, new_xyz_batch_cnt, \ + new_coords, features, voxel2point_indices): + """ + :param xyz: (N1 + N2 ..., 3) tensor of the xyz coordinates of the features + :param xyz_batch_cnt: (batch_size), [N1, N2, ...] + :param new_xyz: (M1 + M2 ..., 3) + :param new_xyz_batch_cnt: (batch_size), [M1, M2, ...] + :param features: (N1 + N2 ..., C) tensor of the descriptors of the the features + :param point_indices: (B, Z, Y, X) tensor of point indices + :return: + new_xyz: (M1 + M2 ..., 3) tensor of the new features' xyz + new_features: (M1 + M2 ..., \sum_k(mlps[k][-1])) tensor of the new_features descriptors + """ + # change the order to [batch_idx, z, y, x] + new_coords = new_coords[:, [0, 3, 2, 1]].contiguous() + new_features_list = [] + for k in range(len(self.groupers)): + # features_in: (1, C, M1+M2) + features_in = features.permute(1, 0).unsqueeze(0) + features_in = self.mlps_in[k](features_in) + # features_in: (1, M1+M2, C) + features_in = features_in.permute(0, 2, 1).contiguous() + # features_in: (M1+M2, C) + features_in = features_in.view(-1, features_in.shape[-1]) + # grouped_features: (M1+M2, C, nsample) + # grouped_xyz: (M1+M2, 3, nsample) + grouped_features, grouped_xyz, empty_ball_mask = self.groupers[k]( + new_coords, xyz, xyz_batch_cnt, new_xyz, new_xyz_batch_cnt, features_in, voxel2point_indices + ) + grouped_features[empty_ball_mask] = 0 + + # grouped_features: (1, C, M1+M2, nsample) + grouped_features = grouped_features.permute(1, 0, 2).unsqueeze(dim=0) + # grouped_xyz: (M1+M2, 3, nsample) + grouped_xyz = grouped_xyz - new_xyz.unsqueeze(-1) + grouped_xyz[empty_ball_mask] = 0 + # grouped_xyz: (1, 3, M1+M2, nsample) + grouped_xyz = grouped_xyz.permute(1, 0, 2).unsqueeze(0) + # grouped_xyz: (1, C, M1+M2, nsample) + position_features = self.mlps_pos[k](grouped_xyz) + new_features = grouped_features + position_features + new_features = self.relu(new_features) + + if self.pool_method == 'max_pool': + new_features = F.max_pool2d( + new_features, kernel_size=[1, new_features.size(3)] + ).squeeze(dim=-1) # (1, C, M1 + M2 ...) + elif self.pool_method == 'avg_pool': + new_features = F.avg_pool2d( + new_features, kernel_size=[1, new_features.size(3)] + ).squeeze(dim=-1) # (1, C, M1 + M2 ...) + else: + raise NotImplementedError + + new_features = self.mlps_out[k](new_features) + new_features = new_features.squeeze(dim=0).permute(1, 0) # (M1 + M2 ..., C) + new_features_list.append(new_features) + + # (M1 + M2 ..., C) + new_features = torch.cat(new_features_list, dim=1) + return new_features +