This commit is contained in:
2025-09-21 20:18:53 +08:00
parent 67d096d580
commit 05830e05a6

View File

@@ -0,0 +1,77 @@
import torch
import torch.nn as nn
from .basic_blocks import BasicBlock2D
from .sem_deeplabv3 import SemDeepLabV3
class PyramidFeat2D(nn.Module):
def __init__(self, optimize, model_cfg):
"""
Initialize 2D feature network via pretrained model
Args:
model_cfg: EasyDict, Dense classification network config
"""
super().__init__()
self.model_cfg = model_cfg
self.is_optimize = optimize
# Create modules
self.ifn = SemDeepLabV3(
num_classes=model_cfg.num_class,
backbone_name=model_cfg.backbone,
**model_cfg.args
)
self.reduce_blocks = torch.nn.ModuleList()
self.out_channels = {}
for _idx, _channel in enumerate(model_cfg.channel_reduce["in_channels"]):
_channel_out = model_cfg.channel_reduce["out_channels"][_idx]
self.out_channels[model_cfg.args['feat_extract_layer'][_idx]] = _channel_out
block_cfg = {"in_channels": _channel,
"out_channels": _channel_out,
"kernel_size": model_cfg.channel_reduce["kernel_size"][_idx],
"stride": model_cfg.channel_reduce["stride"][_idx],
"bias": model_cfg.channel_reduce["bias"][_idx]}
self.reduce_blocks.append(BasicBlock2D(**block_cfg))
def get_output_feature_dim(self):
return self.out_channels
def forward(self, images):
"""
Predicts depths and creates image depth feature volume using depth distributions
Args:
images: (N, 3, H_in, W_in), Input images
Returns:
batch_dict:
frustum_features: (N, C, D, H_out, W_out), Image depth features
"""
# Pixel-wise depth classification
batch_dict = {}
ifn_result = self.ifn(images)
for _idx, _layer in enumerate(self.model_cfg.args['feat_extract_layer']):
image_features = ifn_result[_layer]
# Channel reduce
if self.reduce_blocks[_idx] is not None:
image_features = self.reduce_blocks[_idx](image_features)
batch_dict[_layer+"_feat2d"] = image_features
if self.training:
# detach feature from graph if not optimize
if "logits" in ifn_result:
ifn_result["logits"].detach_()
if not self.is_optimize:
image_features.detach_()
return batch_dict
def get_loss(self):
"""
Gets loss
Args:
Returns:
loss: (1), Network loss
tb_dict: dict[float], All losses to log in tensorboard
"""
return None, None