This commit is contained in:
2025-09-21 20:18:49 +08:00
parent ccd5e09366
commit 57375667fc

View File

@@ -0,0 +1,420 @@
from os import getgrouplist
import torch.nn as nn
import torch
import numpy as np
import torch.nn.functional as F
from typing import Optional, List
from torch import Tensor
from torch.nn.init import xavier_uniform_, zeros_, kaiming_normal_
class PointNetfeat(nn.Module):
def __init__(self, input_dim, x=1,outchannel=512):
super(PointNetfeat, self).__init__()
if outchannel==256:
self.output_channel = 256
else:
self.output_channel = 512 * x
self.conv1 = torch.nn.Conv1d(input_dim, 64 * x, 1)
self.conv2 = torch.nn.Conv1d(64 * x, 128 * x, 1)
self.conv3 = torch.nn.Conv1d(128 * x, 256 * x, 1)
self.conv4 = torch.nn.Conv1d(256 * x, self.output_channel, 1)
self.bn1 = nn.BatchNorm1d(64 * x)
self.bn2 = nn.BatchNorm1d(128 * x)
self.bn3 = nn.BatchNorm1d(256 * x)
self.bn4 = nn.BatchNorm1d(self.output_channel)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
x_ori = self.bn4(self.conv4(x))
x = torch.max(x_ori, 2, keepdim=True)[0]
x = x.view(-1, self.output_channel)
return x, x_ori
class PointNet(nn.Module):
def __init__(self, input_dim, joint_feat=False,model_cfg=None):
super(PointNet, self).__init__()
self.joint_feat = joint_feat
channels = model_cfg.TRANS_INPUT
times=1
self.feat = PointNetfeat(input_dim, 1)
self.fc1 = nn.Linear(512, 256 )
self.fc2 = nn.Linear(256, channels)
self.pre_bn = nn.BatchNorm1d(input_dim)
self.bn1 = nn.BatchNorm1d(256)
self.bn2 = nn.BatchNorm1d(channels)
self.relu = nn.ReLU()
self.fc_s1 = nn.Linear(channels*times, 256)
self.fc_s2 = nn.Linear(256, 3, bias=False)
self.fc_ce1 = nn.Linear(channels*times, 256)
self.fc_ce2 = nn.Linear(256, 3, bias=False)
self.fc_hr1 = nn.Linear(channels*times, 256)
self.fc_hr2 = nn.Linear(256, 1, bias=False)
def forward(self, x, feat=None):
if self.joint_feat:
if len(feat.shape) > 2:
feat = torch.max(feat, 2, keepdim=True)[0]
x = feat.view(-1, self.output_channel)
x = F.relu(self.bn1(self.fc1(x)))
feat = F.relu(self.bn2(self.fc2(x)))
else:
feat = feat
feat_traj = None
else:
x, feat_traj = self.feat(self.pre_bn(x))
x = F.relu(self.bn1(self.fc1(x)))
feat = F.relu(self.bn2(self.fc2(x)))
x = F.relu(self.fc_ce1(feat))
centers = self.fc_ce2(x)
x = F.relu(self.fc_s1(feat))
sizes = self.fc_s2(x)
x = F.relu(self.fc_hr1(feat))
headings = self.fc_hr2(x)
return torch.cat([centers, sizes, headings],-1),feat,feat_traj
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
kaiming_normal_(m.weight.data)
if m.bias is not None:
zeros_(m.bias)
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
class SpatialMixerBlock(nn.Module):
def __init__(self,hidden_dim,grid_size,channels,config=None,dropout=0.0):
super().__init__()
self.mixer_x = MLP(input_dim = grid_size, hidden_dim = hidden_dim, output_dim = grid_size, num_layers = 3)
self.mixer_y = MLP(input_dim = grid_size, hidden_dim = hidden_dim, output_dim = grid_size, num_layers = 3)
self.mixer_z = MLP(input_dim = grid_size, hidden_dim = hidden_dim, output_dim = grid_size, num_layers = 3)
self.norm_x = nn.LayerNorm(channels)
self.norm_y = nn.LayerNorm(channels)
self.norm_z = nn.LayerNorm(channels)
self.norm_channel = nn.LayerNorm(channels)
self.ffn = nn.Sequential(
nn.Linear(channels, 2*channels),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(2*channels, channels),
)
self.config = config
self.grid_size = grid_size
def forward(self, src):
src_3d = src.permute(1,2,0).contiguous().view(src.shape[1],src.shape[2],
self.grid_size,self.grid_size,self.grid_size)
src_3d = src_3d.permute(0,1,4,3,2).contiguous()
mixed_x = self.mixer_x(src_3d)
mixed_x = src_3d + mixed_x
mixed_x = self.norm_x(mixed_x.permute(0,2,3,4,1)).permute(0,4,1,2,3).contiguous()
mixed_y = self.mixer_y(mixed_x.permute(0,1,2,4,3)).permute(0,1,2,4,3).contiguous()
mixed_y = mixed_x + mixed_y
mixed_y = self.norm_y(mixed_y.permute(0,2,3,4,1)).permute(0,4,1,2,3).contiguous()
mixed_z = self.mixer_z(mixed_y.permute(0,1,4,3,2)).permute(0,1,4,3,2).contiguous()
mixed_z = mixed_y + mixed_z
mixed_z = self.norm_z(mixed_z.permute(0,2,3,4,1)).permute(0,4,1,2,3).contiguous()
src_mixer = mixed_z.view(src.shape[1],src.shape[2],-1).permute(2,0,1)
src_mixer = src_mixer + self.ffn(src_mixer)
src_mixer = self.norm_channel(src_mixer)
return src_mixer
class Transformer(nn.Module):
def __init__(self, config, d_model=512, nhead=8, num_encoder_layers=6,
dim_feedforward=2048, dropout=0.1,activation="relu", normalize_before=False,
num_lidar_points=None,num_proxy_points=None, share_head=True,num_groups=None,
sequence_stride=None,num_frames=None):
super().__init__()
self.config = config
self.share_head = share_head
self.num_frames = num_frames
self.nhead = nhead
self.sequence_stride = sequence_stride
self.num_groups = num_groups
self.num_proxy_points = num_proxy_points
self.num_lidar_points = num_lidar_points
self.d_model = d_model
self.nhead = nhead
encoder_layer = [TransformerEncoderLayer(self.config, d_model, nhead, dim_feedforward,dropout, activation,
normalize_before, num_lidar_points,num_groups=num_groups) for i in range(num_encoder_layers)]
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm,self.config)
self.token = nn.Parameter(torch.zeros(self.num_groups, 1, d_model))
if self.num_frames >4:
self.group_length = self.num_frames // self.num_groups
self.fusion_all_group = MLP(input_dim = self.config.hidden_dim*self.group_length,
hidden_dim = self.config.hidden_dim, output_dim = self.config.hidden_dim, num_layers = 4)
self.fusion_norm = FFN(d_model, dim_feedforward)
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, src, pos=None):
BS, N, C = src.shape
if not pos is None:
pos = pos.permute(1, 0, 2)
if self.num_frames == 16:
token_list = [self.token[i:(i+1)].repeat(BS,1,1) for i in range(self.num_groups)]
if self.sequence_stride ==1:
src_groups = src.view(src.shape[0],src.shape[1]//self.num_groups ,-1).chunk(4,dim=1)
elif self.sequence_stride ==4:
src_groups = []
for i in range(self.num_groups):
groups = []
for j in range(self.group_length):
points_index_start = (i+j*self.sequence_stride)*self.num_proxy_points
points_index_end = points_index_start + self.num_proxy_points
groups.append(src[:,points_index_start:points_index_end])
groups = torch.cat(groups,-1)
src_groups.append(groups)
else:
raise NotImplementedError
src_merge = torch.cat(src_groups,1)
src = self.fusion_norm(src[:,:self.num_groups*self.num_proxy_points],self.fusion_all_group(src_merge))
src = [torch.cat([token_list[i],src[:,i*self.num_proxy_points:(i+1)*self.num_proxy_points]],dim=1) for i in range(self.num_groups)]
src = torch.cat(src,dim=0)
else:
token_list = [self.token[i:(i+1)].repeat(BS,1,1) for i in range(self.num_groups)]
src = [torch.cat([token_list[i],src[:,i*self.num_proxy_points:(i+1)*self.num_proxy_points]],dim=1) for i in range(self.num_groups)]
src = torch.cat(src,dim=0)
src = src.permute(1, 0, 2)
memory,tokens = self.encoder(src,pos=pos)
memory = torch.cat(memory[0:1].chunk(4,dim=1),0)
return memory, tokens
class TransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers, norm=None,config=None):
super().__init__()
self.layers = nn.ModuleList(encoder_layer)
self.num_layers = num_layers
self.norm = norm
self.config = config
def forward(self, src,
pos: Optional[Tensor] = None):
token_list = []
output = src
for layer in self.layers:
output,tokens = layer(output,pos=pos)
token_list.append(tokens)
if self.norm is not None:
output = self.norm(output)
return output,token_list
class TransformerEncoderLayer(nn.Module):
count = 0
def __init__(self, config, d_model, nhead, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=False,num_points=None,num_groups=None):
super().__init__()
TransformerEncoderLayer.count += 1
self.layer_count = TransformerEncoderLayer.count
self.config = config
self.num_point = num_points
self.num_groups= num_groups
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
if self.layer_count <= self.config.enc_layers-1:
self.cross_attn_layers = nn.ModuleList()
for _ in range(self.num_groups):
self.cross_attn_layers.append(nn.MultiheadAttention(d_model, nhead, dropout=dropout))
self.ffn = FFN(d_model, dim_feedforward)
self.fusion_all_groups = MLP(input_dim = d_model*4, hidden_dim = d_model, output_dim = d_model, num_layers = 4)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
self.mlp_mixer_3d = SpatialMixerBlock(self.config.use_mlp_mixer.hidden_dim,self.config.use_mlp_mixer.get('grid_size', 4),self.config.hidden_dim, self.config.use_mlp_mixer)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self,
src,
pos: Optional[Tensor] = None):
src_intra_group_fusion = self.mlp_mixer_3d(src[1:])
src = torch.cat([src[:1],src_intra_group_fusion],0)
token = src[:1]
if not pos is None:
key = self.with_pos_embed(src_intra_group_fusion, pos[1:])
else:
key = src_intra_group_fusion
src_summary = self.self_attn(token, key, value=src_intra_group_fusion)[0]
token = token + self.dropout1(src_summary)
token = self.norm1(token)
src_summary = self.linear2(self.dropout(self.activation(self.linear1(token))))
token = token + self.dropout2(src_summary)
token = self.norm2(token)
src = torch.cat([token,src[1:]],0)
if self.layer_count <= self.config.enc_layers-1:
src_all_groups = src[1:].view((src.shape[0]-1)*4,-1,src.shape[-1])
src_groups_list = src_all_groups.chunk(self.num_groups,0)
src_all_groups = torch.cat(src_groups_list,-1)
src_all_groups_fusion = self.fusion_all_groups(src_all_groups)
key = self.with_pos_embed(src_all_groups_fusion, pos[1:])
query_list = [self.with_pos_embed(query, pos[1:]) for query in src_groups_list]
inter_group_fusion_list = []
for i in range(self.num_groups):
inter_group_fusion = self.cross_attn_layers[i](query_list[i], key, value=src_all_groups_fusion)[0]
inter_group_fusion = self.ffn(src_groups_list[i],inter_group_fusion)
inter_group_fusion_list.append(inter_group_fusion)
src_inter_group_fusion = torch.cat(inter_group_fusion_list,1)
src = torch.cat([src[:1],src_inter_group_fusion],0)
return src, torch.cat(src[:1].chunk(4,1),0)
def forward_pre(self, src,
pos: Optional[Tensor] = None):
src2 = self.norm1(src)
q = k = self.with_pos_embed(src2, pos)
src2 = self.self_attn(q, k, value=src2)[0]
src = src + self.dropout1(src2)
src2 = self.norm2(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
src = src + self.dropout2(src2)
return src
def forward(self, src,
pos: Optional[Tensor] = None):
if self.normalize_before:
return self.forward_pre(src, pos)
return self.forward_post(src, pos)
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
class FFN(nn.Module):
def __init__(self, d_model, dim_feedforward=2048, dropout=0.1,dout=None,
activation="relu", normalize_before=False):
super().__init__()
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def forward(self, tgt,tgt_input):
tgt = tgt + self.dropout2(tgt_input)
tgt = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
def build_transformer(args):
return Transformer(
config = args,
d_model=args.hidden_dim,
dropout=args.dropout,
nhead=args.nheads,
dim_feedforward=args.dim_feedforward,
num_encoder_layers=args.enc_layers,
normalize_before=args.pre_norm,
num_lidar_points = args.num_lidar_points,
num_proxy_points = args.num_proxy_points,
num_frames = args.num_frames,
sequence_stride = args.get('sequence_stride',1),
num_groups=args.num_groups,
)