Source code for starling.models.vae_components

from typing import List

from torch import nn

from starling.models.blocks import (
    LayerNorm,
    ResBlockDecBasic,
    ResBlockEncBasic,
    ResizeConv2d,
)


[docs] class ResNet_Encoder(nn.Module):
[docs] def __init__( self, in_channels, num_blocks, norm, base=64, block_type=ResBlockEncBasic, ) -> None: super().__init__() self.block_type = block_type self.norm = norm normalization = { "batch": nn.BatchNorm2d, "instance": nn.InstanceNorm2d, "layer": LayerNorm, "group": nn.GroupNorm, } # First convolution of the ResNet Encoder reduction in the spatial dimensions / 2 # with kernel=7 and stride=2 AvgPool2d reduces spatial dimensions by / 2 self.first_conv = nn.Sequential( nn.Conv2d( in_channels=in_channels, out_channels=base, kernel_size=7, stride=2, padding=3, ), normalization[norm](base) if norm != "group" else normalization[norm](32, base), ) self.in_channels = base layer_in_channels = [base * (2**i) for i in range(len(num_blocks))] # Setting up the layers for the encoder self.layer1 = self._make_layer( self.block_type, layer_in_channels[0], num_blocks[0], stride=1 ) self.layer2 = self._make_layer( self.block_type, layer_in_channels[1], num_blocks[1], stride=2 ) self.layer3 = self._make_layer( self.block_type, layer_in_channels[2], num_blocks[2], stride=2 ) self.layer4 = self._make_layer( self.block_type, layer_in_channels[3], num_blocks[3], stride=2 )
def _make_layer(self, block, out_channels, blocks, stride=1): layers = nn.ModuleList() # layers = [] layers.append( block( self.in_channels, out_channels, stride, norm=self.norm, ) ) self.in_channels = out_channels * block.expansion for _ in range(1, blocks): layers.append( block( self.in_channels, out_channels, stride=1, norm=self.norm, ) ) return layers # return nn.Sequential(*layers)
[docs] def forward(self, data): data = self.first_conv(data) for block in self.layer1: data = block(data) for block in self.layer2: data = block(data) for block in self.layer3: data = block(data) for block in self.layer4: data = block(data) return data
[docs] class ResNet_Decoder(nn.Module):
[docs] def __init__( self, out_channels: int, num_blocks: List, dimension: int, norm: str, block_type=ResBlockDecBasic, base=64, ) -> None: super().__init__() self.norm = norm # Calculate the input channels from the encoder, assuming # symmetric encoder and decoder setup self.block_type = block_type if self.block_type == ResBlockDecBasic: layer_in_channels = [base * (2**i) for i in range(len(num_blocks))] self.in_channels = layer_in_channels[-1] else: layer_in_channels = [base * (4**i) for i in range(len(num_blocks))] self.in_channels = layer_in_channels[-1] # Setting up the layers for the decoder self.layer1 = self._make_layer( self.block_type, layer_in_channels[-1], num_blocks[0], stride=2 ) self.layer2 = self._make_layer( self.block_type, layer_in_channels[-2], num_blocks[1], stride=2 ) self.layer3 = self._make_layer( self.block_type, layer_in_channels[-3], num_blocks[2], stride=2 ) self.layer4 = self._make_layer( self.block_type, layer_in_channels[-4], num_blocks[3], stride=1, last_layer=True, ) in_channels_post_resnets = layer_in_channels[-4] self.output_layer = ResizeConv2d( in_channels=in_channels_post_resnets, out_channels=out_channels, kernel_size=7, padding=3, norm=None, activation="relu", scale_factor=2, )
def _make_layer(self, block, out_channels, blocks, stride=1, last_layer=False): layers = nn.ModuleList() self.in_channels = out_channels * block.contraction for _ in range(1, blocks): layers.append( block( self.in_channels, out_channels, stride=1, norm=self.norm, ) ) if stride > 1 and block == ResBlockDecBasic: out_channels = int(out_channels / 2) layers.append( block( self.in_channels, out_channels, stride, last_layer=last_layer, norm=self.norm, ) ) return layers
[docs] def forward(self, data): for block in self.layer1: data = block(data) for block in self.layer2: data = block(data) for block in self.layer3: data = block(data) for block in self.layer4: data = block(data) data = self.output_layer(data) return data
[docs] class ConditionalSequential(nn.Sequential):
[docs] def forward(self, x, condition=None): if condition is None: for module in self._modules.values(): x = module(x) else: for module in self._modules.values(): x = module(x, condition) return x
# Current implementations of ResNets
[docs] def Resnet18_Encoder(in_channels, norm, base): return ResNet_Encoder( in_channels, num_blocks=[2, 2, 2, 2], base=base, norm=norm, )
[docs] def Resnet18_Decoder(out_channels, dimension, base, norm): return ResNet_Decoder( out_channels, num_blocks=[2, 2, 2, 2], dimension=dimension, base=base, norm=norm, )
[docs] def Resnet34_Encoder(in_channels, base, norm): return ResNet_Encoder( in_channels, num_blocks=[3, 4, 6, 3], base=base, norm=norm, )
[docs] def Resnet34_Decoder(out_channels, dimension, base, norm): return ResNet_Decoder( out_channels, num_blocks=[3, 4, 6, 3], dimension=dimension, base=base, norm=norm, )