This commit is contained in:
2025-09-21 20:18:41 +08:00
parent df3e845e55
commit aa97613265

View File

@@ -0,0 +1,33 @@
import torch
from torch import nn
class ConvFuser(nn.Module):
def __init__(self,model_cfg) -> None:
super().__init__()
self.model_cfg = model_cfg
in_channel = self.model_cfg.IN_CHANNEL
out_channel = self.model_cfg.OUT_CHANNEL
self.conv = nn.Sequential(
nn.Conv2d(in_channel, out_channel, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channel),
nn.ReLU(True)
)
def forward(self,batch_dict):
"""
Args:
batch_dict:
spatial_features_img (tensor): Bev features from image modality
spatial_features (tensor): Bev features from lidar modality
Returns:
batch_dict:
spatial_features (tensor): Bev features after muli-modal fusion
"""
img_bev = batch_dict['spatial_features_img']
lidar_bev = batch_dict['spatial_features']
cat_bev = torch.cat([img_bev,lidar_bev],dim=1)
mm_bev = self.conv(cat_bev)
batch_dict['spatial_features'] = mm_bev
return batch_dict