diff --git a/pcdet/models/model_utils/transfusion_utils.py b/pcdet/models/model_utils/transfusion_utils.py new file mode 100644 index 0000000..677827c --- /dev/null +++ b/pcdet/models/model_utils/transfusion_utils.py @@ -0,0 +1,102 @@ +import torch +from torch import nn +import torch.nn.functional as F + +def clip_sigmoid(x, eps=1e-4): + y = torch.clamp(x.sigmoid_(), min=eps, max=1 - eps) + return y + + +class PositionEmbeddingLearned(nn.Module): + """ + Absolute pos embedding, learned. + """ + + def __init__(self, input_channel, num_pos_feats=288): + super().__init__() + self.position_embedding_head = nn.Sequential( + nn.Conv1d(input_channel, num_pos_feats, kernel_size=1), + nn.BatchNorm1d(num_pos_feats), + nn.ReLU(inplace=True), + nn.Conv1d(num_pos_feats, num_pos_feats, kernel_size=1)) + + def forward(self, xyz): + xyz = xyz.transpose(1, 2).contiguous() + position_embedding = self.position_embedding_head(xyz) + return position_embedding + + +class TransformerDecoderLayer(nn.Module): + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", + self_posembed=None, cross_posembed=None, cross_only=False): + super().__init__() + self.cross_only = cross_only + if not self.cross_only: + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") + + self.activation = _get_activation_fn(activation) + + self.self_posembed = self_posembed + self.cross_posembed = cross_posembed + + def with_pos_embed(self, tensor, pos_embed): + return tensor if pos_embed is None else tensor + pos_embed + + def forward(self, query, key, query_pos, key_pos, key_padding_mask=None, attn_mask=None): + # NxCxP to PxNxC + if self.self_posembed is not None: + query_pos_embed = self.self_posembed(query_pos).permute(2, 0, 1) + else: + query_pos_embed = None + if self.cross_posembed is not None: + key_pos_embed = self.cross_posembed(key_pos).permute(2, 0, 1) + else: + key_pos_embed = None + + query = query.permute(2, 0, 1) + key = key.permute(2, 0, 1) + + if not self.cross_only: + q = k = v = self.with_pos_embed(query, query_pos_embed) + query2 = self.self_attn(q, k, value=v)[0] + query = query + self.dropout1(query2) + query = self.norm1(query) + + query2 = self.multihead_attn(query=self.with_pos_embed(query, query_pos_embed), + key=self.with_pos_embed(key, key_pos_embed), + value=self.with_pos_embed(key, key_pos_embed), + key_padding_mask=key_padding_mask, attn_mask=attn_mask)[0] + + query = query + self.dropout2(query2) + query = self.norm2(query) + + query2 = self.linear2(self.dropout(self.activation(self.linear1(query)))) + query = query + self.dropout3(query2) + query = self.norm3(query) + + # NxCxP to PxNxC + query = query.permute(1, 2, 0) + return query +