39 lines
1.4 KiB
Python
39 lines
1.4 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
from pcdet.models.model_utils.basic_block_2d import BasicBlock2D
|
|
|
|
|
|
class Conv2DCollapse(nn.Module):
|
|
|
|
def __init__(self, model_cfg, grid_size):
|
|
"""
|
|
Initializes 2D convolution collapse module
|
|
Args:
|
|
model_cfg: EasyDict, Model configuration
|
|
grid_size: (X, Y, Z) Voxel grid size
|
|
"""
|
|
super().__init__()
|
|
self.model_cfg = model_cfg
|
|
self.num_heights = grid_size[-1]
|
|
self.num_bev_features = self.model_cfg.NUM_BEV_FEATURES
|
|
self.block = BasicBlock2D(in_channels=self.num_bev_features * self.num_heights,
|
|
out_channels=self.num_bev_features,
|
|
**self.model_cfg.ARGS)
|
|
|
|
def forward(self, batch_dict):
|
|
"""
|
|
Collapses voxel features to BEV via concatenation and channel reduction
|
|
Args:
|
|
batch_dict:
|
|
voxel_features: (B, C, Z, Y, X), Voxel feature representation
|
|
Returns:
|
|
batch_dict:
|
|
spatial_features: (B, C, Y, X), BEV feature representation
|
|
"""
|
|
voxel_features = batch_dict["voxel_features"]
|
|
bev_features = voxel_features.flatten(start_dim=1, end_dim=2) # (B, C, Z, Y, X) -> (B, C*Z, Y, X)
|
|
bev_features = self.block(bev_features) # (B, C*Z, Y, X) -> (B, C, Y, X)
|
|
batch_dict["spatial_features"] = bev_features
|
|
return batch_dict
|