Add File
This commit is contained in:
@@ -0,0 +1,77 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from .basic_blocks import BasicBlock2D
|
||||||
|
from .sem_deeplabv3 import SemDeepLabV3
|
||||||
|
|
||||||
|
class PyramidFeat2D(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, optimize, model_cfg):
|
||||||
|
"""
|
||||||
|
Initialize 2D feature network via pretrained model
|
||||||
|
Args:
|
||||||
|
model_cfg: EasyDict, Dense classification network config
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.model_cfg = model_cfg
|
||||||
|
self.is_optimize = optimize
|
||||||
|
|
||||||
|
# Create modules
|
||||||
|
self.ifn = SemDeepLabV3(
|
||||||
|
num_classes=model_cfg.num_class,
|
||||||
|
backbone_name=model_cfg.backbone,
|
||||||
|
**model_cfg.args
|
||||||
|
)
|
||||||
|
self.reduce_blocks = torch.nn.ModuleList()
|
||||||
|
self.out_channels = {}
|
||||||
|
for _idx, _channel in enumerate(model_cfg.channel_reduce["in_channels"]):
|
||||||
|
_channel_out = model_cfg.channel_reduce["out_channels"][_idx]
|
||||||
|
self.out_channels[model_cfg.args['feat_extract_layer'][_idx]] = _channel_out
|
||||||
|
block_cfg = {"in_channels": _channel,
|
||||||
|
"out_channels": _channel_out,
|
||||||
|
"kernel_size": model_cfg.channel_reduce["kernel_size"][_idx],
|
||||||
|
"stride": model_cfg.channel_reduce["stride"][_idx],
|
||||||
|
"bias": model_cfg.channel_reduce["bias"][_idx]}
|
||||||
|
self.reduce_blocks.append(BasicBlock2D(**block_cfg))
|
||||||
|
|
||||||
|
def get_output_feature_dim(self):
|
||||||
|
return self.out_channels
|
||||||
|
|
||||||
|
def forward(self, images):
|
||||||
|
"""
|
||||||
|
Predicts depths and creates image depth feature volume using depth distributions
|
||||||
|
Args:
|
||||||
|
images: (N, 3, H_in, W_in), Input images
|
||||||
|
Returns:
|
||||||
|
batch_dict:
|
||||||
|
frustum_features: (N, C, D, H_out, W_out), Image depth features
|
||||||
|
"""
|
||||||
|
# Pixel-wise depth classification
|
||||||
|
batch_dict = {}
|
||||||
|
ifn_result = self.ifn(images)
|
||||||
|
|
||||||
|
for _idx, _layer in enumerate(self.model_cfg.args['feat_extract_layer']):
|
||||||
|
image_features = ifn_result[_layer]
|
||||||
|
# Channel reduce
|
||||||
|
if self.reduce_blocks[_idx] is not None:
|
||||||
|
image_features = self.reduce_blocks[_idx](image_features)
|
||||||
|
|
||||||
|
batch_dict[_layer+"_feat2d"] = image_features
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
# detach feature from graph if not optimize
|
||||||
|
if "logits" in ifn_result:
|
||||||
|
ifn_result["logits"].detach_()
|
||||||
|
if not self.is_optimize:
|
||||||
|
image_features.detach_()
|
||||||
|
|
||||||
|
return batch_dict
|
||||||
|
|
||||||
|
def get_loss(self):
|
||||||
|
"""
|
||||||
|
Gets loss
|
||||||
|
Args:
|
||||||
|
Returns:
|
||||||
|
loss: (1), Network loss
|
||||||
|
tb_dict: dict[float], All losses to log in tensorboard
|
||||||
|
"""
|
||||||
|
return None, None
|
||||||
Reference in New Issue
Block a user