Source code for starling.models.resnets_original

import torch.nn.functional as F
from torch import nn

from starling.models.blocks import (
    ResBlockDecBasic,
    ResBlockDecBottleneck,
    ResBlockEncBasic,
    ResBlockEncBottleneck,
)


[docs] class ResNet_Encoder_Original(nn.Module):
[docs] def __init__( self, in_channels, num_blocks, kernel_size=None, dimension=None, block_type=ResBlockEncBasic, base=64, ) -> None: super().__init__() self.block_type = block_type # First convolution of the ResNet Encoder reduction in the spatial dimensions / 2 # with kernel=7 and stride=2 AvgPool2d reduces spatial dimensions by / 2 self.in_channels = 64 self.first_conv = nn.Sequential( nn.Conv2d( in_channels=in_channels, out_channels=self.in_channels, kernel_size=7, stride=2, padding=3, ), nn.BatchNorm2d(self.in_channels), nn.MaxPool2d(kernel_size=3, stride=2, padding=1), ) self.layer1 = self._make_layer(self.block_type, 64, num_blocks[0], stride=1) self.layer2 = self._make_layer(self.block_type, 128, num_blocks[1], stride=2) self.layer3 = self._make_layer(self.block_type, 256, num_blocks[2], stride=2) self.layer4 = self._make_layer(self.block_type, 512, num_blocks[3], stride=2) self.average_pool = nn.AdaptiveAvgPool2d((1, 1))
def _make_layer(self, block, out_channels, blocks, stride=1): layers = [] layers.append(block(self.in_channels, out_channels, stride)) self.in_channels = out_channels * block.expansion for _ in range(1, blocks): layers.append(block(self.in_channels, out_channels, stride=1)) return nn.Sequential(*layers)
[docs] def forward(self, data): data = self.first_conv(data) # for layer in self.layers: # data = layer(data) # The final adaptive average can also be done through convolution data = self.layer1(data) data = self.layer2(data) data = self.layer3(data) data = self.layer4(data) data = self.average_pool(data) return data
[docs] class ResNet_Decoder_Original(nn.Module):
[docs] def __init__( self, out_channels, num_blocks, kernel_size, dimension, block_type=ResBlockDecBasic, base=64, ) -> None: super().__init__() # Calculate the input channels from the encoder, assuming # symmetric encoder and decoder setup self.block_type = block_type if self.block_type == ResBlockDecBasic: self.in_channels = 512 else: self.in_channels = 4096 self.interpolate = int(dimension / (2 ** (len(num_blocks) + 1))) # This part can be done in many ways, this is just one of them # It adds some number of parameters # self.resize_conv = ResizeConv2d( # in_channels=self.in_channels, # out_channels=self.in_channels, # kernel_size=kernel_size, # size=(self.interpolate, self.interpolate), # mode="nearest", # ) self.layers = nn.ModuleList() self.layer1 = self._make_layer(self.block_type, 512, num_blocks[0], stride=2) self.layer2 = self._make_layer(self.block_type, 256, num_blocks[1], stride=2) self.layer3 = self._make_layer(self.block_type, 128, num_blocks[2], stride=2) self.layer4 = self._make_layer( self.block_type, 64, num_blocks[3], stride=1, last_layer=True ) # for layer in [self.layer1, self.layer2, self.layer3, self.layer4]: # self.layers.append(layer) # # This part could be done through interpolation (analogous to MaxPool) self.reshaping_conv = nn.Sequential( nn.ConvTranspose2d( in_channels=64, out_channels=64, kernel_size=kernel_size, stride=2, padding=1, output_padding=1, ), nn.BatchNorm2d(64), # nn.LayerNorm([64, int(dimension / 2), int(dimension / 2)]), nn.ReLU(inplace=True), ) # Final output layer that looks similar to the first layer of # the ResNet Encoder self.output_layer = nn.Sequential( nn.ConvTranspose2d( in_channels=64, out_channels=out_channels, kernel_size=7, stride=2, padding=3, output_padding=1, ), nn.ReLU(inplace=True), )
def _make_layer(self, block, out_channels, blocks, stride=1, last_layer=False): layers = [] self.in_channels = out_channels * block.contraction for _ in range(1, blocks): layers.append(block(self.in_channels, out_channels, stride=1)) if stride > 1 and block == ResBlockDecBasic: out_channels = int(out_channels / 2) layers.append( block(self.in_channels, out_channels, stride, last_layer=last_layer) ) return nn.Sequential(*layers)
[docs] def forward(self, data): # data = self.resize_conv(data) data = F.interpolate(data, size=(self.interpolate, self.interpolate)) data = self.layer1(data) data = self.layer2(data) data = self.layer3(data) data = self.layer4(data) data = self.reshaping_conv(data) data = self.output_layer(data) return data
# Current implementations of ResNets
[docs] def Resnet18_Encoder(in_channels, kernel_size, dimension, base): return ResNet_Encoder_Original( block_type=ResBlockEncBasic, in_channels=in_channels, num_blocks=[2, 2, 2, 2], kernel_size=kernel_size, dimension=dimension, base=base, )
[docs] def Resnet18_Decoder(out_channels, kernel_size, dimension, base): return ResNet_Decoder_Original( block_type=ResBlockDecBasic, out_channels=out_channels, num_blocks=[2, 2, 2, 2], kernel_size=kernel_size, dimension=dimension, base=base, )
[docs] def Resnet34_Encoder(in_channels, kernel_size, dimension, base): return ResNet_Encoder_Original( block_type=ResBlockEncBasic, in_channels=in_channels, num_blocks=[3, 4, 6, 3], kernel_size=kernel_size, dimension=dimension, base=base, )
[docs] def Resnet34_Decoder(out_channels, kernel_size, dimension, base): return ResNet_Decoder_Original( block_type=ResBlockDecBasic, out_channels=out_channels, num_blocks=[3, 6, 4, 3], kernel_size=kernel_size, dimension=dimension, base=base, )
[docs] def Resnet50_Encoder(in_channels, kernel_size, dimension, base): return ResNet_Encoder_Original( block_type=ResBlockEncBottleneck, in_channels=in_channels, num_blocks=[3, 4, 6, 3], kernel_size=kernel_size, dimension=dimension, base=base, )
[docs] def Resnet50_Decoder(out_channels, kernel_size, dimension, base): return ResNet_Decoder_Original( block_type=ResBlockDecBottleneck, out_channels=out_channels, num_blocks=[3, 6, 4, 3], kernel_size=kernel_size, dimension=dimension, base=base, )
[docs] def Resnet101_Encoder(in_channels, kernel_size, dimension, base): return ResNet_Encoder_Original( block_type=ResBlockEncBottleneck, in_channels=in_channels, num_blocks=[3, 4, 23, 3], kernel_size=kernel_size, dimension=dimension, base=base, )
[docs] def Resnet101_Decoder(out_channels, kernel_size, dimension, base): return ResNet_Decoder_Original( block_type=ResBlockDecBottleneck, out_channels=out_channels, num_blocks=[3, 23, 4, 3], kernel_size=kernel_size, dimension=dimension, base=base, )
[docs] def Resnet152_Encoder(in_channels, kernel_size, dimension, base): return ResNet_Encoder_Original( block_type=ResBlockEncBottleneck, in_channels=in_channels, num_blocks=[3, 8, 36, 3], kernel_size=kernel_size, dimension=dimension, base=base, )
[docs] def Resnet152_Decoder(out_channels, kernel_size, dimension, base): return ResNet_Decoder_Original( block_type=ResBlockDecBottleneck, out_channels=out_channels, num_blocks=[3, 36, 8, 3], kernel_size=kernel_size, dimension=dimension, base=base, )