Source code for starling.models.unet

import math
from typing import List

import torch
from torch import nn

from starling.models.blocks import ResBlockEncBasic, ResizeConv2d
from starling.models.normalization import RMSNorm
from starling.models.transformer import SpatialTransformer


[docs] class SinusoidalPosEmb(nn.Module):
[docs] def __init__(self, dim: int, theta: int = 10000): """ Generates sinusoidal positional embeddings that are used in the denoising-diffusion models to encode the timestep information. The positional embeddings are generated using sine and cosine functions. It takes in time in the shape of (batch_size, 1) and returns the positional embeddings in the shape of (batch_size, dim). The positional encodings are later used in each of the ResNet blocks to encode the timestep information. Parameters ---------- dim : int Dimension of the input data. theta : int, optional A scaling factor for the positional embeddings. The default value is 10000. """ super().__init__() self.dim = dim self.theta = theta
[docs] def forward(self, time: torch.Tensor) -> torch.Tensor: """ Forward pass of the positional (timestep) embeddings. Parameters ---------- time : torch.Tensor Timestep information in the shape of (batch_size, 1). Returns ------- torch.Tensor Positional (timestep) embeddings in the shape of (batch_size, dim). """ device = time.device # The number of unique frequencies in the positional embeddings, half # will be used for sine and the other half for cosine functions half_dim = self.dim // 2 emb = math.log(self.theta) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device) * -emb) emb = time[:, None] * emb[None, :] emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb
[docs] class ConditionalSequential(nn.Sequential):
[docs] def forward(self, x, condition): for module in self._modules.values(): x = module(x, condition) return x
[docs] class Downsample(nn.Module):
[docs] def __init__(self, in_channels: int, out_channels: int, norm: str): """ A convolutional block that reduces the spatial dimensions of the input tensor by a factor of 2. The block consists of a convolutional layer with a kernel size of 3, stride of 2, and padding of 1. The convolutional layer is followed by a normalization layer and a ReLU activation function. Parameters ---------- in_channels : int The number of features in the input tensor. out_channels : int The number of features in the output tensor. norm : str The normalization layer to be used in the block. Choose from batch, instance, rms, or group. """ super().__init__() normalization = { "batch": nn.BatchNorm2d, "instance": nn.InstanceNorm2d, "rms": RMSNorm, "group": nn.GroupNorm, } self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1), normalization[norm](out_channels) if norm != "group" else normalization[norm](32, out_channels), nn.ReLU(inplace=True), )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv(x) return x
[docs] class ResnetLayer(nn.Module):
[docs] def __init__( self, in_channels, out_channels, norm, num_blocks, timestep_dim, class_dim=None, ): super().__init__() self.layer = nn.ModuleList() self.in_channels = in_channels for block in range(num_blocks): self.layer.append( ResBlockEncBasic( self.in_channels, out_channels, 1, norm, timestep_dim, class_dim ) ) self.in_channels = out_channels
[docs] def forward(self, x, time): for layer in self.layer: x = layer(x, time) return x
[docs] class CrossAttentionResnetLayer(nn.Module):
[docs] def __init__( self, in_channels: int, out_channels: int, norm: str, num_blocks: int, attention_heads: int, timestep_dim: int, label_dim: int, ): """ A combination of ResNet blocks followed by spatial transformer blocks. The ResNet block processes the input tensor and the spatial transformer block captures the relationships between the input tensor and the context data (protein sequences). Parameters ---------- in_channels : int The number of features in the input tensor. out_channels : int The number of features in the output tensor. norm : str The normalization layer to be used in the block. Choose from batch, instance, rms, or group. num_blocks : int The number of ResNet + spatial transformer blocks in the layer. attention_heads : int The number of heads in the multi-head attention layer. timestep_dim : int The dimension of the timestep embeddings. label_dim : int The dimension of the context data (protein sequences). """ super().__init__() self.layer = nn.ModuleList() self.transformer = nn.ModuleList() self.in_channels = in_channels for block in range(num_blocks): self.layer.append( ResBlockEncBasic(self.in_channels, out_channels, 1, norm, timestep_dim) ) self.transformer.append( SpatialTransformer(out_channels, attention_heads, label_dim), ) self.in_channels = out_channels
[docs] def forward( self, x: torch.Tensor, time: torch.Tensor, sequence_label: torch.Tensor, sequence_mask: torch.Tensor, ) -> torch.Tensor: """ Forward pass of the ResNet + spatial transformer blocks. Parameters ---------- x : torch.Tensor Input tensor to be processed by the ResNet + spatial transformer blocks. time : torch.Tensor Timestep embeddings to be used by the network. sequence_label : torch.Tensor Context data (protein sequences) to guide the prediction. Returns ------- torch.Tensor Output of the ResNet + spatial transformer blocks. """ for layer, transformer in zip(self.layer, self.transformer): x = layer(x, time) x = transformer(x, context=sequence_label, mask=sequence_mask) return x
[docs] class UNetConditional(nn.Module):
[docs] def __init__( self, in_channels: int, out_channels: int, base: int, norm: str, blocks: List = [2, 2, 2], middle_blocks: int = 2, labels_dim: int = 512, sinusoidal_pos_emb_theta: int = 10000, ): """ A U-Net architecture that uses ResNet blocks with spatial transformer blocks to process the input tensor and the context data (protein sequences). The U-Net architecture consists of an encoder, a middle section, and a decoder. The spatial transformer blocks are used to capture the relationships between the input tensor and the context data (in our case protein sequences). Parameters ---------- in_channels : int The number of features in the input tensor. out_channels : int The number of features in the output tensor. base : int The base number of features in the U-Net architecture. norm : str The normalization layer to be used in the block. Choose from batch, instance, rms, or group blocks : List, optional The number of ResNet + spatial transformer blocks in each section of the U-Net architecture, by default [2, 2, 2] middle_blocks : int, optional The number of ResNet + spatial transformer blocks in the middle section of the U-Net architecture, by default 2 labels_dim : int, optional The dimension of the context data (i.e., protein sequences), by default 512 sinusoidal_pos_emb_theta : int, optional A scaling factor for the positional (timestep) embeddings, by default 10000 """ super().__init__() normalization = { "batch": nn.BatchNorm2d, "instance": nn.InstanceNorm2d, "rms": RMSNorm, "group": nn.GroupNorm, } self.norm = norm self.in_channels = in_channels self.out_channels = out_channels self.time_dim = base * 4 self.base = base self.labels_dim = labels_dim # Time embeddings self.time_emb = SinusoidalPosEmb(self.base, theta=sinusoidal_pos_emb_theta) self.time_mlp = nn.Sequential( self.time_emb, nn.Linear(self.base, self.time_dim), nn.SiLU(inplace=False), nn.Linear(self.time_dim, self.time_dim), ) all_in_channels = [base * (2**i) for i in range(len(blocks) + 1)] # Encoder part of UNet self.conv_in = CrossAttentionResnetLayer( in_channels, all_in_channels[0], self.norm, blocks[0], 8, self.time_dim, self.labels_dim, ) self.encoder_layer1 = CrossAttentionResnetLayer( all_in_channels[0], all_in_channels[0], self.norm, blocks[0], 8, self.time_dim, self.labels_dim, ) self.downsample1 = Downsample(all_in_channels[0], all_in_channels[1], norm) self.encoder_layer2 = CrossAttentionResnetLayer( all_in_channels[1], all_in_channels[1], self.norm, blocks[1], 8, self.time_dim, self.labels_dim, ) self.downsample2 = Downsample(all_in_channels[1], all_in_channels[2], norm) self.encoder_layer3 = CrossAttentionResnetLayer( all_in_channels[2], all_in_channels[2], self.norm, blocks[2], 8, self.time_dim, self.labels_dim, ) self.downsample3 = Downsample(all_in_channels[2], all_in_channels[3], norm) # Middle convolution of the UNet self.middle = CrossAttentionResnetLayer( all_in_channels[3], all_in_channels[3], self.norm, middle_blocks, 8, self.time_dim, self.labels_dim, ) # Decoder part of UNet self.upconv1 = ResizeConv2d( all_in_channels[3], all_in_channels[2], kernel_size=3, padding=1, scale_factor=2, norm=normalization[norm], activation="relu", ) self.decoder_layer1 = CrossAttentionResnetLayer( all_in_channels[2] * 2, all_in_channels[2], self.norm, blocks[2], 8, self.time_dim, self.labels_dim, ) self.upconv2 = ResizeConv2d( all_in_channels[2], all_in_channels[1], kernel_size=3, padding=1, scale_factor=2, norm=normalization[norm], activation="relu", ) self.decoder_layer2 = CrossAttentionResnetLayer( all_in_channels[1] * 2, all_in_channels[1], self.norm, blocks[1], 8, self.time_dim, self.labels_dim, ) self.upconv3 = ResizeConv2d( all_in_channels[1], all_in_channels[0], kernel_size=3, padding=1, scale_factor=2, norm=normalization[norm], activation="relu", ) self.decoder_layer3 = CrossAttentionResnetLayer( all_in_channels[0] * 2, all_in_channels[0], self.norm, blocks[1], 8, self.time_dim, self.labels_dim, ) self.conv_out = nn.Conv2d(all_in_channels[0], out_channels, kernel_size=1)
[docs] def forward( self, x: torch.Tensor, time: torch.Tensor, labels: torch.Tensor, sequence_mask: torch.Tensor, ) -> torch.Tensor: """ Forward pass of the UNet architecture. Parameters ---------- x : torch.Tensor Data to pass through the UNet architecture. time : torch.Tensor Timestep embeddings. labels : torch.Tensor, optional Context data (protein sequences) to guide the prediction, by default None Returns ------- torch.Tensor Output of the UNet architecture. """ # Get the time embeddings time = self.time_mlp(time) # Initial convolution x = self.conv_in(x, time, labels, sequence_mask) # Encoder forward passes x = self.encoder_layer1(x, time, labels, sequence_mask) x_layer1 = x.clone() x = self.downsample1(x) x = self.encoder_layer2(x, time, labels, sequence_mask) x_layer2 = x.clone() x = self.downsample2(x) x = self.encoder_layer3(x, time, labels, sequence_mask) x_layer3 = x.clone() x = self.downsample3(x) # Mid UNet x = self.middle(x, time, labels, sequence_mask) # Decoder forward passes with skip connections from the encoder x = self.upconv1(x) x = torch.cat((x, x_layer3), dim=1) x = self.decoder_layer1(x, time, labels, sequence_mask) x = self.upconv2(x) x = torch.cat((x, x_layer2), dim=1) x = self.decoder_layer2(x, time, labels, sequence_mask) x = self.upconv3(x) x = torch.cat((x, x_layer1), dim=1) x = self.decoder_layer3(x, time, labels, sequence_mask) # Final convolutions x = self.conv_out(x) return x