This commit is contained in:
2025-09-21 20:18:41 +08:00
parent 6a6602a14c
commit df3e845e55

View File

@@ -0,0 +1,26 @@
import torch.nn as nn
class HeightCompression(nn.Module):
def __init__(self, model_cfg, **kwargs):
super().__init__()
self.model_cfg = model_cfg
self.num_bev_features = self.model_cfg.NUM_BEV_FEATURES
def forward(self, batch_dict):
"""
Args:
batch_dict:
encoded_spconv_tensor: sparse tensor
Returns:
batch_dict:
spatial_features:
"""
encoded_spconv_tensor = batch_dict['encoded_spconv_tensor']
spatial_features = encoded_spconv_tensor.dense()
N, C, D, H, W = spatial_features.shape
spatial_features = spatial_features.view(N, C * D, H, W)
batch_dict['spatial_features'] = spatial_features
batch_dict['spatial_features_stride'] = batch_dict['encoded_spconv_tensor_stride']
return batch_dict