Source code for starling.models.blocks

import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn

from starling.models.normalization import RMSNorm


[docs] class LayerNorm(nn.Module): r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). Modified from: https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py#L119 It increases memory requirements substantially, unclear if that can be changed """
[docs] def __init__(self, normalized_shape, eps=1e-6, data_format="channels_first"): super().__init__() self.weight = nn.Parameter(torch.ones(normalized_shape)) self.bias = nn.Parameter(torch.zeros(normalized_shape)) self.eps = eps self.data_format = data_format if self.data_format not in ["channels_last", "channels_first"]: raise NotImplementedError self.normalized_shape = (normalized_shape,)
[docs] def forward(self, x): if self.data_format == "channels_last": return F.layer_norm( x, self.normalized_shape, self.weight, self.bias, self.eps ) elif self.data_format == "channels_first": x = x.permute(0, 2, 3, 1) x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) x = x.permute(0, 3, 1, 2) return x
[docs] class MinPool2d(nn.Module):
[docs] def __init__(self, kernel_size, stride=None, padding=0, dilation=1): super(MinPool2d, self).__init__() self.kernel_size = kernel_size self.stride = stride or kernel_size self.padding = padding self.dilation = dilation
[docs] def forward(self, x): # Perform min pooling using torch.min and torch.nn.functional.max_pool2d unpool = nn.functional.max_pool2d( -1 * x, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, dilation=self.dilation, return_indices=False, ) return -1 * unpool
#! Fix how the activation function is passed in (should be torch.nn.Module not str)
[docs] class ResizeConv2d(nn.Module):
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: int, norm: torch.nn.Module, activation: str, padding: int, size: int = None, scale_factor: int = None, mode: str = "nearest", ): """ This module uses F.interpolate for upsampling followed by a convolutional layer, instead of ConvTranspose2d. This approach helps to avoid checkerboard artifacts that are common with ConvTranspose2d (https://distill.pub/2016/deconv-checkerboard/). It is particularly useful in the decoder part of the network. Parameters ---------- in_channels : int Number of input channels. out_channels : int Number of output channels. kernel_size : int Size of the convolutional kernel. norm : torch.nn.Module Normalization layer to use (e.g., nn.InstanceNorm2d). activation : str Activation function to use (e.g., nn.ReLU). padding : int Padding for the convolutional layer. size : int, optional Spatial size of the output tensor. If None, scale_factor is used. Default is None. scale_factor : int, optional Scale factor for upsampling. Default is None. mode : str, optional Mode for upsampling. Default is "nearest". """ super().__init__() self.size = size self.scale_factor = scale_factor self.mode = mode if norm is not None: normalization = ( norm(out_channels) if norm != nn.GroupNorm else norm(32, out_channels) ) self.conv = nn.Sequential( nn.Conv2d( in_channels, out_channels, kernel_size, stride=1, padding=padding ), nn.Identity() if norm is None else normalization, nn.Identity() if activation is None else nn.ReLU(inplace=True), )
[docs] def forward(self, x): x = F.interpolate( x, size=self.size, scale_factor=self.scale_factor, mode=self.mode ) x = self.conv(x) return x
[docs] class ResBlockEncBasic(nn.Module): expansion = 1
[docs] def __init__( self, in_channels: int, out_channels: int, stride: int, norm: str, timestep: int = None, kernel_size: int = 3, ) -> None: """ A basic residual block commonly used in ResNet architectures like ResNet18 and ResNet34. It consists of two convolutional layers with a ReLU activation function in between. The input is added to the output of the second convolutional layer, followed by a ReLU activation function. Optionally, the block can be conditioned on class labels or other information, which is added to the output of the first convolution before normalization and activation. Parameters ---------- in_channels : int Number of input channels. out_channels : int Number of output channels. stride : int Stride of the first convolutional layer. norm : str Normalization layer to use. Options are "batch", "instance", "layer", "rms", and "group". timestep : int, optional Dimension of class labels/timesteps for conditioning. If None, no conditioning is applied. Default is None. kernel_size : int, optional Kernel size for convolutional layers. Default is 3. """ super().__init__() kernel_size = 3 if kernel_size is None else kernel_size padding = 2 if kernel_size == 5 else (3 if kernel_size == 7 else 1) normalization = { "batch": nn.BatchNorm2d, "instance": nn.InstanceNorm2d, "layer": LayerNorm, "rms": RMSNorm, "group": nn.GroupNorm, } self.conv1 = nn.Conv2d( in_channels=in_channels, out_channels=out_channels, stride=stride, padding=padding, kernel_size=kernel_size, ) self.norm1 = ( normalization[norm](out_channels) if norm != "group" else normalization[norm](32, out_channels) ) self.activation1 = nn.ReLU(inplace=True) if timestep is not None: self.time_mlp = nn.Sequential( nn.SiLU(inplace=False), nn.Linear(timestep, out_channels * 2), ) self.conv2 = nn.Sequential( nn.Conv2d( in_channels=out_channels, out_channels=out_channels, stride=1, padding=padding, kernel_size=kernel_size, ), normalization[norm](out_channels) if norm != "group" else normalization[norm](32, out_channels), ) if stride > 1 or in_channels != out_channels: self.shortcut = nn.Sequential( nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=0, ), normalization[norm](out_channels) if norm != "group" else normalization[norm](32, out_channels), ) else: self.shortcut = nn.Sequential() self.activation = nn.ReLU(inplace=True)
[docs] def forward(self, data, timestep=None): # Set up the shortcut connection if necessary identity = self.shortcut(data) # First convolution data = self.conv1(data) # Add timestep conditioning if provided using FiLM if timestep is not None: timestep = self.time_mlp(timestep) timestep = rearrange(timestep, "b c -> b c 1 1") # See the following link for explanation of scale, shift for timestep/class conditioning # https://distill.pub/2018/feature-wise-transformations/ scale, shift = timestep.chunk(2, dim=1) data = data * (scale + 1) + shift # Add normalization and activation function after timestep conditioning data = self.norm1(data) data = self.activation1(data) # Second convolution data = self.conv2(data) # Add the input and run it through activation function data += identity return self.activation(data)
[docs] class ResBlockDecBasic(nn.Module): contraction = 1
[docs] def __init__( self, in_channels: int, out_channels: int, stride: int, norm: str, last_layer=None, kernel_size: int = None, ) -> None: """ A basic residual block commonly used in ResNet architectures like ResNet18 and ResNet34. It consists of an interpolation layer (upsampling) and two convolutional layers with a ReLU activation function in between. The input is added to the output of the second convolutional layer, followed by a ReLU activation function. This block is used in the decoder part of the network. Parameters ---------- in_channels : int Number of input channels. out_channels : int Number of output channels. stride : int Stride of the first convolutional layer. norm : str Normalization layer to use. Options are "batch", "instance", "layer", and "group". kernel_size : int, optional Kernel size for convolutional layers. Default is None. """ super().__init__() kernel_size = 3 if kernel_size is None else kernel_size padding = 2 if kernel_size == 5 else (3 if kernel_size == 7 else 1) normalization = { "batch": nn.BatchNorm2d, "instance": nn.InstanceNorm2d, "layer": LayerNorm, "group": nn.GroupNorm, } # First convolution which doesn't change the shape of the tensor # (b, c, h, w) -> (b, c, h, w) stride = 1 self.conv1 = nn.Sequential( nn.Conv2d( in_channels=in_channels, out_channels=in_channels, stride=1, padding=padding, kernel_size=kernel_size, ), normalization[norm](in_channels) if norm != "group" else normalization[norm](32, in_channels), nn.ReLU(inplace=True), ) if stride > 1: self.conv2 = ResizeConv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, norm=normalization[norm], activation=None, padding=padding, scale_factor=stride, mode="nearest", ) self.shortcut = ResizeConv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, norm=normalization[norm], activation=None, padding=0, scale_factor=stride, mode="nearest", ) else: self.conv2 = nn.Sequential( nn.Conv2d( in_channels=in_channels, out_channels=out_channels, stride=1, padding=padding, kernel_size=kernel_size, ), normalization[norm](out_channels) if norm != "group" else normalization[norm](32, out_channels), ) self.shortcut = nn.Sequential() self.activation = nn.ReLU(inplace=True)
[docs] def forward(self, data): # Setup the shortcut connection if necessary identity = self.shortcut(data) # First convolution of the data data = self.conv1(data) # Second convolution of the data data = self.conv2(data) # Connect the input data to the output of convolutions data += identity # Run it through the activation function return self.activation(data)
[docs] class ResBlockEncBottleneck(nn.Module): expansion = 4
[docs] def __init__( self, in_channels, out_channels, stride, expansion=4, ) -> None: super().__init__() self.expansion = expansion self.conv1 = nn.Sequential( nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, ), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), ) self.conv2 = nn.Sequential( nn.Conv2d( in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, ), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), ) self.conv3 = nn.Sequential( nn.Conv2d( in_channels=out_channels, out_channels=int(out_channels * self.expansion), kernel_size=1, ), nn.BatchNorm2d(int(out_channels * self.expansion)), nn.ReLU(inplace=True), ) if stride != 1 or in_channels != int(out_channels * self.expansion): self.shortcut = nn.Sequential( nn.Conv2d( in_channels=in_channels, out_channels=int(out_channels * self.expansion), kernel_size=1, stride=stride, ), nn.BatchNorm2d(int(out_channels * self.expansion)), ) else: self.shortcut = nn.Sequential() self.activation = nn.ReLU(inplace=True)
[docs] def forward(self, data): identity = self.shortcut(data) out = self.conv1(data) out = self.conv2(out) out = self.conv3(out) out = out + identity return self.activation(out)
[docs] class ResBlockDecBottleneck(nn.Module): contraction = 4
[docs] def __init__( self, in_channels, out_channels, stride, contraction=4, last_layer=False ) -> None: super().__init__() self.contraction = contraction self.conv1 = nn.Sequential( nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, ), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), ) if stride != 1: self.conv2 = nn.Sequential( nn.ConvTranspose2d( in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=stride, output_padding=1, padding=1, ), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), ) else: self.conv2 = nn.Sequential( nn.Conv2d( in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, ), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), ) expansion = ( self.contraction if stride == 1 and not last_layer else (1 if last_layer else int(self.contraction / 2)) ) self.conv3 = nn.Sequential( nn.Conv2d( in_channels=out_channels, out_channels=out_channels * expansion, kernel_size=1, ), nn.BatchNorm2d(out_channels * expansion), nn.ReLU(inplace=True), ) if stride != 1 or last_layer: expansion = 1 if last_layer else int(self.contraction / 2) self.shortcut = nn.Sequential( nn.ConvTranspose2d( in_channels=in_channels, out_channels=int(out_channels * expansion), kernel_size=1, stride=stride, output_padding=1 if stride > 1 else 0, ), nn.BatchNorm2d(int(out_channels * expansion)), ) else: self.shortcut = nn.Sequential() self.activation = nn.ReLU(inplace=True)
[docs] def forward(self, data): identity = self.shortcut(data) out = self.conv1(data) out = self.conv2(out) out = self.conv3(out) out = out + identity return self.activation(out)
[docs] def instance_norm(features, eps=1e-6, **kwargs): return nn.InstanceNorm2d(features, affine=True, eps=eps, **kwargs)
[docs] def layer_norm(out_channels, starting_dimension, **kwargs): denominator = 4 * (out_channels / 64) dimension = int(starting_dimension / denominator) return nn.LayerNorm([out_channels, dimension, dimension])
[docs] class vanilla_Encoder(nn.Module):
[docs] def __init__(self, in_channels, out_channels, kernel_size, stride) -> None: super().__init__() padding = 2 if kernel_size == 5 else (3 if kernel_size == 7 else 1) modules = [] for num, hidden_dim in enumerate(out_channels): modules.append( nn.Sequential( nn.Conv2d( in_channels, out_channels=hidden_dim, kernel_size=kernel_size, stride=stride, padding=padding, ), nn.BatchNorm2d(hidden_dim), nn.ReLU(inplace=True), ) ) in_channels = hidden_dim modules.append( nn.Sequential( nn.Conv2d( in_channels=in_channels, out_channels=out_channels[-1], kernel_size=3, stride=1, padding=0, ), nn.BatchNorm2d(out_channels[-1]), nn.ReLU(inplace=True), ) ) self.encoder = nn.Sequential(*modules)
[docs] def forward(self, data): return self.encoder(data)
[docs] class vanilla_Decoder(nn.Module):
[docs] def __init__(self, in_channels, out_channels, kernel_size, stride) -> None: super().__init__() padding = 2 if kernel_size == 5 else (3 if kernel_size == 7 else 1) modules = [] modules.append( nn.Sequential( nn.ConvTranspose2d( in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=3, stride=1, padding=0, ), nn.BatchNorm2d(out_channels[0]), nn.ReLU(inplace=True), ) ) num_layers = len(out_channels) - 1 for num in range(num_layers): modules.append( nn.Sequential( nn.ConvTranspose2d( out_channels[num], out_channels[num + 1], kernel_size=kernel_size, stride=stride, padding=padding, output_padding=1, ), nn.BatchNorm2d(out_channels[num + 1]), nn.ReLU(inplace=True), ) ) # Final output layer modules.append( nn.Sequential( nn.Conv2d( in_channels=out_channels[-1], out_channels=out_channels[-1], kernel_size=kernel_size, stride=1, padding=padding, ), nn.ReLU(inplace=True), ) ) self.decoder = nn.Sequential(*modules)
[docs] def forward(self, data): return self.decoder(data)
[docs] class DownsampleBlock(nn.Module):
[docs] def __init__(self, in_channels, out_channels, kernel_size, stride): super().__init__() padding = 2 if kernel_size == 5 else (3 if kernel_size == 7 else 1) self.conv = nn.Sequential( nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, ), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), )
[docs] def forward(self, data): return self.conv(data)
[docs] class UpsampleBlock(nn.Module):
[docs] def __init__(self, in_channels, out_channels, kernel_size, stride): super().__init__() padding = 2 if kernel_size == 5 else (3 if kernel_size == 7 else 1) self.conv_transpose = nn.Sequential( nn.ConvTranspose2d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=1, ), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), )
[docs] def forward(self, data): return self.conv_transpose(data)