import math import numpy as np import torch import torch.nn as nn from ....ops.pointnet2.pointnet2_stack import pointnet2_modules as pointnet2_stack_modules from ....ops.pointnet2.pointnet2_stack import pointnet2_utils as pointnet2_stack_utils from ....utils import common_utils def bilinear_interpolate_torch(im, x, y): """ Args: im: (H, W, C) [y, x] x: (N) y: (N) Returns: """ x0 = torch.floor(x).long() x1 = x0 + 1 y0 = torch.floor(y).long() y1 = y0 + 1 x0 = torch.clamp(x0, 0, im.shape[1] - 1) x1 = torch.clamp(x1, 0, im.shape[1] - 1) y0 = torch.clamp(y0, 0, im.shape[0] - 1) y1 = torch.clamp(y1, 0, im.shape[0] - 1) Ia = im[y0, x0] Ib = im[y1, x0] Ic = im[y0, x1] Id = im[y1, x1] wa = (x1.type_as(x) - x) * (y1.type_as(y) - y) wb = (x1.type_as(x) - x) * (y - y0.type_as(y)) wc = (x - x0.type_as(x)) * (y1.type_as(y) - y) wd = (x - x0.type_as(x)) * (y - y0.type_as(y)) ans = torch.t((torch.t(Ia) * wa)) + torch.t(torch.t(Ib) * wb) + torch.t(torch.t(Ic) * wc) + torch.t(torch.t(Id) * wd) return ans def sample_points_with_roi(rois, points, sample_radius_with_roi, num_max_points_of_part=200000): """ Args: rois: (M, 7 + C) points: (N, 3) sample_radius_with_roi: num_max_points_of_part: Returns: sampled_points: (N_out, 3) """ if points.shape[0] < num_max_points_of_part: distance = (points[:, None, :] - rois[None, :, 0:3]).norm(dim=-1) min_dis, min_dis_roi_idx = distance.min(dim=-1) roi_max_dim = (rois[min_dis_roi_idx, 3:6] / 2).norm(dim=-1) point_mask = min_dis < roi_max_dim + sample_radius_with_roi else: start_idx = 0 point_mask_list = [] while start_idx < points.shape[0]: distance = (points[start_idx:start_idx + num_max_points_of_part, None, :] - rois[None, :, 0:3]).norm(dim=-1) min_dis, min_dis_roi_idx = distance.min(dim=-1) roi_max_dim = (rois[min_dis_roi_idx, 3:6] / 2).norm(dim=-1) cur_point_mask = min_dis < roi_max_dim + sample_radius_with_roi point_mask_list.append(cur_point_mask) start_idx += num_max_points_of_part point_mask = torch.cat(point_mask_list, dim=0) sampled_points = points[:1] if point_mask.sum() == 0 else points[point_mask, :] return sampled_points, point_mask def sector_fps(points, num_sampled_points, num_sectors): """ Args: points: (N, 3) num_sampled_points: int num_sectors: int Returns: sampled_points: (N_out, 3) """ sector_size = np.pi * 2 / num_sectors point_angles = torch.atan2(points[:, 1], points[:, 0]) + np.pi sector_idx = (point_angles / sector_size).floor().clamp(min=0, max=num_sectors) xyz_points_list = [] xyz_batch_cnt = [] num_sampled_points_list = [] for k in range(num_sectors): mask = (sector_idx == k) cur_num_points = mask.sum().item() if cur_num_points > 0: xyz_points_list.append(points[mask]) xyz_batch_cnt.append(cur_num_points) ratio = cur_num_points / points.shape[0] num_sampled_points_list.append( min(cur_num_points, math.ceil(ratio * num_sampled_points)) ) if len(xyz_batch_cnt) == 0: xyz_points_list.append(points) xyz_batch_cnt.append(len(points)) num_sampled_points_list.append(num_sampled_points) print(f'Warning: empty sector points detected in SectorFPS: points.shape={points.shape}') xyz = torch.cat(xyz_points_list, dim=0) xyz_batch_cnt = torch.tensor(xyz_batch_cnt, device=points.device).int() sampled_points_batch_cnt = torch.tensor(num_sampled_points_list, device=points.device).int() sampled_pt_idxs = pointnet2_stack_utils.stack_farthest_point_sample( xyz.contiguous(), xyz_batch_cnt, sampled_points_batch_cnt ).long() sampled_points = xyz[sampled_pt_idxs] return sampled_points class VoxelSetAbstraction(nn.Module): def __init__(self, model_cfg, voxel_size, point_cloud_range, num_bev_features=None, num_rawpoint_features=None, **kwargs): super().__init__() self.model_cfg = model_cfg self.voxel_size = voxel_size self.point_cloud_range = point_cloud_range SA_cfg = self.model_cfg.SA_LAYER self.SA_layers = nn.ModuleList() self.SA_layer_names = [] self.downsample_times_map = {} c_in = 0 for src_name in self.model_cfg.FEATURES_SOURCE: if src_name in ['bev', 'raw_points']: continue self.downsample_times_map[src_name] = SA_cfg[src_name].DOWNSAMPLE_FACTOR if SA_cfg[src_name].get('INPUT_CHANNELS', None) is None: input_channels = SA_cfg[src_name].MLPS[0][0] \ if isinstance(SA_cfg[src_name].MLPS[0], list) else SA_cfg[src_name].MLPS[0] else: input_channels = SA_cfg[src_name]['INPUT_CHANNELS'] cur_layer, cur_num_c_out = pointnet2_stack_modules.build_local_aggregation_module( input_channels=input_channels, config=SA_cfg[src_name] ) self.SA_layers.append(cur_layer) self.SA_layer_names.append(src_name) c_in += cur_num_c_out if 'bev' in self.model_cfg.FEATURES_SOURCE: c_bev = num_bev_features c_in += c_bev if 'raw_points' in self.model_cfg.FEATURES_SOURCE: self.SA_rawpoints, cur_num_c_out = pointnet2_stack_modules.build_local_aggregation_module( input_channels=num_rawpoint_features - 3, config=SA_cfg['raw_points'] ) c_in += cur_num_c_out self.vsa_point_feature_fusion = nn.Sequential( nn.Linear(c_in, self.model_cfg.NUM_OUTPUT_FEATURES, bias=False), nn.BatchNorm1d(self.model_cfg.NUM_OUTPUT_FEATURES), nn.ReLU(), ) self.num_point_features = self.model_cfg.NUM_OUTPUT_FEATURES self.num_point_features_before_fusion = c_in def interpolate_from_bev_features(self, keypoints, bev_features, batch_size, bev_stride): """ Args: keypoints: (N1 + N2 + ..., 4) bev_features: (B, C, H, W) batch_size: bev_stride: Returns: point_bev_features: (N1 + N2 + ..., C) """ x_idxs = (keypoints[:, 1] - self.point_cloud_range[0]) / self.voxel_size[0] y_idxs = (keypoints[:, 2] - self.point_cloud_range[1]) / self.voxel_size[1] x_idxs = x_idxs / bev_stride y_idxs = y_idxs / bev_stride point_bev_features_list = [] for k in range(batch_size): bs_mask = (keypoints[:, 0] == k) cur_x_idxs = x_idxs[bs_mask] cur_y_idxs = y_idxs[bs_mask] cur_bev_features = bev_features[k].permute(1, 2, 0) # (H, W, C) point_bev_features = bilinear_interpolate_torch(cur_bev_features, cur_x_idxs, cur_y_idxs) point_bev_features_list.append(point_bev_features) point_bev_features = torch.cat(point_bev_features_list, dim=0) # (N1 + N2 + ..., C) return point_bev_features def sectorized_proposal_centric_sampling(self, roi_boxes, points): """ Args: roi_boxes: (M, 7 + C) points: (N, 3) Returns: sampled_points: (N_out, 3) """ sampled_points, _ = sample_points_with_roi( rois=roi_boxes, points=points, sample_radius_with_roi=self.model_cfg.SPC_SAMPLING.SAMPLE_RADIUS_WITH_ROI, num_max_points_of_part=self.model_cfg.SPC_SAMPLING.get('NUM_POINTS_OF_EACH_SAMPLE_PART', 200000) ) sampled_points = sector_fps( points=sampled_points, num_sampled_points=self.model_cfg.NUM_KEYPOINTS, num_sectors=self.model_cfg.SPC_SAMPLING.NUM_SECTORS ) return sampled_points def get_sampled_points(self, batch_dict): """ Args: batch_dict: Returns: keypoints: (N1 + N2 + ..., 4), where 4 indicates [bs_idx, x, y, z] """ batch_size = batch_dict['batch_size'] if self.model_cfg.POINT_SOURCE == 'raw_points': src_points = batch_dict['points'][:, 1:4] batch_indices = batch_dict['points'][:, 0].long() elif self.model_cfg.POINT_SOURCE == 'voxel_centers': src_points = common_utils.get_voxel_centers( batch_dict['voxel_coords'][:, 1:4], downsample_times=1, voxel_size=self.voxel_size, point_cloud_range=self.point_cloud_range ) batch_indices = batch_dict['voxel_coords'][:, 0].long() else: raise NotImplementedError keypoints_list = [] for bs_idx in range(batch_size): bs_mask = (batch_indices == bs_idx) sampled_points = src_points[bs_mask].unsqueeze(dim=0) # (1, N, 3) if self.model_cfg.SAMPLE_METHOD == 'FPS': cur_pt_idxs = pointnet2_stack_utils.farthest_point_sample( sampled_points[:, :, 0:3].contiguous(), self.model_cfg.NUM_KEYPOINTS ).long() if sampled_points.shape[1] < self.model_cfg.NUM_KEYPOINTS: times = int(self.model_cfg.NUM_KEYPOINTS / sampled_points.shape[1]) + 1 non_empty = cur_pt_idxs[0, :sampled_points.shape[1]] cur_pt_idxs[0] = non_empty.repeat(times)[:self.model_cfg.NUM_KEYPOINTS] keypoints = sampled_points[0][cur_pt_idxs[0]].unsqueeze(dim=0) elif self.model_cfg.SAMPLE_METHOD == 'SPC': cur_keypoints = self.sectorized_proposal_centric_sampling( roi_boxes=batch_dict['rois'][bs_idx], points=sampled_points[0] ) bs_idxs = cur_keypoints.new_ones(cur_keypoints.shape[0]) * bs_idx keypoints = torch.cat((bs_idxs[:, None], cur_keypoints), dim=1) else: raise NotImplementedError keypoints_list.append(keypoints) keypoints = torch.cat(keypoints_list, dim=0) # (B, M, 3) or (N1 + N2 + ..., 4) if len(keypoints.shape) == 3: batch_idx = torch.arange(batch_size, device=keypoints.device).view(-1, 1).repeat(1, keypoints.shape[1]).view(-1, 1) keypoints = torch.cat((batch_idx.float(), keypoints.view(-1, 3)), dim=1) return keypoints @staticmethod def aggregate_keypoint_features_from_one_source( batch_size, aggregate_func, xyz, xyz_features, xyz_bs_idxs, new_xyz, new_xyz_batch_cnt, filter_neighbors_with_roi=False, radius_of_neighbor=None, num_max_points_of_part=200000, rois=None ): """ Args: aggregate_func: xyz: (N, 3) xyz_features: (N, C) xyz_bs_idxs: (N) new_xyz: (M, 3) new_xyz_batch_cnt: (batch_size), [N1, N2, ...] filter_neighbors_with_roi: True/False radius_of_neighbor: float num_max_points_of_part: int rois: (batch_size, num_rois, 7 + C) Returns: """ xyz_batch_cnt = xyz.new_zeros(batch_size).int() if filter_neighbors_with_roi: point_features = torch.cat((xyz, xyz_features), dim=-1) if xyz_features is not None else xyz point_features_list = [] for bs_idx in range(batch_size): bs_mask = (xyz_bs_idxs == bs_idx) _, valid_mask = sample_points_with_roi( rois=rois[bs_idx], points=xyz[bs_mask], sample_radius_with_roi=radius_of_neighbor, num_max_points_of_part=num_max_points_of_part, ) point_features_list.append(point_features[bs_mask][valid_mask]) xyz_batch_cnt[bs_idx] = valid_mask.sum() valid_point_features = torch.cat(point_features_list, dim=0) xyz = valid_point_features[:, 0:3] xyz_features = valid_point_features[:, 3:] if xyz_features is not None else None else: for bs_idx in range(batch_size): xyz_batch_cnt[bs_idx] = (xyz_bs_idxs == bs_idx).sum() pooled_points, pooled_features = aggregate_func( xyz=xyz.contiguous(), xyz_batch_cnt=xyz_batch_cnt, new_xyz=new_xyz, new_xyz_batch_cnt=new_xyz_batch_cnt, features=xyz_features.contiguous(), ) return pooled_features def forward(self, batch_dict): """ Args: batch_dict: batch_size: keypoints: (B, num_keypoints, 3) multi_scale_3d_features: { 'x_conv4': ... } points: optional (N, 1 + 3 + C) [bs_idx, x, y, z, ...] spatial_features: optional spatial_features_stride: optional Returns: point_features: (N, C) point_coords: (N, 4) """ keypoints = self.get_sampled_points(batch_dict) point_features_list = [] if 'bev' in self.model_cfg.FEATURES_SOURCE: point_bev_features = self.interpolate_from_bev_features( keypoints, batch_dict['spatial_features'], batch_dict['batch_size'], bev_stride=batch_dict['spatial_features_stride'] ) point_features_list.append(point_bev_features) batch_size = batch_dict['batch_size'] new_xyz = keypoints[:, 1:4].contiguous() new_xyz_batch_cnt = new_xyz.new_zeros(batch_size).int() for k in range(batch_size): new_xyz_batch_cnt[k] = (keypoints[:, 0] == k).sum() if 'raw_points' in self.model_cfg.FEATURES_SOURCE: raw_points = batch_dict['points'] pooled_features = self.aggregate_keypoint_features_from_one_source( batch_size=batch_size, aggregate_func=self.SA_rawpoints, xyz=raw_points[:, 1:4], xyz_features=raw_points[:, 4:].contiguous() if raw_points.shape[1] > 4 else None, xyz_bs_idxs=raw_points[:, 0], new_xyz=new_xyz, new_xyz_batch_cnt=new_xyz_batch_cnt, filter_neighbors_with_roi=self.model_cfg.SA_LAYER['raw_points'].get('FILTER_NEIGHBOR_WITH_ROI', False), radius_of_neighbor=self.model_cfg.SA_LAYER['raw_points'].get('RADIUS_OF_NEIGHBOR_WITH_ROI', None), rois=batch_dict.get('rois', None) ) point_features_list.append(pooled_features) for k, src_name in enumerate(self.SA_layer_names): cur_coords = batch_dict['multi_scale_3d_features'][src_name].indices cur_features = batch_dict['multi_scale_3d_features'][src_name].features.contiguous() xyz = common_utils.get_voxel_centers( cur_coords[:, 1:4], downsample_times=self.downsample_times_map[src_name], voxel_size=self.voxel_size, point_cloud_range=self.point_cloud_range ) pooled_features = self.aggregate_keypoint_features_from_one_source( batch_size=batch_size, aggregate_func=self.SA_layers[k], xyz=xyz.contiguous(), xyz_features=cur_features, xyz_bs_idxs=cur_coords[:, 0], new_xyz=new_xyz, new_xyz_batch_cnt=new_xyz_batch_cnt, filter_neighbors_with_roi=self.model_cfg.SA_LAYER[src_name].get('FILTER_NEIGHBOR_WITH_ROI', False), radius_of_neighbor=self.model_cfg.SA_LAYER[src_name].get('RADIUS_OF_NEIGHBOR_WITH_ROI', None), rois=batch_dict.get('rois', None) ) point_features_list.append(pooled_features) point_features = torch.cat(point_features_list, dim=-1) batch_dict['point_features_before_fusion'] = point_features.view(-1, point_features.shape[-1]) point_features = self.vsa_point_feature_fusion(point_features.view(-1, point_features.shape[-1])) batch_dict['point_features'] = point_features # (BxN, C) batch_dict['point_coords'] = keypoints # (BxN, 4) return batch_dict