diff --git a/pcdet/models/backbones_2d/map_to_bev/pointpillar_scatter.py b/pcdet/models/backbones_2d/map_to_bev/pointpillar_scatter.py new file mode 100644 index 0000000..c57cda8 --- /dev/null +++ b/pcdet/models/backbones_2d/map_to_bev/pointpillar_scatter.py @@ -0,0 +1,73 @@ +import torch +import torch.nn as nn + + +class PointPillarScatter(nn.Module): + def __init__(self, model_cfg, grid_size, **kwargs): + super().__init__() + + self.model_cfg = model_cfg + self.num_bev_features = self.model_cfg.NUM_BEV_FEATURES + self.nx, self.ny, self.nz = grid_size + assert self.nz == 1 + + def forward(self, batch_dict, **kwargs): + pillar_features, coords = batch_dict['pillar_features'], batch_dict['voxel_coords'] + batch_spatial_features = [] + batch_size = coords[:, 0].max().int().item() + 1 + for batch_idx in range(batch_size): + spatial_feature = torch.zeros( + self.num_bev_features, + self.nz * self.nx * self.ny, + dtype=pillar_features.dtype, + device=pillar_features.device) + + batch_mask = coords[:, 0] == batch_idx + this_coords = coords[batch_mask, :] + indices = this_coords[:, 1] + this_coords[:, 2] * self.nx + this_coords[:, 3] + indices = indices.type(torch.long) + pillars = pillar_features[batch_mask, :] + pillars = pillars.t() + spatial_feature[:, indices] = pillars + batch_spatial_features.append(spatial_feature) + + batch_spatial_features = torch.stack(batch_spatial_features, 0) + batch_spatial_features = batch_spatial_features.view(batch_size, self.num_bev_features * self.nz, self.ny, self.nx) + batch_dict['spatial_features'] = batch_spatial_features + return batch_dict + + +class PointPillarScatter3d(nn.Module): + def __init__(self, model_cfg, grid_size, **kwargs): + super().__init__() + + self.model_cfg = model_cfg + self.nx, self.ny, self.nz = self.model_cfg.INPUT_SHAPE + self.num_bev_features = self.model_cfg.NUM_BEV_FEATURES + self.num_bev_features_before_compression = self.model_cfg.NUM_BEV_FEATURES // self.nz + + def forward(self, batch_dict, **kwargs): + pillar_features, coords = batch_dict['pillar_features'], batch_dict['voxel_coords'] + + batch_spatial_features = [] + batch_size = coords[:, 0].max().int().item() + 1 + for batch_idx in range(batch_size): + spatial_feature = torch.zeros( + self.num_bev_features_before_compression, + self.nz * self.nx * self.ny, + dtype=pillar_features.dtype, + device=pillar_features.device) + + batch_mask = coords[:, 0] == batch_idx + this_coords = coords[batch_mask, :] + indices = this_coords[:, 1] * self.ny * self.nx + this_coords[:, 2] * self.nx + this_coords[:, 3] + indices = indices.type(torch.long) + pillars = pillar_features[batch_mask, :] + pillars = pillars.t() + spatial_feature[:, indices] = pillars + batch_spatial_features.append(spatial_feature) + + batch_spatial_features = torch.stack(batch_spatial_features, 0) + batch_spatial_features = batch_spatial_features.view(batch_size, self.num_bev_features_before_compression * self.nz, self.ny, self.nx) + batch_dict['spatial_features'] = batch_spatial_features + return batch_dict \ No newline at end of file