diff --git a/pcdet/models/model_utils/basic_block_2d.py b/pcdet/models/model_utils/basic_block_2d.py new file mode 100644 index 0000000..f285eb5 --- /dev/null +++ b/pcdet/models/model_utils/basic_block_2d.py @@ -0,0 +1,34 @@ +import torch.nn as nn + + +class BasicBlock2D(nn.Module): + + def __init__(self, in_channels, out_channels, **kwargs): + """ + Initializes convolutional block + Args: + in_channels: int, Number of input channels + out_channels: int, Number of output channels + **kwargs: Dict, Extra arguments for nn.Conv2d + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.conv = nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + **kwargs) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + + def forward(self, features): + """ + Applies convolutional block + Args: + features: (B, C_in, H, W), Input features + Returns: + x: (B, C_out, H, W), Output features + """ + x = self.conv(features) + x = self.bn(x) + x = self.relu(x) + return x