From 12612b236a978246ca326449d476300c6e7b6342 Mon Sep 17 00:00:00 2001 From: inter Date: Sun, 21 Sep 2025 20:18:56 +0800 Subject: [PATCH] Add File --- .../ffn/ddn_loss/balancer.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 pcdet/models/backbones_3d/vfe/image_vfe_modules/ffn/ddn_loss/balancer.py diff --git a/pcdet/models/backbones_3d/vfe/image_vfe_modules/ffn/ddn_loss/balancer.py b/pcdet/models/backbones_3d/vfe/image_vfe_modules/ffn/ddn_loss/balancer.py new file mode 100644 index 0000000..47bf8d4 --- /dev/null +++ b/pcdet/models/backbones_3d/vfe/image_vfe_modules/ffn/ddn_loss/balancer.py @@ -0,0 +1,50 @@ +import torch +import torch.nn as nn + +from pcdet.utils import loss_utils + + +class Balancer(nn.Module): + def __init__(self, fg_weight, bg_weight, downsample_factor=1): + """ + Initialize fixed foreground/background loss balancer + Args: + fg_weight: float, Foreground loss weight + bg_weight: float, Background loss weight + downsample_factor: int, Depth map downsample factor + """ + super().__init__() + self.fg_weight = fg_weight + self.bg_weight = bg_weight + self.downsample_factor = downsample_factor + + def forward(self, loss, gt_boxes2d): + """ + Forward pass + Args: + loss: (B, H, W), Pixel-wise loss + gt_boxes2d: (B, N, 4), 2D box labels for foreground/background balancing + Returns: + loss: (1), Total loss after foreground/background balancing + tb_dict: dict[float], All losses to log in tensorboard + """ + # Compute masks + fg_mask = loss_utils.compute_fg_mask(gt_boxes2d=gt_boxes2d, + shape=loss.shape, + downsample_factor=self.downsample_factor, + device=loss.device) + bg_mask = ~fg_mask + + # Compute balancing weights + weights = self.fg_weight * fg_mask + self.bg_weight * bg_mask + num_pixels = fg_mask.sum() + bg_mask.sum() + + # Compute losses + loss *= weights + fg_loss = loss[fg_mask].sum() / num_pixels + bg_loss = loss[bg_mask].sum() / num_pixels + + # Get total loss + loss = fg_loss + bg_loss + tb_dict = {"balancer_loss": loss.item(), "fg_loss": fg_loss.item(), "bg_loss": bg_loss.item()} + return loss, tb_dict