import torch import torch.nn as nn from torch.utils.checkpoint import checkpoint from math import ceil from pcdet.models.model_utils.dsvt_utils import get_window_coors, get_inner_win_inds_cuda, get_pooling_index, get_continous_inds from pcdet.models.model_utils.dsvt_utils import PositionEmbeddingLearned class DSVT(nn.Module): '''Dynamic Sparse Voxel Transformer Backbone. Args: INPUT_LAYER: Config of input layer, which converts the output of vfe to dsvt input. block_name (list[string]): Name of blocks for each stage. Length: stage_num. set_info (list[list[int, int]]): A list of set config for each stage. Eelement i contains [set_size, block_num], where set_size is the number of voxel in a set and block_num is the number of blocks for stage i. Length: stage_num. d_model (list[int]): Number of input channels for each stage. Length: stage_num. nhead (list[int]): Number of attention heads for each stage. Length: stage_num. dim_feedforward (list[int]): Dimensions of the feedforward network in set attention for each stage. Length: stage num. dropout (float): Drop rate of set attention. activation (string): Name of activation layer in set attention. reduction_type (string): Pooling method between stages. One of: "attention", "maxpool", "linear". output_shape (tuple[int, int]): Shape of output bev feature. conv_out_channel (int): Number of output channels. ''' def __init__(self, model_cfg, **kwargs): super().__init__() self.model_cfg = model_cfg self.input_layer = DSVTInputLayer(self.model_cfg.INPUT_LAYER) block_name = self.model_cfg.block_name set_info = self.model_cfg.set_info d_model = self.model_cfg.d_model nhead = self.model_cfg.nhead dim_feedforward = self.model_cfg.dim_feedforward dropout = self.model_cfg.dropout activation = self.model_cfg.activation self.reduction_type = self.model_cfg.get('reduction_type', 'attention') # save GPU memory self.use_torch_ckpt = self.model_cfg.get('USE_CHECKPOINT', False) # Sparse Regional Attention Blocks stage_num = len(block_name) for stage_id in range(stage_num): num_blocks_this_stage = set_info[stage_id][-1] dmodel_this_stage = d_model[stage_id] dfeed_this_stage = dim_feedforward[stage_id] num_head_this_stage = nhead[stage_id] block_name_this_stage = block_name[stage_id] block_module = _get_block_module(block_name_this_stage) block_list=[] norm_list=[] for i in range(num_blocks_this_stage): block_list.append( block_module(dmodel_this_stage, num_head_this_stage, dfeed_this_stage, dropout, activation, batch_first=True) ) norm_list.append(nn.LayerNorm(dmodel_this_stage)) self.__setattr__(f'stage_{stage_id}', nn.ModuleList(block_list)) self.__setattr__(f'residual_norm_stage_{stage_id}', nn.ModuleList(norm_list)) # apply pooling except the last stage if stage_id < stage_num-1: downsample_window = self.model_cfg.INPUT_LAYER.downsample_stride[stage_id] dmodel_next_stage = d_model[stage_id+1] pool_volume = torch.IntTensor(downsample_window).prod().item() if self.reduction_type == 'linear': cat_feat_dim = dmodel_this_stage * torch.IntTensor(downsample_window).prod().item() self.__setattr__(f'stage_{stage_id}_reduction', Stage_Reduction_Block(cat_feat_dim, dmodel_next_stage)) elif self.reduction_type == 'maxpool': self.__setattr__(f'stage_{stage_id}_reduction', torch.nn.MaxPool1d(pool_volume)) elif self.reduction_type == 'attention': self.__setattr__(f'stage_{stage_id}_reduction', Stage_ReductionAtt_Block(dmodel_this_stage, pool_volume)) else: raise NotImplementedError self.num_shifts = [2] * stage_num self.output_shape = self.model_cfg.output_shape self.stage_num = stage_num self.set_info = set_info self.num_point_features = self.model_cfg.conv_out_channel self._reset_parameters() def forward(self, batch_dict): ''' Args: bacth_dict (dict): The dict contains the following keys - voxel_features (Tensor[float]): Voxel features after VFE. Shape of (N, d_model[0]), where N is the number of input voxels. - voxel_coords (Tensor[int]): Shape of (N, 4), corresponding voxel coordinates of each voxels. Each row is (batch_id, z, y, x). - ... Returns: bacth_dict (dict): The dict contains the following keys - pillar_features (Tensor[float]): - voxel_coords (Tensor[int]): - ... ''' voxel_info = self.input_layer(batch_dict) voxel_feat = voxel_info['voxel_feats_stage0'] set_voxel_inds_list = [[voxel_info[f'set_voxel_inds_stage{s}_shift{i}'] for i in range(self.num_shifts[s])] for s in range(self.stage_num)] set_voxel_masks_list = [[voxel_info[f'set_voxel_mask_stage{s}_shift{i}'] for i in range(self.num_shifts[s])] for s in range(self.stage_num)] pos_embed_list = [[[voxel_info[f'pos_embed_stage{s}_block{b}_shift{i}'] for i in range(self.num_shifts[s])] for b in range(self.set_info[s][1])] for s in range(self.stage_num)] pooling_mapping_index = [voxel_info[f'pooling_mapping_index_stage{s+1}'] for s in range(self.stage_num-1)] pooling_index_in_pool = [voxel_info[f'pooling_index_in_pool_stage{s+1}'] for s in range(self.stage_num-1)] pooling_preholder_feats = [voxel_info[f'pooling_preholder_feats_stage{s+1}'] for s in range(self.stage_num-1)] output = voxel_feat block_id = 0 for stage_id in range(self.stage_num): block_layers = self.__getattr__(f'stage_{stage_id}') residual_norm_layers = self.__getattr__(f'residual_norm_stage_{stage_id}') for i in range(len(block_layers)): block = block_layers[i] residual = output.clone() if self.use_torch_ckpt==False: output = block(output, set_voxel_inds_list[stage_id], set_voxel_masks_list[stage_id], pos_embed_list[stage_id][i], \ block_id=block_id) else: output = checkpoint(block, output, set_voxel_inds_list[stage_id], set_voxel_masks_list[stage_id], pos_embed_list[stage_id][i], block_id) output = residual_norm_layers[i](output + residual) block_id += 1 if stage_id < self.stage_num - 1: # pooling prepool_features = pooling_preholder_feats[stage_id].type_as(output) pooled_voxel_num = prepool_features.shape[0] pool_volume = prepool_features.shape[1] prepool_features[pooling_mapping_index[stage_id], pooling_index_in_pool[stage_id]] = output prepool_features = prepool_features.view(prepool_features.shape[0], -1) if self.reduction_type == 'linear': output = self.__getattr__(f'stage_{stage_id}_reduction')(prepool_features) elif self.reduction_type == 'maxpool': prepool_features = prepool_features.view(pooled_voxel_num, pool_volume, -1).permute(0, 2, 1) output = self.__getattr__(f'stage_{stage_id}_reduction')(prepool_features).squeeze(-1) elif self.reduction_type == 'attention': prepool_features = prepool_features.view(pooled_voxel_num, pool_volume, -1).permute(0, 2, 1) key_padding_mask = torch.zeros((pooled_voxel_num, pool_volume)).to(prepool_features.device).int() output = self.__getattr__(f'stage_{stage_id}_reduction')(prepool_features, key_padding_mask) else: raise NotImplementedError batch_dict['pillar_features'] = batch_dict['voxel_features'] = output batch_dict['voxel_coords'] = voxel_info[f'voxel_coors_stage{self.stage_num - 1}'] return batch_dict def _reset_parameters(self): for name, p in self.named_parameters(): if p.dim() > 1 and 'scaler' not in name: nn.init.xavier_uniform_(p) class DSVTBlock(nn.Module): ''' Consist of two encoder layer, shift and shift back. ''' def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", batch_first=True): super().__init__() encoder_1 = DSVT_EncoderLayer(d_model, nhead, dim_feedforward, dropout, activation, batch_first) encoder_2 = DSVT_EncoderLayer(d_model, nhead, dim_feedforward, dropout, activation, batch_first) self.encoder_list = nn.ModuleList([encoder_1, encoder_2]) def forward( self, src, set_voxel_inds_list, set_voxel_masks_list, pos_embed_list, block_id, ): num_shifts = 2 output = src # TODO: bug to be fixed, mismatch of pos_embed for i in range(num_shifts): set_id = i shift_id = block_id % 2 pos_embed_id = i set_voxel_inds = set_voxel_inds_list[shift_id][set_id] set_voxel_masks = set_voxel_masks_list[shift_id][set_id] pos_embed = pos_embed_list[pos_embed_id] layer = self.encoder_list[i] output = layer(output, set_voxel_inds, set_voxel_masks, pos_embed) return output class DSVT_EncoderLayer(nn.Module): def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", batch_first=True, mlp_dropout=0): super().__init__() self.win_attn = SetAttention(d_model, nhead, dropout, dim_feedforward, activation, batch_first, mlp_dropout) self.norm = nn.LayerNorm(d_model) self.d_model = d_model def forward(self,src,set_voxel_inds,set_voxel_masks,pos=None): identity = src src = self.win_attn(src, pos, set_voxel_masks, set_voxel_inds) src = src + identity src = self.norm(src) return src class SetAttention(nn.Module): def __init__(self, d_model, nhead, dropout, dim_feedforward=2048, activation="relu", batch_first=True, mlp_dropout=0): super().__init__() self.nhead = nhead if batch_first: self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first) else: self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(mlp_dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.d_model = d_model self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout1 = nn.Identity() self.dropout2 = nn.Identity() self.activation = _get_activation_fn(activation) def forward(self, src, pos=None, key_padding_mask=None, voxel_inds=None): ''' Args: src (Tensor[float]): Voxel features with shape (N, C), where N is the number of voxels. pos (Tensor[float]): Position embedding vectors with shape (N, C). key_padding_mask (Tensor[bool]): Mask for redundant voxels within set. Shape of (set_num, set_size). voxel_inds (Tensor[int]): Voxel indexs for each set. Shape of (set_num, set_size). Returns: src (Tensor[float]): Voxel features. ''' set_features = src[voxel_inds] if pos is not None: set_pos = pos[voxel_inds] else: set_pos = None if pos is not None: query = set_features + set_pos key = set_features + set_pos value = set_features if key_padding_mask is not None: src2 = self.self_attn(query, key, value, key_padding_mask)[0] else: src2 = self.self_attn(query, key, value)[0] # map voxel featurs from set space to voxel space: (set_num, set_size, C) --> (N, C) flatten_inds = voxel_inds.reshape(-1) unique_flatten_inds, inverse = torch.unique(flatten_inds, return_inverse=True) perm = torch.arange(inverse.size(0), dtype=inverse.dtype, device=inverse.device) inverse, perm = inverse.flip([0]), perm.flip([0]) perm = inverse.new_empty(unique_flatten_inds.size(0)).scatter_(0, inverse, perm) src2 = src2.reshape(-1, self.d_model)[perm] # FFN layer src = src + self.dropout1(src2) src = self.norm1(src) src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) src = src + self.dropout2(src2) src = self.norm2(src) return src class Stage_Reduction_Block(nn.Module): def __init__(self, input_channel, output_channel): super().__init__() self.linear1 = nn.Linear(input_channel, output_channel, bias=False) self.norm = nn.LayerNorm(output_channel) def forward(self, x): src = x src = self.norm(self.linear1(x)) return src class Stage_ReductionAtt_Block(nn.Module): def __init__(self, input_channel, pool_volume): super().__init__() self.pool_volume = pool_volume self.query_func = torch.nn.MaxPool1d(pool_volume) self.norm = nn.LayerNorm(input_channel) self.self_attn = nn.MultiheadAttention(input_channel, 8, batch_first=True) self.pos_embedding = nn.Parameter(torch.randn(pool_volume, input_channel)) nn.init.normal_(self.pos_embedding, std=.01) def forward(self, x, key_padding_mask): # x: [voxel_num, c_dim, pool_volume] src = self.query_func(x).permute(0, 2, 1) # voxel_num, 1, c_dim key = value = x.permute(0, 2, 1) key = key + self.pos_embedding.unsqueeze(0).repeat(src.shape[0], 1, 1) query = src.clone() output = self.self_attn(query, key, value, key_padding_mask)[0] src = self.norm(output + src).squeeze(1) return src def _get_activation_fn(activation): """Return an activation function given a string""" if activation == "relu": return torch.nn.functional.relu if activation == "gelu": return torch.nn.functional.gelu if activation == "glu": return torch.nn.functional.glu raise RuntimeError(F"activation should be relu/gelu, not {activation}.") def _get_block_module(name): """Return an block module given a string""" if name == "DSVTBlock": return DSVTBlock raise RuntimeError(F"This Block not exist.") class DSVTInputLayer(nn.Module): ''' This class converts the output of vfe to dsvt input. We do in this class: 1. Window partition: partition voxels to non-overlapping windows. 2. Set partition: generate non-overlapped and size-equivalent local sets within each window. 3. Pre-compute the downsample infomation between two consecutive stages. 4. Pre-compute the position embedding vectors. Args: sparse_shape (tuple[int, int, int]): Shape of input space (xdim, ydim, zdim). window_shape (list[list[int, int, int]]): Window shapes (winx, winy, winz) in different stages. Length: stage_num. downsample_stride (list[list[int, int, int]]): Downsample strides between two consecutive stages. Element i is [ds_x, ds_y, ds_z], which is used between stage_i and stage_{i+1}. Length: stage_num - 1. d_model (list[int]): Number of input channels for each stage. Length: stage_num. set_info (list[list[int, int]]): A list of set config for each stage. Eelement i contains [set_size, block_num], where set_size is the number of voxel in a set and block_num is the number of blocks for stage i. Length: stage_num. hybrid_factor (list[int, int, int]): Control the window shape in different blocks. e.g. for block_{0} and block_{1} in stage_0, window shapes are [win_x, win_y, win_z] and [win_x * h[0], win_y * h[1], win_z * h[2]] respectively. shift_list (list): Shift window. Length: stage_num. normalize_pos (bool): Whether to normalize coordinates in position embedding. ''' def __init__(self, model_cfg): super().__init__() self.model_cfg = model_cfg self.sparse_shape = self.model_cfg.sparse_shape self.window_shape = self.model_cfg.window_shape self.downsample_stride = self.model_cfg.downsample_stride self.d_model = self.model_cfg.d_model self.set_info = self.model_cfg.set_info self.stage_num = len(self.d_model) self.hybrid_factor = self.model_cfg.hybrid_factor self.window_shape = [[self.window_shape[s_id], [self.window_shape[s_id][coord_id] * self.hybrid_factor[coord_id] \ for coord_id in range(3)]] for s_id in range(self.stage_num)] self.shift_list = self.model_cfg.shifts_list self.normalize_pos = self.model_cfg.normalize_pos self.num_shifts = [2,] * len(self.window_shape) self.sparse_shape_list = [self.sparse_shape] # compute sparse shapes for each stage for ds_stride in self.downsample_stride: last_sparse_shape = self.sparse_shape_list[-1] self.sparse_shape_list.append((ceil(last_sparse_shape[0]/ds_stride[0]), ceil(last_sparse_shape[1]/ds_stride[1]), ceil(last_sparse_shape[2]/ds_stride[2]))) # position embedding layers self.posembed_layers = nn.ModuleList() for i in range(len(self.set_info)): input_dim = 3 if self.sparse_shape_list[i][-1] > 1 else 2 stage_posembed_layers = nn.ModuleList() for j in range(self.set_info[i][1]): block_posembed_layers = nn.ModuleList() for s in range(self.num_shifts[i]): block_posembed_layers.append(PositionEmbeddingLearned(input_dim, self.d_model[i])) stage_posembed_layers.append(block_posembed_layers) self.posembed_layers.append(stage_posembed_layers) def forward(self, batch_dict): ''' Args: bacth_dict (dict): The dict contains the following keys - voxel_features (Tensor[float]): Voxel features after VFE with shape (N, d_model[0]), where N is the number of input voxels. - voxel_coords (Tensor[int]): Shape of (N, 4), corresponding voxel coordinates of each voxels. Each row is (batch_id, z, y, x). - ... Returns: voxel_info (dict): The dict contains the following keys - voxel_coors_stage{i} (Tensor[int]): Shape of (N_i, 4). N is the number of voxels in stage_i. Each row is (batch_id, z, y, x). - set_voxel_inds_stage{i}_shift{j} (Tensor[int]): Set partition index with shape (2, set_num, set_info[i][0]). 2 indicates x-axis partition and y-axis partition. - set_voxel_mask_stage{i}_shift{i} (Tensor[bool]): Key mask used in set attention with shape (2, set_num, set_info[i][0]). - pos_embed_stage{i}_block{i}_shift{i} (Tensor[float]): Position embedding vectors with shape (N_i, d_model[i]). N_i is the number of remain voxels in stage_i; - pooling_mapping_index_stage{i} (Tensor[int]): Pooling region index used in pooling operation between stage_{i-1} and stage_{i} with shape (N_{i-1}). - pooling_index_in_pool_stage{i} (Tensor[int]): Index inner region with shape (N_{i-1}). Combined with pooling_mapping_index_stage{i}, we can map each voxel in satge_{i-1} to pooling_preholder_feats_stage{i}, which are input of downsample operation. - pooling_preholder_feats_stage{i} (Tensor[int]): Preholder features initial with value 0. Shape of (N_{i}, downsample_stride[i-1].prob(), d_moel[i-1]), where prob() returns the product of all elements. - ... ''' voxel_feats = batch_dict['voxel_features'] voxel_coors = batch_dict['voxel_coords'].long() voxel_info = {} voxel_info['voxel_feats_stage0'] = voxel_feats.clone() voxel_info['voxel_coors_stage0'] = voxel_coors.clone() for stage_id in range(self.stage_num): # window partition of corrsponding stage-map voxel_info = self.window_partition(voxel_info, stage_id) # generate set id of corrsponding stage-map voxel_info = self.get_set(voxel_info, stage_id) for block_id in range(self.set_info[stage_id][1]): for shift_id in range(self.num_shifts[stage_id]): voxel_info[f'pos_embed_stage{stage_id}_block{block_id}_shift{shift_id}'] = \ self.get_pos_embed(voxel_info[f'coors_in_win_stage{stage_id}_shift{shift_id}'], stage_id, block_id, shift_id) # compute pooling information if stage_id < self.stage_num - 1: voxel_info = self.subm_pooling(voxel_info, stage_id) return voxel_info @torch.no_grad() def subm_pooling(self, voxel_info, stage_id): # x,y,z stride cur_stage_downsample = self.downsample_stride[stage_id] # batch_win_coords is from 1 of x, y batch_win_inds, _, index_in_win, batch_win_coors = get_pooling_index(voxel_info[f'voxel_coors_stage{stage_id}'], self.sparse_shape_list[stage_id], cur_stage_downsample) # compute pooling mapping index unique_batch_win_inds, contiguous_batch_win_inds = torch.unique(batch_win_inds, return_inverse=True) voxel_info[f'pooling_mapping_index_stage{stage_id+1}'] = contiguous_batch_win_inds # generate empty placeholder features placeholder_prepool_feats = voxel_info[f'voxel_feats_stage0'].new_zeros((len(unique_batch_win_inds), torch.prod(torch.IntTensor(cur_stage_downsample)).item(), self.d_model[stage_id])) voxel_info[f'pooling_index_in_pool_stage{stage_id+1}'] = index_in_win voxel_info[f'pooling_preholder_feats_stage{stage_id+1}'] = placeholder_prepool_feats # compute pooling coordinates unique, inverse = unique_batch_win_inds.clone(), contiguous_batch_win_inds.clone() perm = torch.arange(inverse.size(0), dtype=inverse.dtype, device=inverse.device) inverse, perm = inverse.flip([0]), perm.flip([0]) perm = inverse.new_empty(unique.size(0)).scatter_(0, inverse, perm) pool_coors = batch_win_coors[perm] voxel_info[f'voxel_coors_stage{stage_id+1}'] = pool_coors return voxel_info def get_set(self, voxel_info, stage_id): ''' This is one of the core operation of DSVT. Given voxels' window ids and relative-coords inner window, we partition them into window-bounded and size-equivalent local sets. To make it clear and easy to follow, we do not use loop to process two shifts. Args: voxel_info (dict): The dict contains the following keys - batch_win_inds_s{i} (Tensor[float]): Windows indexs of each voxel with shape (N), computed by 'window_partition'. - coors_in_win_shift{i} (Tensor[int]): Relative-coords inner window of each voxel with shape (N, 3), computed by 'window_partition'. Each row is (z, y, x). - ... Returns: See from 'forward' function. ''' batch_win_inds_shift0 = voxel_info[f'batch_win_inds_stage{stage_id}_shift0'] coors_in_win_shift0 = voxel_info[f'coors_in_win_stage{stage_id}_shift0'] set_voxel_inds_shift0 = self.get_set_single_shift(batch_win_inds_shift0, stage_id, shift_id=0, coors_in_win=coors_in_win_shift0) voxel_info[f'set_voxel_inds_stage{stage_id}_shift0'] = set_voxel_inds_shift0 # compute key masks, voxel duplication must happen continuously prefix_set_voxel_inds_s0 = torch.roll(set_voxel_inds_shift0.clone(), shifts=1, dims=-1) prefix_set_voxel_inds_s0[ :, :, 0] = -1 set_voxel_mask_s0 = (set_voxel_inds_shift0 == prefix_set_voxel_inds_s0) voxel_info[f'set_voxel_mask_stage{stage_id}_shift0'] = set_voxel_mask_s0 batch_win_inds_shift1 = voxel_info[f'batch_win_inds_stage{stage_id}_shift1'] coors_in_win_shift1 = voxel_info[f'coors_in_win_stage{stage_id}_shift1'] set_voxel_inds_shift1 = self.get_set_single_shift(batch_win_inds_shift1, stage_id, shift_id=1, coors_in_win=coors_in_win_shift1) voxel_info[f'set_voxel_inds_stage{stage_id}_shift1'] = set_voxel_inds_shift1 # compute key masks, voxel duplication must happen continuously prefix_set_voxel_inds_s1 = torch.roll(set_voxel_inds_shift1.clone(), shifts=1, dims=-1) prefix_set_voxel_inds_s1[ :, :, 0] = -1 set_voxel_mask_s1 = (set_voxel_inds_shift1 == prefix_set_voxel_inds_s1) voxel_info[f'set_voxel_mask_stage{stage_id}_shift1'] = set_voxel_mask_s1 return voxel_info def get_set_single_shift(self, batch_win_inds, stage_id, shift_id=None, coors_in_win=None): device = batch_win_inds.device # the number of voxels assigned to a set voxel_num_set = self.set_info[stage_id][0] # max number of voxels in a window max_voxel = self.window_shape[stage_id][shift_id][0] * self.window_shape[stage_id][shift_id][1] * self.window_shape[stage_id][shift_id][2] # get unique set indexs contiguous_win_inds = torch.unique(batch_win_inds, return_inverse=True)[1] voxelnum_per_win = torch.bincount(contiguous_win_inds) win_num = voxelnum_per_win.shape[0] setnum_per_win_float = voxelnum_per_win / voxel_num_set setnum_per_win = torch.ceil(setnum_per_win_float).long() set_win_inds, set_inds_in_win = get_continous_inds(setnum_per_win) # compution of Eq.3 in 'DSVT: Dynamic Sparse Voxel Transformer with Rotated Sets' - https://arxiv.org/abs/2301.06051, # for each window, we can get voxel indexs belong to different sets. offset_idx = set_inds_in_win[:,None].repeat(1, voxel_num_set) * voxel_num_set base_idx = torch.arange(0, voxel_num_set, 1, device=device) base_select_idx = offset_idx + base_idx base_select_idx = base_select_idx * voxelnum_per_win[set_win_inds][:,None] base_select_idx = base_select_idx.double() / (setnum_per_win[set_win_inds] * voxel_num_set)[:,None].double() base_select_idx = torch.floor(base_select_idx) # obtain unique indexs in whole space select_idx = base_select_idx select_idx = select_idx + set_win_inds.view(-1, 1) * max_voxel # this function will return unordered inner window indexs of each voxel inner_voxel_inds = get_inner_win_inds_cuda(contiguous_win_inds) global_voxel_inds = contiguous_win_inds * max_voxel + inner_voxel_inds _, order1 = torch.sort(global_voxel_inds) # get y-axis partition results global_voxel_inds_sorty = contiguous_win_inds * max_voxel + \ coors_in_win[:,1] * self.window_shape[stage_id][shift_id][0] * self.window_shape[stage_id][shift_id][2] + \ coors_in_win[:,2] * self.window_shape[stage_id][shift_id][2] + \ coors_in_win[:,0] _, order2 = torch.sort(global_voxel_inds_sorty) inner_voxel_inds_sorty = -torch.ones_like(inner_voxel_inds) inner_voxel_inds_sorty.scatter_(dim=0, index=order2, src=inner_voxel_inds[order1]) # get y-axis ordered inner window indexs of each voxel voxel_inds_in_batch_sorty = inner_voxel_inds_sorty + max_voxel * contiguous_win_inds voxel_inds_padding_sorty = -1 * torch.ones((win_num * max_voxel), dtype=torch.long, device=device) voxel_inds_padding_sorty[voxel_inds_in_batch_sorty] = torch.arange(0, voxel_inds_in_batch_sorty.shape[0], dtype=torch.long, device=device) set_voxel_inds_sorty = voxel_inds_padding_sorty[select_idx.long()] # get x-axis partition results global_voxel_inds_sortx = contiguous_win_inds * max_voxel + \ coors_in_win[:,2] * self.window_shape[stage_id][shift_id][1] * self.window_shape[stage_id][shift_id][2] + \ coors_in_win[:,1] * self.window_shape[stage_id][shift_id][2] + \ coors_in_win[:,0] _, order2 = torch.sort(global_voxel_inds_sortx) inner_voxel_inds_sortx = -torch.ones_like(inner_voxel_inds) inner_voxel_inds_sortx.scatter_(dim=0,index=order2, src=inner_voxel_inds[order1]) # get x-axis ordered inner window indexs of each voxel voxel_inds_in_batch_sortx = inner_voxel_inds_sortx + max_voxel * contiguous_win_inds voxel_inds_padding_sortx = -1 * torch.ones((win_num * max_voxel), dtype=torch.long, device=device) voxel_inds_padding_sortx[voxel_inds_in_batch_sortx] = torch.arange(0, voxel_inds_in_batch_sortx.shape[0], dtype=torch.long, device=device) set_voxel_inds_sortx = voxel_inds_padding_sortx[select_idx.long()] all_set_voxel_inds = torch.stack((set_voxel_inds_sorty, set_voxel_inds_sortx), dim=0) return all_set_voxel_inds @torch.no_grad() def window_partition(self, voxel_info, stage_id): for i in range(2): batch_win_inds, coors_in_win = get_window_coors(voxel_info[f'voxel_coors_stage{stage_id}'], self.sparse_shape_list[stage_id], self.window_shape[stage_id][i], i == 1, self.shift_list[stage_id][i]) voxel_info[f'batch_win_inds_stage{stage_id}_shift{i}'] = batch_win_inds voxel_info[f'coors_in_win_stage{stage_id}_shift{i}'] = coors_in_win return voxel_info def get_pos_embed(self, coors_in_win, stage_id, block_id, shift_id): ''' Args: coors_in_win: shape=[N, 3], order: z, y, x ''' # [N,] window_shape = self.window_shape[stage_id][shift_id] embed_layer = self.posembed_layers[stage_id][block_id][shift_id] if len(window_shape) == 2: ndim = 2 win_x, win_y = window_shape win_z = 0 elif window_shape[-1] == 1: ndim = 2 win_x, win_y = window_shape[:2] win_z = 0 else: win_x, win_y, win_z = window_shape ndim = 3 assert coors_in_win.size(1) == 3 z, y, x = coors_in_win[:, 0] - win_z/2, coors_in_win[:, 1] - win_y/2, coors_in_win[:, 2] - win_x/2 if self.normalize_pos: x = x / win_x * 2 * 3.1415 #[-pi, pi] y = y / win_y * 2 * 3.1415 #[-pi, pi] z = z / win_z * 2 * 3.1415 #[-pi, pi] if ndim==2: location = torch.stack((x, y), dim=-1) else: location = torch.stack((x, y, z), dim=-1) pos_embed = embed_layer(location) return pos_embed