From 423e6f0ccfcb6366835b325378f678fb787cff1d Mon Sep 17 00:00:00 2001 From: inter Date: Sun, 21 Sep 2025 20:18:55 +0800 Subject: [PATCH] Add File --- .../image_vfe_modules/ffn/ddn/ddn_template.py | 162 ++++++++++++++++++ 1 file changed, 162 insertions(+) create mode 100644 pcdet/models/backbones_3d/vfe/image_vfe_modules/ffn/ddn/ddn_template.py diff --git a/pcdet/models/backbones_3d/vfe/image_vfe_modules/ffn/ddn/ddn_template.py b/pcdet/models/backbones_3d/vfe/image_vfe_modules/ffn/ddn/ddn_template.py new file mode 100644 index 0000000..be110d3 --- /dev/null +++ b/pcdet/models/backbones_3d/vfe/image_vfe_modules/ffn/ddn/ddn_template.py @@ -0,0 +1,162 @@ +from collections import OrderedDict +from pathlib import Path +from torch import hub + +import torch +import torch.nn as nn +import torch.nn.functional as F + +try: + from kornia.enhance.normalize import normalize +except: + pass + # print('Warning: kornia is not installed. This package is only required by CaDDN') + + +class DDNTemplate(nn.Module): + + def __init__(self, constructor, feat_extract_layer, num_classes, pretrained_path=None, aux_loss=None): + """ + Initializes depth distribution network. + Args: + constructor: function, Model constructor + feat_extract_layer: string, Layer to extract features from + num_classes: int, Number of classes + pretrained_path: string, (Optional) Path of the model to load weights from + aux_loss: bool, Flag to include auxillary loss + """ + super().__init__() + self.num_classes = num_classes + self.pretrained_path = pretrained_path + self.pretrained = pretrained_path is not None + self.aux_loss = aux_loss + + if self.pretrained: + # Preprocess Module + self.norm_mean = torch.Tensor([0.485, 0.456, 0.406]) + self.norm_std = torch.Tensor([0.229, 0.224, 0.225]) + + # Model + self.model = self.get_model(constructor=constructor) + self.feat_extract_layer = feat_extract_layer + self.model.backbone.return_layers = { + feat_extract_layer: 'features', + **self.model.backbone.return_layers + } + + def get_model(self, constructor): + """ + Get model + Args: + constructor: function, Model constructor + Returns: + model: nn.Module, Model + """ + # Get model + model = constructor(pretrained=False, + pretrained_backbone=False, + num_classes=self.num_classes, + aux_loss=self.aux_loss) + + # Update weights + if self.pretrained_path is not None: + model_dict = model.state_dict() + + # Download pretrained model if not available yet + checkpoint_path = Path(self.pretrained_path) + if not checkpoint_path.exists(): + checkpoint = checkpoint_path.name + save_dir = checkpoint_path.parent + save_dir.mkdir(parents=True) + url = f'https://download.pytorch.org/models/{checkpoint}' + hub.load_state_dict_from_url(url, save_dir) + + # Get pretrained state dict + pretrained_dict = torch.load(self.pretrained_path) + pretrained_dict = self.filter_pretrained_dict(model_dict=model_dict, + pretrained_dict=pretrained_dict) + + # Update current model state dict + model_dict.update(pretrained_dict) + model.load_state_dict(model_dict) + + return model + + def filter_pretrained_dict(self, model_dict, pretrained_dict): + """ + Removes layers from pretrained state dict that are not used or changed in model + Args: + model_dict: dict, Default model state dictionary + pretrained_dict: dict, Pretrained model state dictionary + Returns: + pretrained_dict: dict, Pretrained model state dictionary with removed weights + """ + # Removes aux classifier weights if not used + if "aux_classifier.0.weight" in pretrained_dict and "aux_classifier.0.weight" not in model_dict: + pretrained_dict = {key: value for key, value in pretrained_dict.items() + if "aux_classifier" not in key} + + # Removes final conv layer from weights if number of classes are different + model_num_classes = model_dict["classifier.4.weight"].shape[0] + pretrained_num_classes = pretrained_dict["classifier.4.weight"].shape[0] + if model_num_classes != pretrained_num_classes: + pretrained_dict.pop("classifier.4.weight") + pretrained_dict.pop("classifier.4.bias") + + return pretrained_dict + + def forward(self, images): + """ + Forward pass + Args: + images: (N, 3, H_in, W_in), Input images + Returns + result: dict[torch.Tensor], Depth distribution result + features: (N, C, H_out, W_out), Image features + logits: (N, num_classes, H_out, W_out), Classification logits + aux: (N, num_classes, H_out, W_out), Auxillary classification logits + """ + # Preprocess images + x = self.preprocess(images) + + # Extract features + result = OrderedDict() + features = self.model.backbone(x) + result['features'] = features['features'] + feat_shape = features['features'].shape[-2:] + + # Prediction classification logits + x = features["out"] + x = self.model.classifier(x) + x = F.interpolate(x, size=feat_shape, mode='bilinear', align_corners=False) + result["logits"] = x + + # Prediction auxillary classification logits + if self.model.aux_classifier is not None: + x = features["aux"] + x = self.model.aux_classifier(x) + x = F.interpolate(x, size=feat_shape, mode='bilinear', align_corners=False) + result["aux"] = x + + return result + + def preprocess(self, images): + """ + Preprocess images + Args: + images: (N, 3, H, W), Input images + Return + x: (N, 3, H, W), Preprocessed images + """ + x = images + if self.pretrained: + # Create a mask for padded pixels + mask = (x == 0) + + # Match ResNet pretrained preprocessing + x = normalize(x, mean=self.norm_mean, std=self.norm_std) + + # Make padded pixels = 0 + x[mask] = 0 + + return x