diff --git a/pcdet/models/backbones_3d/vfe/image_vfe_modules/ffn/ddn/ddn_deeplabv3.py b/pcdet/models/backbones_3d/vfe/image_vfe_modules/ffn/ddn/ddn_deeplabv3.py new file mode 100644 index 0000000..76be8ca --- /dev/null +++ b/pcdet/models/backbones_3d/vfe/image_vfe_modules/ffn/ddn/ddn_deeplabv3.py @@ -0,0 +1,24 @@ +from .ddn_template import DDNTemplate + +try: + import torchvision +except: + pass + + +class DDNDeepLabV3(DDNTemplate): + + def __init__(self, backbone_name, **kwargs): + """ + Initializes DDNDeepLabV3 model + Args: + backbone_name: string, ResNet Backbone Name [ResNet50/ResNet101] + """ + if backbone_name == "ResNet50": + constructor = torchvision.models.segmentation.deeplabv3_resnet50 + elif backbone_name == "ResNet101": + constructor = torchvision.models.segmentation.deeplabv3_resnet101 + else: + raise NotImplementedError + + super().__init__(constructor=constructor, **kwargs)