Add File
This commit is contained in:
@@ -0,0 +1,75 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
from .balancer import Balancer
|
||||
from pcdet.utils import transform_utils
|
||||
|
||||
try:
|
||||
from kornia.losses.focal import FocalLoss
|
||||
except:
|
||||
pass
|
||||
# print('Warning: kornia is not installed. This package is only required by CaDDN')
|
||||
|
||||
|
||||
class DDNLoss(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
weight,
|
||||
alpha,
|
||||
gamma,
|
||||
disc_cfg,
|
||||
fg_weight,
|
||||
bg_weight,
|
||||
downsample_factor):
|
||||
"""
|
||||
Initializes DDNLoss module
|
||||
Args:
|
||||
weight: float, Loss function weight
|
||||
alpha: float, Alpha value for Focal Loss
|
||||
gamma: float, Gamma value for Focal Loss
|
||||
disc_cfg: dict, Depth discretiziation configuration
|
||||
fg_weight: float, Foreground loss weight
|
||||
bg_weight: float, Background loss weight
|
||||
downsample_factor: int, Depth map downsample factor
|
||||
"""
|
||||
super().__init__()
|
||||
self.device = torch.cuda.current_device()
|
||||
self.disc_cfg = disc_cfg
|
||||
self.balancer = Balancer(downsample_factor=downsample_factor,
|
||||
fg_weight=fg_weight,
|
||||
bg_weight=bg_weight)
|
||||
|
||||
# Set loss function
|
||||
self.alpha = alpha
|
||||
self.gamma = gamma
|
||||
self.loss_func = FocalLoss(alpha=self.alpha, gamma=self.gamma, reduction="none")
|
||||
self.weight = weight
|
||||
|
||||
def forward(self, depth_logits, depth_maps, gt_boxes2d):
|
||||
"""
|
||||
Gets DDN loss
|
||||
Args:
|
||||
depth_logits: (B, D+1, H, W), Predicted depth logits
|
||||
depth_maps: (B, H, W), Depth map [m]
|
||||
gt_boxes2d: torch.Tensor (B, N, 4), 2D box labels for foreground/background balancing
|
||||
Returns:
|
||||
loss: (1), Depth distribution network loss
|
||||
tb_dict: dict[float], All losses to log in tensorboard
|
||||
"""
|
||||
tb_dict = {}
|
||||
|
||||
# Bin depth map to create target
|
||||
depth_target = transform_utils.bin_depths(depth_maps, **self.disc_cfg, target=True)
|
||||
|
||||
# Compute loss
|
||||
loss = self.loss_func(depth_logits, depth_target)
|
||||
|
||||
# Compute foreground/background balancing
|
||||
loss, tb_dict = self.balancer(loss=loss, gt_boxes2d=gt_boxes2d)
|
||||
|
||||
# Final loss
|
||||
loss *= self.weight
|
||||
tb_dict.update({"ddn_loss": loss.item()})
|
||||
|
||||
return loss, tb_dict
|
||||
Reference in New Issue
Block a user