diff --git a/pcdet/models/backbones_3d/pfe/voxel_set_abstraction.py b/pcdet/models/backbones_3d/pfe/voxel_set_abstraction.py new file mode 100644 index 0000000..0f3b8ae --- /dev/null +++ b/pcdet/models/backbones_3d/pfe/voxel_set_abstraction.py @@ -0,0 +1,411 @@ +import math +import numpy as np +import torch +import torch.nn as nn + +from ....ops.pointnet2.pointnet2_stack import pointnet2_modules as pointnet2_stack_modules +from ....ops.pointnet2.pointnet2_stack import pointnet2_utils as pointnet2_stack_utils +from ....utils import common_utils + + +def bilinear_interpolate_torch(im, x, y): + """ + Args: + im: (H, W, C) [y, x] + x: (N) + y: (N) + + Returns: + + """ + x0 = torch.floor(x).long() + x1 = x0 + 1 + + y0 = torch.floor(y).long() + y1 = y0 + 1 + + x0 = torch.clamp(x0, 0, im.shape[1] - 1) + x1 = torch.clamp(x1, 0, im.shape[1] - 1) + y0 = torch.clamp(y0, 0, im.shape[0] - 1) + y1 = torch.clamp(y1, 0, im.shape[0] - 1) + + Ia = im[y0, x0] + Ib = im[y1, x0] + Ic = im[y0, x1] + Id = im[y1, x1] + + wa = (x1.type_as(x) - x) * (y1.type_as(y) - y) + wb = (x1.type_as(x) - x) * (y - y0.type_as(y)) + wc = (x - x0.type_as(x)) * (y1.type_as(y) - y) + wd = (x - x0.type_as(x)) * (y - y0.type_as(y)) + ans = torch.t((torch.t(Ia) * wa)) + torch.t(torch.t(Ib) * wb) + torch.t(torch.t(Ic) * wc) + torch.t(torch.t(Id) * wd) + return ans + + +def sample_points_with_roi(rois, points, sample_radius_with_roi, num_max_points_of_part=200000): + """ + Args: + rois: (M, 7 + C) + points: (N, 3) + sample_radius_with_roi: + num_max_points_of_part: + + Returns: + sampled_points: (N_out, 3) + """ + if points.shape[0] < num_max_points_of_part: + distance = (points[:, None, :] - rois[None, :, 0:3]).norm(dim=-1) + min_dis, min_dis_roi_idx = distance.min(dim=-1) + roi_max_dim = (rois[min_dis_roi_idx, 3:6] / 2).norm(dim=-1) + point_mask = min_dis < roi_max_dim + sample_radius_with_roi + else: + start_idx = 0 + point_mask_list = [] + while start_idx < points.shape[0]: + distance = (points[start_idx:start_idx + num_max_points_of_part, None, :] - rois[None, :, 0:3]).norm(dim=-1) + min_dis, min_dis_roi_idx = distance.min(dim=-1) + roi_max_dim = (rois[min_dis_roi_idx, 3:6] / 2).norm(dim=-1) + cur_point_mask = min_dis < roi_max_dim + sample_radius_with_roi + point_mask_list.append(cur_point_mask) + start_idx += num_max_points_of_part + point_mask = torch.cat(point_mask_list, dim=0) + + sampled_points = points[:1] if point_mask.sum() == 0 else points[point_mask, :] + + return sampled_points, point_mask + + +def sector_fps(points, num_sampled_points, num_sectors): + """ + Args: + points: (N, 3) + num_sampled_points: int + num_sectors: int + + Returns: + sampled_points: (N_out, 3) + """ + sector_size = np.pi * 2 / num_sectors + point_angles = torch.atan2(points[:, 1], points[:, 0]) + np.pi + sector_idx = (point_angles / sector_size).floor().clamp(min=0, max=num_sectors) + xyz_points_list = [] + xyz_batch_cnt = [] + num_sampled_points_list = [] + for k in range(num_sectors): + mask = (sector_idx == k) + cur_num_points = mask.sum().item() + if cur_num_points > 0: + xyz_points_list.append(points[mask]) + xyz_batch_cnt.append(cur_num_points) + ratio = cur_num_points / points.shape[0] + num_sampled_points_list.append( + min(cur_num_points, math.ceil(ratio * num_sampled_points)) + ) + + if len(xyz_batch_cnt) == 0: + xyz_points_list.append(points) + xyz_batch_cnt.append(len(points)) + num_sampled_points_list.append(num_sampled_points) + print(f'Warning: empty sector points detected in SectorFPS: points.shape={points.shape}') + + xyz = torch.cat(xyz_points_list, dim=0) + xyz_batch_cnt = torch.tensor(xyz_batch_cnt, device=points.device).int() + sampled_points_batch_cnt = torch.tensor(num_sampled_points_list, device=points.device).int() + + sampled_pt_idxs = pointnet2_stack_utils.stack_farthest_point_sample( + xyz.contiguous(), xyz_batch_cnt, sampled_points_batch_cnt + ).long() + + sampled_points = xyz[sampled_pt_idxs] + + return sampled_points + + +class VoxelSetAbstraction(nn.Module): + def __init__(self, model_cfg, voxel_size, point_cloud_range, num_bev_features=None, + num_rawpoint_features=None, **kwargs): + super().__init__() + self.model_cfg = model_cfg + self.voxel_size = voxel_size + self.point_cloud_range = point_cloud_range + + SA_cfg = self.model_cfg.SA_LAYER + + self.SA_layers = nn.ModuleList() + self.SA_layer_names = [] + self.downsample_times_map = {} + c_in = 0 + for src_name in self.model_cfg.FEATURES_SOURCE: + if src_name in ['bev', 'raw_points']: + continue + self.downsample_times_map[src_name] = SA_cfg[src_name].DOWNSAMPLE_FACTOR + + if SA_cfg[src_name].get('INPUT_CHANNELS', None) is None: + input_channels = SA_cfg[src_name].MLPS[0][0] \ + if isinstance(SA_cfg[src_name].MLPS[0], list) else SA_cfg[src_name].MLPS[0] + else: + input_channels = SA_cfg[src_name]['INPUT_CHANNELS'] + + cur_layer, cur_num_c_out = pointnet2_stack_modules.build_local_aggregation_module( + input_channels=input_channels, config=SA_cfg[src_name] + ) + self.SA_layers.append(cur_layer) + self.SA_layer_names.append(src_name) + + c_in += cur_num_c_out + + if 'bev' in self.model_cfg.FEATURES_SOURCE: + c_bev = num_bev_features + c_in += c_bev + + if 'raw_points' in self.model_cfg.FEATURES_SOURCE: + self.SA_rawpoints, cur_num_c_out = pointnet2_stack_modules.build_local_aggregation_module( + input_channels=num_rawpoint_features - 3, config=SA_cfg['raw_points'] + ) + + c_in += cur_num_c_out + + self.vsa_point_feature_fusion = nn.Sequential( + nn.Linear(c_in, self.model_cfg.NUM_OUTPUT_FEATURES, bias=False), + nn.BatchNorm1d(self.model_cfg.NUM_OUTPUT_FEATURES), + nn.ReLU(), + ) + self.num_point_features = self.model_cfg.NUM_OUTPUT_FEATURES + self.num_point_features_before_fusion = c_in + + def interpolate_from_bev_features(self, keypoints, bev_features, batch_size, bev_stride): + """ + Args: + keypoints: (N1 + N2 + ..., 4) + bev_features: (B, C, H, W) + batch_size: + bev_stride: + + Returns: + point_bev_features: (N1 + N2 + ..., C) + """ + x_idxs = (keypoints[:, 1] - self.point_cloud_range[0]) / self.voxel_size[0] + y_idxs = (keypoints[:, 2] - self.point_cloud_range[1]) / self.voxel_size[1] + + x_idxs = x_idxs / bev_stride + y_idxs = y_idxs / bev_stride + + point_bev_features_list = [] + for k in range(batch_size): + bs_mask = (keypoints[:, 0] == k) + + cur_x_idxs = x_idxs[bs_mask] + cur_y_idxs = y_idxs[bs_mask] + cur_bev_features = bev_features[k].permute(1, 2, 0) # (H, W, C) + point_bev_features = bilinear_interpolate_torch(cur_bev_features, cur_x_idxs, cur_y_idxs) + point_bev_features_list.append(point_bev_features) + + point_bev_features = torch.cat(point_bev_features_list, dim=0) # (N1 + N2 + ..., C) + return point_bev_features + + def sectorized_proposal_centric_sampling(self, roi_boxes, points): + """ + Args: + roi_boxes: (M, 7 + C) + points: (N, 3) + + Returns: + sampled_points: (N_out, 3) + """ + + sampled_points, _ = sample_points_with_roi( + rois=roi_boxes, points=points, + sample_radius_with_roi=self.model_cfg.SPC_SAMPLING.SAMPLE_RADIUS_WITH_ROI, + num_max_points_of_part=self.model_cfg.SPC_SAMPLING.get('NUM_POINTS_OF_EACH_SAMPLE_PART', 200000) + ) + sampled_points = sector_fps( + points=sampled_points, num_sampled_points=self.model_cfg.NUM_KEYPOINTS, + num_sectors=self.model_cfg.SPC_SAMPLING.NUM_SECTORS + ) + return sampled_points + + def get_sampled_points(self, batch_dict): + """ + Args: + batch_dict: + + Returns: + keypoints: (N1 + N2 + ..., 4), where 4 indicates [bs_idx, x, y, z] + """ + batch_size = batch_dict['batch_size'] + if self.model_cfg.POINT_SOURCE == 'raw_points': + src_points = batch_dict['points'][:, 1:4] + batch_indices = batch_dict['points'][:, 0].long() + elif self.model_cfg.POINT_SOURCE == 'voxel_centers': + src_points = common_utils.get_voxel_centers( + batch_dict['voxel_coords'][:, 1:4], + downsample_times=1, + voxel_size=self.voxel_size, + point_cloud_range=self.point_cloud_range + ) + batch_indices = batch_dict['voxel_coords'][:, 0].long() + else: + raise NotImplementedError + keypoints_list = [] + for bs_idx in range(batch_size): + bs_mask = (batch_indices == bs_idx) + sampled_points = src_points[bs_mask].unsqueeze(dim=0) # (1, N, 3) + if self.model_cfg.SAMPLE_METHOD == 'FPS': + cur_pt_idxs = pointnet2_stack_utils.farthest_point_sample( + sampled_points[:, :, 0:3].contiguous(), self.model_cfg.NUM_KEYPOINTS + ).long() + + if sampled_points.shape[1] < self.model_cfg.NUM_KEYPOINTS: + times = int(self.model_cfg.NUM_KEYPOINTS / sampled_points.shape[1]) + 1 + non_empty = cur_pt_idxs[0, :sampled_points.shape[1]] + cur_pt_idxs[0] = non_empty.repeat(times)[:self.model_cfg.NUM_KEYPOINTS] + + keypoints = sampled_points[0][cur_pt_idxs[0]].unsqueeze(dim=0) + + elif self.model_cfg.SAMPLE_METHOD == 'SPC': + cur_keypoints = self.sectorized_proposal_centric_sampling( + roi_boxes=batch_dict['rois'][bs_idx], points=sampled_points[0] + ) + bs_idxs = cur_keypoints.new_ones(cur_keypoints.shape[0]) * bs_idx + keypoints = torch.cat((bs_idxs[:, None], cur_keypoints), dim=1) + else: + raise NotImplementedError + + keypoints_list.append(keypoints) + + keypoints = torch.cat(keypoints_list, dim=0) # (B, M, 3) or (N1 + N2 + ..., 4) + if len(keypoints.shape) == 3: + batch_idx = torch.arange(batch_size, device=keypoints.device).view(-1, 1).repeat(1, keypoints.shape[1]).view(-1, 1) + keypoints = torch.cat((batch_idx.float(), keypoints.view(-1, 3)), dim=1) + + return keypoints + + @staticmethod + def aggregate_keypoint_features_from_one_source( + batch_size, aggregate_func, xyz, xyz_features, xyz_bs_idxs, new_xyz, new_xyz_batch_cnt, + filter_neighbors_with_roi=False, radius_of_neighbor=None, num_max_points_of_part=200000, rois=None + ): + """ + + Args: + aggregate_func: + xyz: (N, 3) + xyz_features: (N, C) + xyz_bs_idxs: (N) + new_xyz: (M, 3) + new_xyz_batch_cnt: (batch_size), [N1, N2, ...] + + filter_neighbors_with_roi: True/False + radius_of_neighbor: float + num_max_points_of_part: int + rois: (batch_size, num_rois, 7 + C) + Returns: + + """ + xyz_batch_cnt = xyz.new_zeros(batch_size).int() + if filter_neighbors_with_roi: + point_features = torch.cat((xyz, xyz_features), dim=-1) if xyz_features is not None else xyz + point_features_list = [] + for bs_idx in range(batch_size): + bs_mask = (xyz_bs_idxs == bs_idx) + _, valid_mask = sample_points_with_roi( + rois=rois[bs_idx], points=xyz[bs_mask], + sample_radius_with_roi=radius_of_neighbor, num_max_points_of_part=num_max_points_of_part, + ) + point_features_list.append(point_features[bs_mask][valid_mask]) + xyz_batch_cnt[bs_idx] = valid_mask.sum() + + valid_point_features = torch.cat(point_features_list, dim=0) + xyz = valid_point_features[:, 0:3] + xyz_features = valid_point_features[:, 3:] if xyz_features is not None else None + else: + for bs_idx in range(batch_size): + xyz_batch_cnt[bs_idx] = (xyz_bs_idxs == bs_idx).sum() + + pooled_points, pooled_features = aggregate_func( + xyz=xyz.contiguous(), + xyz_batch_cnt=xyz_batch_cnt, + new_xyz=new_xyz, + new_xyz_batch_cnt=new_xyz_batch_cnt, + features=xyz_features.contiguous(), + ) + return pooled_features + + def forward(self, batch_dict): + """ + Args: + batch_dict: + batch_size: + keypoints: (B, num_keypoints, 3) + multi_scale_3d_features: { + 'x_conv4': ... + } + points: optional (N, 1 + 3 + C) [bs_idx, x, y, z, ...] + spatial_features: optional + spatial_features_stride: optional + + Returns: + point_features: (N, C) + point_coords: (N, 4) + + """ + keypoints = self.get_sampled_points(batch_dict) + + point_features_list = [] + if 'bev' in self.model_cfg.FEATURES_SOURCE: + point_bev_features = self.interpolate_from_bev_features( + keypoints, batch_dict['spatial_features'], batch_dict['batch_size'], + bev_stride=batch_dict['spatial_features_stride'] + ) + point_features_list.append(point_bev_features) + + batch_size = batch_dict['batch_size'] + + new_xyz = keypoints[:, 1:4].contiguous() + new_xyz_batch_cnt = new_xyz.new_zeros(batch_size).int() + for k in range(batch_size): + new_xyz_batch_cnt[k] = (keypoints[:, 0] == k).sum() + + if 'raw_points' in self.model_cfg.FEATURES_SOURCE: + raw_points = batch_dict['points'] + + pooled_features = self.aggregate_keypoint_features_from_one_source( + batch_size=batch_size, aggregate_func=self.SA_rawpoints, + xyz=raw_points[:, 1:4], + xyz_features=raw_points[:, 4:].contiguous() if raw_points.shape[1] > 4 else None, + xyz_bs_idxs=raw_points[:, 0], + new_xyz=new_xyz, new_xyz_batch_cnt=new_xyz_batch_cnt, + filter_neighbors_with_roi=self.model_cfg.SA_LAYER['raw_points'].get('FILTER_NEIGHBOR_WITH_ROI', False), + radius_of_neighbor=self.model_cfg.SA_LAYER['raw_points'].get('RADIUS_OF_NEIGHBOR_WITH_ROI', None), + rois=batch_dict.get('rois', None) + ) + point_features_list.append(pooled_features) + + for k, src_name in enumerate(self.SA_layer_names): + cur_coords = batch_dict['multi_scale_3d_features'][src_name].indices + cur_features = batch_dict['multi_scale_3d_features'][src_name].features.contiguous() + + xyz = common_utils.get_voxel_centers( + cur_coords[:, 1:4], downsample_times=self.downsample_times_map[src_name], + voxel_size=self.voxel_size, point_cloud_range=self.point_cloud_range + ) + + pooled_features = self.aggregate_keypoint_features_from_one_source( + batch_size=batch_size, aggregate_func=self.SA_layers[k], + xyz=xyz.contiguous(), xyz_features=cur_features, xyz_bs_idxs=cur_coords[:, 0], + new_xyz=new_xyz, new_xyz_batch_cnt=new_xyz_batch_cnt, + filter_neighbors_with_roi=self.model_cfg.SA_LAYER[src_name].get('FILTER_NEIGHBOR_WITH_ROI', False), + radius_of_neighbor=self.model_cfg.SA_LAYER[src_name].get('RADIUS_OF_NEIGHBOR_WITH_ROI', None), + rois=batch_dict.get('rois', None) + ) + + point_features_list.append(pooled_features) + + point_features = torch.cat(point_features_list, dim=-1) + + batch_dict['point_features_before_fusion'] = point_features.view(-1, point_features.shape[-1]) + point_features = self.vsa_point_feature_fusion(point_features.view(-1, point_features.shape[-1])) + + batch_dict['point_features'] = point_features # (BxN, C) + batch_dict['point_coords'] = keypoints # (BxN, 4) + return batch_dict