Add File
This commit is contained in:
@@ -0,0 +1,65 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
class BasicBlock1D(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, **kwargs):
|
||||||
|
"""
|
||||||
|
Initializes convolutional block
|
||||||
|
Args:
|
||||||
|
in_channels: int, Number of input channels
|
||||||
|
out_channels: int, Number of output channels
|
||||||
|
**kwargs: Dict, Extra arguments for nn.Conv2d
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.conv = nn.Conv1d(in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
**kwargs)
|
||||||
|
self.bn = nn.BatchNorm1d(out_channels)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self, features):
|
||||||
|
"""
|
||||||
|
Applies convolutional block
|
||||||
|
Args:
|
||||||
|
features: (B, C_in, H, W), Input features
|
||||||
|
Returns:
|
||||||
|
x: (B, C_out, H, W), Output features
|
||||||
|
"""
|
||||||
|
x = self.conv(features)
|
||||||
|
x = self.bn(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class BasicBlock2D(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, **kwargs):
|
||||||
|
"""
|
||||||
|
Initializes convolutional block
|
||||||
|
Args:
|
||||||
|
in_channels: int, Number of input channels
|
||||||
|
out_channels: int, Number of output channels
|
||||||
|
**kwargs: Dict, Extra arguments for nn.Conv2d
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.conv = nn.Conv2d(in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
**kwargs)
|
||||||
|
self.bn = nn.BatchNorm2d(out_channels)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self, features):
|
||||||
|
"""
|
||||||
|
Applies convolutional block
|
||||||
|
Args:
|
||||||
|
features: (B, C_in, H, W), Input features
|
||||||
|
Returns:
|
||||||
|
x: (B, C_out, H, W), Output features
|
||||||
|
"""
|
||||||
|
x = self.conv(features)
|
||||||
|
x = self.bn(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
return x
|
||||||
Reference in New Issue
Block a user