From 57375667fc34f1dab432c1a577459cb65a0189cb Mon Sep 17 00:00:00 2001 From: inter Date: Sun, 21 Sep 2025 20:18:49 +0800 Subject: [PATCH] Add File --- pcdet/models/model_utils/mppnet_utils.py | 420 +++++++++++++++++++++++ 1 file changed, 420 insertions(+) create mode 100644 pcdet/models/model_utils/mppnet_utils.py diff --git a/pcdet/models/model_utils/mppnet_utils.py b/pcdet/models/model_utils/mppnet_utils.py new file mode 100644 index 0000000..10641ad --- /dev/null +++ b/pcdet/models/model_utils/mppnet_utils.py @@ -0,0 +1,420 @@ +from os import getgrouplist +import torch.nn as nn +import torch +import numpy as np +import torch.nn.functional as F +from typing import Optional, List +from torch import Tensor +from torch.nn.init import xavier_uniform_, zeros_, kaiming_normal_ + + +class PointNetfeat(nn.Module): + def __init__(self, input_dim, x=1,outchannel=512): + super(PointNetfeat, self).__init__() + if outchannel==256: + self.output_channel = 256 + else: + self.output_channel = 512 * x + self.conv1 = torch.nn.Conv1d(input_dim, 64 * x, 1) + self.conv2 = torch.nn.Conv1d(64 * x, 128 * x, 1) + self.conv3 = torch.nn.Conv1d(128 * x, 256 * x, 1) + self.conv4 = torch.nn.Conv1d(256 * x, self.output_channel, 1) + self.bn1 = nn.BatchNorm1d(64 * x) + self.bn2 = nn.BatchNorm1d(128 * x) + self.bn3 = nn.BatchNorm1d(256 * x) + self.bn4 = nn.BatchNorm1d(self.output_channel) + + def forward(self, x): + x = F.relu(self.bn1(self.conv1(x))) + x = F.relu(self.bn2(self.conv2(x))) + x = F.relu(self.bn3(self.conv3(x))) + x_ori = self.bn4(self.conv4(x)) + + x = torch.max(x_ori, 2, keepdim=True)[0] + + x = x.view(-1, self.output_channel) + return x, x_ori + +class PointNet(nn.Module): + def __init__(self, input_dim, joint_feat=False,model_cfg=None): + super(PointNet, self).__init__() + self.joint_feat = joint_feat + channels = model_cfg.TRANS_INPUT + + times=1 + self.feat = PointNetfeat(input_dim, 1) + + self.fc1 = nn.Linear(512, 256 ) + self.fc2 = nn.Linear(256, channels) + + self.pre_bn = nn.BatchNorm1d(input_dim) + self.bn1 = nn.BatchNorm1d(256) + self.bn2 = nn.BatchNorm1d(channels) + self.relu = nn.ReLU() + + self.fc_s1 = nn.Linear(channels*times, 256) + self.fc_s2 = nn.Linear(256, 3, bias=False) + self.fc_ce1 = nn.Linear(channels*times, 256) + self.fc_ce2 = nn.Linear(256, 3, bias=False) + self.fc_hr1 = nn.Linear(channels*times, 256) + self.fc_hr2 = nn.Linear(256, 1, bias=False) + + def forward(self, x, feat=None): + + if self.joint_feat: + if len(feat.shape) > 2: + feat = torch.max(feat, 2, keepdim=True)[0] + x = feat.view(-1, self.output_channel) + x = F.relu(self.bn1(self.fc1(x))) + feat = F.relu(self.bn2(self.fc2(x))) + else: + feat = feat + feat_traj = None + else: + x, feat_traj = self.feat(self.pre_bn(x)) + x = F.relu(self.bn1(self.fc1(x))) + feat = F.relu(self.bn2(self.fc2(x))) + + x = F.relu(self.fc_ce1(feat)) + centers = self.fc_ce2(x) + + x = F.relu(self.fc_s1(feat)) + sizes = self.fc_s2(x) + + x = F.relu(self.fc_hr1(feat)) + headings = self.fc_hr2(x) + + return torch.cat([centers, sizes, headings],-1),feat,feat_traj + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear): + kaiming_normal_(m.weight.data) + if m.bias is not None: + zeros_(m.bias) + +class MLP(nn.Module): + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + +class SpatialMixerBlock(nn.Module): + + def __init__(self,hidden_dim,grid_size,channels,config=None,dropout=0.0): + super().__init__() + + + self.mixer_x = MLP(input_dim = grid_size, hidden_dim = hidden_dim, output_dim = grid_size, num_layers = 3) + self.mixer_y = MLP(input_dim = grid_size, hidden_dim = hidden_dim, output_dim = grid_size, num_layers = 3) + self.mixer_z = MLP(input_dim = grid_size, hidden_dim = hidden_dim, output_dim = grid_size, num_layers = 3) + self.norm_x = nn.LayerNorm(channels) + self.norm_y = nn.LayerNorm(channels) + self.norm_z = nn.LayerNorm(channels) + self.norm_channel = nn.LayerNorm(channels) + self.ffn = nn.Sequential( + nn.Linear(channels, 2*channels), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(2*channels, channels), + ) + self.config = config + self.grid_size = grid_size + + def forward(self, src): + + src_3d = src.permute(1,2,0).contiguous().view(src.shape[1],src.shape[2], + self.grid_size,self.grid_size,self.grid_size) + src_3d = src_3d.permute(0,1,4,3,2).contiguous() + mixed_x = self.mixer_x(src_3d) + mixed_x = src_3d + mixed_x + mixed_x = self.norm_x(mixed_x.permute(0,2,3,4,1)).permute(0,4,1,2,3).contiguous() + + mixed_y = self.mixer_y(mixed_x.permute(0,1,2,4,3)).permute(0,1,2,4,3).contiguous() + mixed_y = mixed_x + mixed_y + mixed_y = self.norm_y(mixed_y.permute(0,2,3,4,1)).permute(0,4,1,2,3).contiguous() + + mixed_z = self.mixer_z(mixed_y.permute(0,1,4,3,2)).permute(0,1,4,3,2).contiguous() + + mixed_z = mixed_y + mixed_z + mixed_z = self.norm_z(mixed_z.permute(0,2,3,4,1)).permute(0,4,1,2,3).contiguous() + + src_mixer = mixed_z.view(src.shape[1],src.shape[2],-1).permute(2,0,1) + src_mixer = src_mixer + self.ffn(src_mixer) + src_mixer = self.norm_channel(src_mixer) + + return src_mixer + +class Transformer(nn.Module): + + def __init__(self, config, d_model=512, nhead=8, num_encoder_layers=6, + dim_feedforward=2048, dropout=0.1,activation="relu", normalize_before=False, + num_lidar_points=None,num_proxy_points=None, share_head=True,num_groups=None, + sequence_stride=None,num_frames=None): + super().__init__() + + self.config = config + self.share_head = share_head + self.num_frames = num_frames + self.nhead = nhead + self.sequence_stride = sequence_stride + self.num_groups = num_groups + self.num_proxy_points = num_proxy_points + self.num_lidar_points = num_lidar_points + self.d_model = d_model + self.nhead = nhead + encoder_layer = [TransformerEncoderLayer(self.config, d_model, nhead, dim_feedforward,dropout, activation, + normalize_before, num_lidar_points,num_groups=num_groups) for i in range(num_encoder_layers)] + + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm,self.config) + + self.token = nn.Parameter(torch.zeros(self.num_groups, 1, d_model)) + + + if self.num_frames >4: + + self.group_length = self.num_frames // self.num_groups + self.fusion_all_group = MLP(input_dim = self.config.hidden_dim*self.group_length, + hidden_dim = self.config.hidden_dim, output_dim = self.config.hidden_dim, num_layers = 4) + + self.fusion_norm = FFN(d_model, dim_feedforward) + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src, pos=None): + + BS, N, C = src.shape + if not pos is None: + pos = pos.permute(1, 0, 2) + + if self.num_frames == 16: + token_list = [self.token[i:(i+1)].repeat(BS,1,1) for i in range(self.num_groups)] + if self.sequence_stride ==1: + src_groups = src.view(src.shape[0],src.shape[1]//self.num_groups ,-1).chunk(4,dim=1) + + elif self.sequence_stride ==4: + src_groups = [] + + for i in range(self.num_groups): + groups = [] + for j in range(self.group_length): + points_index_start = (i+j*self.sequence_stride)*self.num_proxy_points + points_index_end = points_index_start + self.num_proxy_points + groups.append(src[:,points_index_start:points_index_end]) + + groups = torch.cat(groups,-1) + src_groups.append(groups) + + else: + raise NotImplementedError + + src_merge = torch.cat(src_groups,1) + src = self.fusion_norm(src[:,:self.num_groups*self.num_proxy_points],self.fusion_all_group(src_merge)) + src = [torch.cat([token_list[i],src[:,i*self.num_proxy_points:(i+1)*self.num_proxy_points]],dim=1) for i in range(self.num_groups)] + src = torch.cat(src,dim=0) + + else: + token_list = [self.token[i:(i+1)].repeat(BS,1,1) for i in range(self.num_groups)] + src = [torch.cat([token_list[i],src[:,i*self.num_proxy_points:(i+1)*self.num_proxy_points]],dim=1) for i in range(self.num_groups)] + src = torch.cat(src,dim=0) + + src = src.permute(1, 0, 2) + memory,tokens = self.encoder(src,pos=pos) + + memory = torch.cat(memory[0:1].chunk(4,dim=1),0) + return memory, tokens + + +class TransformerEncoder(nn.Module): + + def __init__(self, encoder_layer, num_layers, norm=None,config=None): + super().__init__() + self.layers = nn.ModuleList(encoder_layer) + self.num_layers = num_layers + self.norm = norm + self.config = config + + def forward(self, src, + pos: Optional[Tensor] = None): + + token_list = [] + output = src + for layer in self.layers: + output,tokens = layer(output,pos=pos) + token_list.append(tokens) + if self.norm is not None: + output = self.norm(output) + + return output,token_list + + +class TransformerEncoderLayer(nn.Module): + count = 0 + def __init__(self, config, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False,num_points=None,num_groups=None): + super().__init__() + TransformerEncoderLayer.count += 1 + self.layer_count = TransformerEncoderLayer.count + self.config = config + self.num_point = num_points + self.num_groups= num_groups + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + if self.layer_count <= self.config.enc_layers-1: + self.cross_attn_layers = nn.ModuleList() + for _ in range(self.num_groups): + self.cross_attn_layers.append(nn.MultiheadAttention(d_model, nhead, dropout=dropout)) + + self.ffn = FFN(d_model, dim_feedforward) + self.fusion_all_groups = MLP(input_dim = d_model*4, hidden_dim = d_model, output_dim = d_model, num_layers = 4) + + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + self.mlp_mixer_3d = SpatialMixerBlock(self.config.use_mlp_mixer.hidden_dim,self.config.use_mlp_mixer.get('grid_size', 4),self.config.hidden_dim, self.config.use_mlp_mixer) + + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, + src, + pos: Optional[Tensor] = None): + + src_intra_group_fusion = self.mlp_mixer_3d(src[1:]) + src = torch.cat([src[:1],src_intra_group_fusion],0) + + token = src[:1] + + if not pos is None: + key = self.with_pos_embed(src_intra_group_fusion, pos[1:]) + else: + key = src_intra_group_fusion + + src_summary = self.self_attn(token, key, value=src_intra_group_fusion)[0] + token = token + self.dropout1(src_summary) + token = self.norm1(token) + src_summary = self.linear2(self.dropout(self.activation(self.linear1(token)))) + token = token + self.dropout2(src_summary) + token = self.norm2(token) + src = torch.cat([token,src[1:]],0) + + if self.layer_count <= self.config.enc_layers-1: + + src_all_groups = src[1:].view((src.shape[0]-1)*4,-1,src.shape[-1]) + src_groups_list = src_all_groups.chunk(self.num_groups,0) + + src_all_groups = torch.cat(src_groups_list,-1) + src_all_groups_fusion = self.fusion_all_groups(src_all_groups) + + key = self.with_pos_embed(src_all_groups_fusion, pos[1:]) + query_list = [self.with_pos_embed(query, pos[1:]) for query in src_groups_list] + + inter_group_fusion_list = [] + for i in range(self.num_groups): + inter_group_fusion = self.cross_attn_layers[i](query_list[i], key, value=src_all_groups_fusion)[0] + inter_group_fusion = self.ffn(src_groups_list[i],inter_group_fusion) + inter_group_fusion_list.append(inter_group_fusion) + + src_inter_group_fusion = torch.cat(inter_group_fusion_list,1) + + src = torch.cat([src[:1],src_inter_group_fusion],0) + + return src, torch.cat(src[:1].chunk(4,1),0) + + def forward_pre(self, src, + pos: Optional[Tensor] = None): + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn(q, k, value=src2)[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + def forward(self, src, + pos: Optional[Tensor] = None): + + if self.normalize_before: + return self.forward_pre(src, pos) + return self.forward_post(src, pos) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") + + +class FFN(nn.Module): + def __init__(self, d_model, dim_feedforward=2048, dropout=0.1,dout=None, + activation="relu", normalize_before=False): + super().__init__() + + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def forward(self, tgt,tgt_input): + tgt = tgt + self.dropout2(tgt_input) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + + return tgt + +def build_transformer(args): + return Transformer( + config = args, + d_model=args.hidden_dim, + dropout=args.dropout, + nhead=args.nheads, + dim_feedforward=args.dim_feedforward, + num_encoder_layers=args.enc_layers, + normalize_before=args.pre_norm, + num_lidar_points = args.num_lidar_points, + num_proxy_points = args.num_proxy_points, + num_frames = args.num_frames, + sequence_stride = args.get('sequence_stride',1), + num_groups=args.num_groups, + ) +