Add File
This commit is contained in:
38
pcdet/models/backbones_2d/map_to_bev/conv2d_collapse.py
Normal file
38
pcdet/models/backbones_2d/map_to_bev/conv2d_collapse.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from pcdet.models.model_utils.basic_block_2d import BasicBlock2D
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2DCollapse(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, model_cfg, grid_size):
|
||||||
|
"""
|
||||||
|
Initializes 2D convolution collapse module
|
||||||
|
Args:
|
||||||
|
model_cfg: EasyDict, Model configuration
|
||||||
|
grid_size: (X, Y, Z) Voxel grid size
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.model_cfg = model_cfg
|
||||||
|
self.num_heights = grid_size[-1]
|
||||||
|
self.num_bev_features = self.model_cfg.NUM_BEV_FEATURES
|
||||||
|
self.block = BasicBlock2D(in_channels=self.num_bev_features * self.num_heights,
|
||||||
|
out_channels=self.num_bev_features,
|
||||||
|
**self.model_cfg.ARGS)
|
||||||
|
|
||||||
|
def forward(self, batch_dict):
|
||||||
|
"""
|
||||||
|
Collapses voxel features to BEV via concatenation and channel reduction
|
||||||
|
Args:
|
||||||
|
batch_dict:
|
||||||
|
voxel_features: (B, C, Z, Y, X), Voxel feature representation
|
||||||
|
Returns:
|
||||||
|
batch_dict:
|
||||||
|
spatial_features: (B, C, Y, X), BEV feature representation
|
||||||
|
"""
|
||||||
|
voxel_features = batch_dict["voxel_features"]
|
||||||
|
bev_features = voxel_features.flatten(start_dim=1, end_dim=2) # (B, C, Z, Y, X) -> (B, C*Z, Y, X)
|
||||||
|
bev_features = self.block(bev_features) # (B, C*Z, Y, X) -> (B, C, Y, X)
|
||||||
|
batch_dict["spatial_features"] = bev_features
|
||||||
|
return batch_dict
|
||||||
Reference in New Issue
Block a user