Add File
This commit is contained in:
33
pcdet/models/backbones_2d/fuser/convfuser.py
Normal file
33
pcdet/models/backbones_2d/fuser/convfuser.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class ConvFuser(nn.Module):
|
||||
def __init__(self,model_cfg) -> None:
|
||||
super().__init__()
|
||||
self.model_cfg = model_cfg
|
||||
in_channel = self.model_cfg.IN_CHANNEL
|
||||
out_channel = self.model_cfg.OUT_CHANNEL
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_channel, out_channel, 3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channel),
|
||||
nn.ReLU(True)
|
||||
)
|
||||
|
||||
def forward(self,batch_dict):
|
||||
"""
|
||||
Args:
|
||||
batch_dict:
|
||||
spatial_features_img (tensor): Bev features from image modality
|
||||
spatial_features (tensor): Bev features from lidar modality
|
||||
|
||||
Returns:
|
||||
batch_dict:
|
||||
spatial_features (tensor): Bev features after muli-modal fusion
|
||||
"""
|
||||
img_bev = batch_dict['spatial_features_img']
|
||||
lidar_bev = batch_dict['spatial_features']
|
||||
cat_bev = torch.cat([img_bev,lidar_bev],dim=1)
|
||||
mm_bev = self.conv(cat_bev)
|
||||
batch_dict['spatial_features'] = mm_bev
|
||||
return batch_dict
|
||||
Reference in New Issue
Block a user