From 34650b3730f75336d7f4cae0a92fe79414c88674 Mon Sep 17 00:00:00 2001 From: inter Date: Sun, 21 Sep 2025 20:18:52 +0800 Subject: [PATCH] Add File --- .../models/backbones_3d/pointnet2_backbone.py | 206 ++++++++++++++++++ 1 file changed, 206 insertions(+) create mode 100644 pcdet/models/backbones_3d/pointnet2_backbone.py diff --git a/pcdet/models/backbones_3d/pointnet2_backbone.py b/pcdet/models/backbones_3d/pointnet2_backbone.py new file mode 100644 index 0000000..fdd0d02 --- /dev/null +++ b/pcdet/models/backbones_3d/pointnet2_backbone.py @@ -0,0 +1,206 @@ +import torch +import torch.nn as nn + +from ...ops.pointnet2.pointnet2_batch import pointnet2_modules +from ...ops.pointnet2.pointnet2_stack import pointnet2_modules as pointnet2_modules_stack +from ...ops.pointnet2.pointnet2_stack import pointnet2_utils as pointnet2_utils_stack + + +class PointNet2MSG(nn.Module): + def __init__(self, model_cfg, input_channels, **kwargs): + super().__init__() + self.model_cfg = model_cfg + + self.SA_modules = nn.ModuleList() + channel_in = input_channels - 3 + + self.num_points_each_layer = [] + skip_channel_list = [input_channels - 3] + for k in range(self.model_cfg.SA_CONFIG.NPOINTS.__len__()): + mlps = self.model_cfg.SA_CONFIG.MLPS[k].copy() + channel_out = 0 + for idx in range(mlps.__len__()): + mlps[idx] = [channel_in] + mlps[idx] + channel_out += mlps[idx][-1] + + self.SA_modules.append( + pointnet2_modules.PointnetSAModuleMSG( + npoint=self.model_cfg.SA_CONFIG.NPOINTS[k], + radii=self.model_cfg.SA_CONFIG.RADIUS[k], + nsamples=self.model_cfg.SA_CONFIG.NSAMPLE[k], + mlps=mlps, + use_xyz=self.model_cfg.SA_CONFIG.get('USE_XYZ', True), + ) + ) + skip_channel_list.append(channel_out) + channel_in = channel_out + + self.FP_modules = nn.ModuleList() + + for k in range(self.model_cfg.FP_MLPS.__len__()): + pre_channel = self.model_cfg.FP_MLPS[k + 1][-1] if k + 1 < len(self.model_cfg.FP_MLPS) else channel_out + self.FP_modules.append( + pointnet2_modules.PointnetFPModule( + mlp=[pre_channel + skip_channel_list[k]] + self.model_cfg.FP_MLPS[k] + ) + ) + + self.num_point_features = self.model_cfg.FP_MLPS[0][-1] + + def break_up_pc(self, pc): + batch_idx = pc[:, 0] + xyz = pc[:, 1:4].contiguous() + features = (pc[:, 4:].contiguous() if pc.size(-1) > 4 else None) + return batch_idx, xyz, features + + def forward(self, batch_dict): + """ + Args: + batch_dict: + batch_size: int + vfe_features: (num_voxels, C) + points: (num_points, 4 + C), [batch_idx, x, y, z, ...] + Returns: + batch_dict: + encoded_spconv_tensor: sparse tensor + point_features: (N, C) + """ + batch_size = batch_dict['batch_size'] + points = batch_dict['points'] + batch_idx, xyz, features = self.break_up_pc(points) + + xyz_batch_cnt = xyz.new_zeros(batch_size).int() + for bs_idx in range(batch_size): + xyz_batch_cnt[bs_idx] = (batch_idx == bs_idx).sum() + + assert xyz_batch_cnt.min() == xyz_batch_cnt.max() + xyz = xyz.view(batch_size, -1, 3) + features = features.view(batch_size, -1, features.shape[-1]).permute(0, 2, 1).contiguous() if features is not None else None + + l_xyz, l_features = [xyz], [features] + for i in range(len(self.SA_modules)): + li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i]) + l_xyz.append(li_xyz) + l_features.append(li_features) + + for i in range(-1, -(len(self.FP_modules) + 1), -1): + l_features[i - 1] = self.FP_modules[i]( + l_xyz[i - 1], l_xyz[i], l_features[i - 1], l_features[i] + ) # (B, C, N) + + point_features = l_features[0].permute(0, 2, 1).contiguous() # (B, N, C) + batch_dict['point_features'] = point_features.view(-1, point_features.shape[-1]) + batch_dict['point_coords'] = torch.cat((batch_idx[:, None].float(), l_xyz[0].view(-1, 3)), dim=1) + return batch_dict + + +class PointNet2Backbone(nn.Module): + """ + DO NOT USE THIS CURRENTLY SINCE IT MAY HAVE POTENTIAL BUGS, 20200723 + """ + def __init__(self, model_cfg, input_channels, **kwargs): + assert False, 'DO NOT USE THIS CURRENTLY SINCE IT MAY HAVE POTENTIAL BUGS, 20200723' + super().__init__() + self.model_cfg = model_cfg + + self.SA_modules = nn.ModuleList() + channel_in = input_channels - 3 + + self.num_points_each_layer = [] + skip_channel_list = [input_channels] + for k in range(self.model_cfg.SA_CONFIG.NPOINTS.__len__()): + self.num_points_each_layer.append(self.model_cfg.SA_CONFIG.NPOINTS[k]) + mlps = self.model_cfg.SA_CONFIG.MLPS[k].copy() + channel_out = 0 + for idx in range(mlps.__len__()): + mlps[idx] = [channel_in] + mlps[idx] + channel_out += mlps[idx][-1] + + self.SA_modules.append( + pointnet2_modules_stack.StackSAModuleMSG( + radii=self.model_cfg.SA_CONFIG.RADIUS[k], + nsamples=self.model_cfg.SA_CONFIG.NSAMPLE[k], + mlps=mlps, + use_xyz=self.model_cfg.SA_CONFIG.get('USE_XYZ', True), + ) + ) + skip_channel_list.append(channel_out) + channel_in = channel_out + + self.FP_modules = nn.ModuleList() + + for k in range(self.model_cfg.FP_MLPS.__len__()): + pre_channel = self.model_cfg.FP_MLPS[k + 1][-1] if k + 1 < len(self.model_cfg.FP_MLPS) else channel_out + self.FP_modules.append( + pointnet2_modules_stack.StackPointnetFPModule( + mlp=[pre_channel + skip_channel_list[k]] + self.model_cfg.FP_MLPS[k] + ) + ) + + self.num_point_features = self.model_cfg.FP_MLPS[0][-1] + + def break_up_pc(self, pc): + batch_idx = pc[:, 0] + xyz = pc[:, 1:4].contiguous() + features = (pc[:, 4:].contiguous() if pc.size(-1) > 4 else None) + return batch_idx, xyz, features + + def forward(self, batch_dict): + """ + Args: + batch_dict: + batch_size: int + vfe_features: (num_voxels, C) + points: (num_points, 4 + C), [batch_idx, x, y, z, ...] + Returns: + batch_dict: + encoded_spconv_tensor: sparse tensor + point_features: (N, C) + """ + batch_size = batch_dict['batch_size'] + points = batch_dict['points'] + batch_idx, xyz, features = self.break_up_pc(points) + + xyz_batch_cnt = xyz.new_zeros(batch_size).int() + for bs_idx in range(batch_size): + xyz_batch_cnt[bs_idx] = (batch_idx == bs_idx).sum() + + l_xyz, l_features, l_batch_cnt = [xyz], [features], [xyz_batch_cnt] + for i in range(len(self.SA_modules)): + new_xyz_list = [] + for k in range(batch_size): + if len(l_xyz) == 1: + cur_xyz = l_xyz[0][batch_idx == k] + else: + last_num_points = self.num_points_each_layer[i - 1] + cur_xyz = l_xyz[-1][k * last_num_points: (k + 1) * last_num_points] + cur_pt_idxs = pointnet2_utils_stack.farthest_point_sample( + cur_xyz[None, :, :].contiguous(), self.num_points_each_layer[i] + ).long()[0] + if cur_xyz.shape[0] < self.num_points_each_layer[i]: + empty_num = self.num_points_each_layer[i] - cur_xyz.shape[1] + cur_pt_idxs[0, -empty_num:] = cur_pt_idxs[0, :empty_num] + new_xyz_list.append(cur_xyz[cur_pt_idxs]) + new_xyz = torch.cat(new_xyz_list, dim=0) + + new_xyz_batch_cnt = xyz.new_zeros(batch_size).int().fill_(self.num_points_each_layer[i]) + li_xyz, li_features = self.SA_modules[i]( + xyz=l_xyz[i], features=l_features[i], xyz_batch_cnt=l_batch_cnt[i], + new_xyz=new_xyz, new_xyz_batch_cnt=new_xyz_batch_cnt + ) + + l_xyz.append(li_xyz) + l_features.append(li_features) + l_batch_cnt.append(new_xyz_batch_cnt) + + l_features[0] = points[:, 1:] + for i in range(-1, -(len(self.FP_modules) + 1), -1): + l_features[i - 1] = self.FP_modules[i]( + unknown=l_xyz[i - 1], unknown_batch_cnt=l_batch_cnt[i - 1], + known=l_xyz[i], known_batch_cnt=l_batch_cnt[i], + unknown_feats=l_features[i - 1], known_feats=l_features[i] + ) + + batch_dict['point_features'] = l_features[0] + batch_dict['point_coords'] = torch.cat((batch_idx[:, None].float(), l_xyz[0]), dim=1) + return batch_dict