diff --git a/pcdet/models/backbones_3d/vfe/image_vfe.py b/pcdet/models/backbones_3d/vfe/image_vfe.py new file mode 100644 index 0000000..f626986 --- /dev/null +++ b/pcdet/models/backbones_3d/vfe/image_vfe.py @@ -0,0 +1,85 @@ +import torch + +from .vfe_template import VFETemplate +from .image_vfe_modules import ffn, f2v + + +class ImageVFE(VFETemplate): + def __init__(self, model_cfg, grid_size, point_cloud_range, depth_downsample_factor, **kwargs): + super().__init__(model_cfg=model_cfg) + self.grid_size = grid_size + self.pc_range = point_cloud_range + self.downsample_factor = depth_downsample_factor + self.module_topology = [ + 'ffn', 'f2v' + ] + self.build_modules() + + def build_modules(self): + """ + Builds modules + """ + for module_name in self.module_topology: + module = getattr(self, 'build_%s' % module_name)() + self.add_module(module_name, module) + + def build_ffn(self): + """ + Builds frustum feature network + Returns: + ffn_module: nn.Module, Frustum feature network + """ + ffn_module = ffn.__all__[self.model_cfg.FFN.NAME]( + model_cfg=self.model_cfg.FFN, + downsample_factor=self.downsample_factor + ) + self.disc_cfg = ffn_module.disc_cfg + return ffn_module + + def build_f2v(self): + """ + Builds frustum to voxel transformation + Returns: + f2v_module: nn.Module, Frustum to voxel transformation + """ + f2v_module = f2v.__all__[self.model_cfg.F2V.NAME]( + model_cfg=self.model_cfg.F2V, + grid_size=self.grid_size, + pc_range=self.pc_range, + disc_cfg=self.disc_cfg + ) + return f2v_module + + def get_output_feature_dim(self): + """ + Gets number of output channels + Returns: + out_feature_dim: int, Number of output channels + """ + out_feature_dim = self.ffn.get_output_feature_dim() + return out_feature_dim + + def forward(self, batch_dict, **kwargs): + """ + Args: + batch_dict: + images: (N, 3, H_in, W_in), Input images + **kwargs: + Returns: + batch_dict: + voxel_features: (B, C, Z, Y, X), Image voxel features + """ + batch_dict = self.ffn(batch_dict) + batch_dict = self.f2v(batch_dict) + return batch_dict + + def get_loss(self): + """ + Gets DDN loss + Returns: + loss: (1), Depth distribution network loss + tb_dict: dict[float], All losses to log in tensorboard + """ + + loss, tb_dict = self.ffn.get_loss() + return loss, tb_dict