Add File
This commit is contained in:
33
pcdet/models/backbones_2d/fuser/convfuser.py
Normal file
33
pcdet/models/backbones_2d/fuser/convfuser.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class ConvFuser(nn.Module):
|
||||||
|
def __init__(self,model_cfg) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.model_cfg = model_cfg
|
||||||
|
in_channel = self.model_cfg.IN_CHANNEL
|
||||||
|
out_channel = self.model_cfg.OUT_CHANNEL
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channel, out_channel, 3, padding=1, bias=False),
|
||||||
|
nn.BatchNorm2d(out_channel),
|
||||||
|
nn.ReLU(True)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self,batch_dict):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
batch_dict:
|
||||||
|
spatial_features_img (tensor): Bev features from image modality
|
||||||
|
spatial_features (tensor): Bev features from lidar modality
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
batch_dict:
|
||||||
|
spatial_features (tensor): Bev features after muli-modal fusion
|
||||||
|
"""
|
||||||
|
img_bev = batch_dict['spatial_features_img']
|
||||||
|
lidar_bev = batch_dict['spatial_features']
|
||||||
|
cat_bev = torch.cat([img_bev,lidar_bev],dim=1)
|
||||||
|
mm_bev = self.conv(cat_bev)
|
||||||
|
batch_dict['spatial_features'] = mm_bev
|
||||||
|
return batch_dict
|
||||||
Reference in New Issue
Block a user