import torch import torch.nn as nn from ...ops.pointnet2.pointnet2_batch import pointnet2_modules from ...ops.pointnet2.pointnet2_stack import pointnet2_modules as pointnet2_modules_stack from ...ops.pointnet2.pointnet2_stack import pointnet2_utils as pointnet2_utils_stack class PointNet2MSG(nn.Module): def __init__(self, model_cfg, input_channels, **kwargs): super().__init__() self.model_cfg = model_cfg self.SA_modules = nn.ModuleList() channel_in = input_channels - 3 self.num_points_each_layer = [] skip_channel_list = [input_channels - 3] for k in range(self.model_cfg.SA_CONFIG.NPOINTS.__len__()): mlps = self.model_cfg.SA_CONFIG.MLPS[k].copy() channel_out = 0 for idx in range(mlps.__len__()): mlps[idx] = [channel_in] + mlps[idx] channel_out += mlps[idx][-1] self.SA_modules.append( pointnet2_modules.PointnetSAModuleMSG( npoint=self.model_cfg.SA_CONFIG.NPOINTS[k], radii=self.model_cfg.SA_CONFIG.RADIUS[k], nsamples=self.model_cfg.SA_CONFIG.NSAMPLE[k], mlps=mlps, use_xyz=self.model_cfg.SA_CONFIG.get('USE_XYZ', True), ) ) skip_channel_list.append(channel_out) channel_in = channel_out self.FP_modules = nn.ModuleList() for k in range(self.model_cfg.FP_MLPS.__len__()): pre_channel = self.model_cfg.FP_MLPS[k + 1][-1] if k + 1 < len(self.model_cfg.FP_MLPS) else channel_out self.FP_modules.append( pointnet2_modules.PointnetFPModule( mlp=[pre_channel + skip_channel_list[k]] + self.model_cfg.FP_MLPS[k] ) ) self.num_point_features = self.model_cfg.FP_MLPS[0][-1] def break_up_pc(self, pc): batch_idx = pc[:, 0] xyz = pc[:, 1:4].contiguous() features = (pc[:, 4:].contiguous() if pc.size(-1) > 4 else None) return batch_idx, xyz, features def forward(self, batch_dict): """ Args: batch_dict: batch_size: int vfe_features: (num_voxels, C) points: (num_points, 4 + C), [batch_idx, x, y, z, ...] Returns: batch_dict: encoded_spconv_tensor: sparse tensor point_features: (N, C) """ batch_size = batch_dict['batch_size'] points = batch_dict['points'] batch_idx, xyz, features = self.break_up_pc(points) xyz_batch_cnt = xyz.new_zeros(batch_size).int() for bs_idx in range(batch_size): xyz_batch_cnt[bs_idx] = (batch_idx == bs_idx).sum() assert xyz_batch_cnt.min() == xyz_batch_cnt.max() xyz = xyz.view(batch_size, -1, 3) features = features.view(batch_size, -1, features.shape[-1]).permute(0, 2, 1).contiguous() if features is not None else None l_xyz, l_features = [xyz], [features] for i in range(len(self.SA_modules)): li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i]) l_xyz.append(li_xyz) l_features.append(li_features) for i in range(-1, -(len(self.FP_modules) + 1), -1): l_features[i - 1] = self.FP_modules[i]( l_xyz[i - 1], l_xyz[i], l_features[i - 1], l_features[i] ) # (B, C, N) point_features = l_features[0].permute(0, 2, 1).contiguous() # (B, N, C) batch_dict['point_features'] = point_features.view(-1, point_features.shape[-1]) batch_dict['point_coords'] = torch.cat((batch_idx[:, None].float(), l_xyz[0].view(-1, 3)), dim=1) return batch_dict class PointNet2Backbone(nn.Module): """ DO NOT USE THIS CURRENTLY SINCE IT MAY HAVE POTENTIAL BUGS, 20200723 """ def __init__(self, model_cfg, input_channels, **kwargs): assert False, 'DO NOT USE THIS CURRENTLY SINCE IT MAY HAVE POTENTIAL BUGS, 20200723' super().__init__() self.model_cfg = model_cfg self.SA_modules = nn.ModuleList() channel_in = input_channels - 3 self.num_points_each_layer = [] skip_channel_list = [input_channels] for k in range(self.model_cfg.SA_CONFIG.NPOINTS.__len__()): self.num_points_each_layer.append(self.model_cfg.SA_CONFIG.NPOINTS[k]) mlps = self.model_cfg.SA_CONFIG.MLPS[k].copy() channel_out = 0 for idx in range(mlps.__len__()): mlps[idx] = [channel_in] + mlps[idx] channel_out += mlps[idx][-1] self.SA_modules.append( pointnet2_modules_stack.StackSAModuleMSG( radii=self.model_cfg.SA_CONFIG.RADIUS[k], nsamples=self.model_cfg.SA_CONFIG.NSAMPLE[k], mlps=mlps, use_xyz=self.model_cfg.SA_CONFIG.get('USE_XYZ', True), ) ) skip_channel_list.append(channel_out) channel_in = channel_out self.FP_modules = nn.ModuleList() for k in range(self.model_cfg.FP_MLPS.__len__()): pre_channel = self.model_cfg.FP_MLPS[k + 1][-1] if k + 1 < len(self.model_cfg.FP_MLPS) else channel_out self.FP_modules.append( pointnet2_modules_stack.StackPointnetFPModule( mlp=[pre_channel + skip_channel_list[k]] + self.model_cfg.FP_MLPS[k] ) ) self.num_point_features = self.model_cfg.FP_MLPS[0][-1] def break_up_pc(self, pc): batch_idx = pc[:, 0] xyz = pc[:, 1:4].contiguous() features = (pc[:, 4:].contiguous() if pc.size(-1) > 4 else None) return batch_idx, xyz, features def forward(self, batch_dict): """ Args: batch_dict: batch_size: int vfe_features: (num_voxels, C) points: (num_points, 4 + C), [batch_idx, x, y, z, ...] Returns: batch_dict: encoded_spconv_tensor: sparse tensor point_features: (N, C) """ batch_size = batch_dict['batch_size'] points = batch_dict['points'] batch_idx, xyz, features = self.break_up_pc(points) xyz_batch_cnt = xyz.new_zeros(batch_size).int() for bs_idx in range(batch_size): xyz_batch_cnt[bs_idx] = (batch_idx == bs_idx).sum() l_xyz, l_features, l_batch_cnt = [xyz], [features], [xyz_batch_cnt] for i in range(len(self.SA_modules)): new_xyz_list = [] for k in range(batch_size): if len(l_xyz) == 1: cur_xyz = l_xyz[0][batch_idx == k] else: last_num_points = self.num_points_each_layer[i - 1] cur_xyz = l_xyz[-1][k * last_num_points: (k + 1) * last_num_points] cur_pt_idxs = pointnet2_utils_stack.farthest_point_sample( cur_xyz[None, :, :].contiguous(), self.num_points_each_layer[i] ).long()[0] if cur_xyz.shape[0] < self.num_points_each_layer[i]: empty_num = self.num_points_each_layer[i] - cur_xyz.shape[1] cur_pt_idxs[0, -empty_num:] = cur_pt_idxs[0, :empty_num] new_xyz_list.append(cur_xyz[cur_pt_idxs]) new_xyz = torch.cat(new_xyz_list, dim=0) new_xyz_batch_cnt = xyz.new_zeros(batch_size).int().fill_(self.num_points_each_layer[i]) li_xyz, li_features = self.SA_modules[i]( xyz=l_xyz[i], features=l_features[i], xyz_batch_cnt=l_batch_cnt[i], new_xyz=new_xyz, new_xyz_batch_cnt=new_xyz_batch_cnt ) l_xyz.append(li_xyz) l_features.append(li_features) l_batch_cnt.append(new_xyz_batch_cnt) l_features[0] = points[:, 1:] for i in range(-1, -(len(self.FP_modules) + 1), -1): l_features[i - 1] = self.FP_modules[i]( unknown=l_xyz[i - 1], unknown_batch_cnt=l_batch_cnt[i - 1], known=l_xyz[i], known_batch_cnt=l_batch_cnt[i], unknown_feats=l_features[i - 1], known_feats=l_features[i] ) batch_dict['point_features'] = l_features[0] batch_dict['point_coords'] = torch.cat((batch_idx[:, None].float(), l_xyz[0]), dim=1) return batch_dict