This commit is contained in:
2025-09-21 20:18:57 +08:00
parent c9e2842324
commit 77e95b3ac1

View File

@@ -0,0 +1,54 @@
import torch
import torch.nn as nn
from .frustum_grid_generator import FrustumGridGenerator
from .sampler import Sampler
class FrustumToVoxel(nn.Module):
def __init__(self, model_cfg, grid_size, pc_range, disc_cfg):
"""
Initializes module to transform frustum features to voxel features via 3D transformation and sampling
Args:
model_cfg: EasyDict, Module configuration
grid_size: [X, Y, Z], Voxel grid size
pc_range: [x_min, y_min, z_min, x_max, y_max, z_max], Voxelization point cloud range (m)
disc_cfg: EasyDict, Depth discretiziation configuration
"""
super().__init__()
self.model_cfg = model_cfg
self.grid_size = grid_size
self.pc_range = pc_range
self.disc_cfg = disc_cfg
self.grid_generator = FrustumGridGenerator(grid_size=grid_size,
pc_range=pc_range,
disc_cfg=disc_cfg)
self.sampler = Sampler(**model_cfg.SAMPLER)
def forward(self, batch_dict):
"""
Generates voxel features via 3D transformation and sampling
Args:
batch_dict:
frustum_features: (B, C, D, H_image, W_image), Image frustum features
lidar_to_cam: (B, 4, 4), LiDAR to camera frame transformation
cam_to_img: (B, 3, 4), Camera projection matrix
image_shape: (B, 2), Image shape [H, W]
Returns:
batch_dict:
voxel_features: (B, C, Z, Y, X), Image voxel features
"""
# Generate sampling grid for frustum volume
grid = self.grid_generator(lidar_to_cam=batch_dict["trans_lidar_to_cam"],
cam_to_img=batch_dict["trans_cam_to_img"],
image_shape=batch_dict["image_shape"]) # (B, X, Y, Z, 3)
# Sample frustum volume to generate voxel volume
voxel_features = self.sampler(input_features=batch_dict["frustum_features"],
grid=grid) # (B, C, X, Y, Z)
# (B, C, X, Y, Z) -> (B, C, Z, Y, X)
voxel_features = voxel_features.permute(0, 1, 4, 3, 2)
batch_dict["voxel_features"] = voxel_features
return batch_dict