import math
import torch
import torch.nn as nn
from einops import rearrange, repeat
from torch import nn
from starling.data.positional_encodings import (
PositionalEncoding1D,
PositionalEncoding2D,
)
from starling.models.attention import CrossAttention, MultiHeadAttention, SelfAttention
[docs]
class SinusoidalPosEmb(nn.Module):
[docs]
def __init__(self, dim: int, theta: int = 10000):
"""
Generates sinusoidal positional embeddings that are used in the denoising-diffusion
models to encode the timestep information. The positional embeddings are generated
using sine and cosine functions. It takes in time in the shape of (batch_size, 1)
and returns the positional embeddings in the shape of (batch_size, dim). The positional
encodings are later used in each of the ResNet blocks to encode the timestep information.
Parameters
----------
dim : int
Dimension of the input data.
theta : int, optional
A scaling factor for the positional embeddings. The default value is 10000.
"""
super().__init__()
self.dim = dim
self.theta = theta
[docs]
def forward(self, time: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the positional (timestep) embeddings.
Parameters
----------
time : torch.Tensor
Timestep information in the shape of (batch_size, 1).
Returns
-------
torch.Tensor
Positional (timestep) embeddings in the shape of (batch_size, dim).
"""
device = time.device
# The number of unique frequencies in the positional embeddings, half
# will be used for sine and the other half for cosine functions
half_dim = self.dim // 2
emb = math.log(self.theta) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = time[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
[docs]
class MLP(nn.Module):
[docs]
def __init__(self, input_dim: int, output_dim: int, expansion_factor: int = 4):
"""
A simple Multi-Layer Perceptron with a single hidden layer and layer normalization.
The MLP first projects the input to a higher dimension (output_dim * expansion_factor),
applies a ReLU activation, then projects back to the output dimension. Finally,
layer normalization is applied to the output.
Parameters
----------
input_dim : int
The dimension of the input features.
output_dim : int
The dimension of the output features.
expansion_factor : int, optional
The factor by which to expand the hidden dimension, by default 4.
"""
super().__init__()
hidden_dim = output_dim * expansion_factor
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim),
)
# self._init_weights()
# def _init_weights(self):
# # Initialize weights for all linear layers in the sequential
# for module in self.net:
# if isinstance(module, nn.Linear):
# nn.init.xavier_uniform_(module.weight)
# if module.bias is not None:
# nn.init.zeros_(module.bias)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
[docs]
class AdaLayerNorm(nn.Module):
[docs]
def __init__(self, embed_dim, cond_dim):
super().__init__()
self.embed_dim = embed_dim
self.cond_dim = cond_dim
# Map conditioning (t + c) to scale and shift
self.cond_mlp = nn.Sequential(
nn.Linear(cond_dim, embed_dim * 2), # for gamma and beta
nn.SiLU(),
nn.Linear(embed_dim * 2, embed_dim * 2), # outputs [gamma | beta]
)
# no learned gamma/beta
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False)
[docs]
def forward(self, x, cond):
"""
x: (B, N, D) - token embeddings
cond: (B, cond_dim) - conditioning vector (e.g., t_emb + c_emb)
"""
# Apply vanilla LayerNorm without scale/shift
x_norm = self.norm(x)
# Generate dynamic gamma and beta
gamma_beta = self.cond_mlp(cond) # (B, 2D)
gamma, beta = gamma_beta.chunk(2, dim=-1) # Each is (B, D)
# Expand for broadcasting over sequence
gamma = gamma.unsqueeze(1) # (B, 1, D)
beta = beta.unsqueeze(1) # (B, 1, D)
# Apply adaptive scale and shift
return gamma * x_norm + beta
[docs]
class GeGLU(nn.Module):
[docs]
def __init__(self, d_in: int, d_out: int):
"""
Activation function that combines the concept of gating with the GELU activation function.
The gating mechanism is used to control the flow of information through the network. The GELU
activation function is used to introduce non-linearity in the network. The GeGLU activation
function is often seen in the feed forward layer of transformers.
The GeGLU activation function
is defined as follows: x * GELU(gate), where x is the input to the activation function and
gate is the output of a linear layer.
Parameters
----------
d_in : int
The input dimension of the data. Used to initialize the linear layer.
d_out : int
The output dimension of the data. Used to initialize the linear layer.
"""
super().__init__()
self.proj = nn.Linear(d_in, d_out * 2)
self.gelu = nn.GELU()
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, gate = self.proj(x).chunk(2, dim=-1)
return x * self.gelu(gate)
[docs]
class FeedForward(nn.Module):
[docs]
def __init__(self, embed_dim: int):
"""
Feed forward layer in the transformer architecture. The feed forward layer consists of
two linear layers with a GELU activation function in between. The linear layers first
expand the number of dimensions by a factor of 4 and then reduce the number of dimensions
back to the original number of dimensions. The GELU activation function is used to introduce
non-linearity in the network.
Parameters
----------
embed_dim : int
The input dimension of the data. Used to initialize the linear layers.
"""
super().__init__()
self.net = nn.Sequential(
GeGLU(embed_dim, embed_dim * 4),
nn.Linear(embed_dim * 4, embed_dim),
)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.net(x)
return x
[docs]
class SequenceEncoder(nn.Module):
[docs]
def __init__(self, num_layers: int, embed_dim: int, num_heads: int):
"""
Sequence encoder layer. The sequence encoder layer consists of a transformer encoder
and a feed forward layer. The transformer encoder layer is used to capture the relationships
between different elements in the input data. The feed forward layer is used to introduce
non-linearity in the network.
Parameters
----------
num_layers : int
The number of layers in the transformer encoder.
embed_dim : int
The input dimension of the data. Used to initialize the transformer encoder and feed forward layers.
num_heads : int
The number of heads in the multi-head attention layer. Used to initialize the transformer encoder.
"""
super().__init__()
self.ionic_strength_emb = nn.Sequential(
SinusoidalPosEmb(embed_dim),
MLP(embed_dim, embed_dim),
nn.LayerNorm(embed_dim),
)
self.sequence_learned_embedding = nn.Embedding(21, embed_dim)
self.sequence_positional_encoding = PositionalEncoding1D(embed_dim)
self.layers = nn.ModuleList(
[TransformerEncoder(embed_dim, num_heads) for _ in range(num_layers)]
)
self.final_norm = nn.LayerNorm(embed_dim)
[docs]
def forward(self, x: torch.Tensor, mask, ionic_strengths) -> torch.Tensor:
# Embed ionic strengths and expand for each token
ionic_strengths = self.ionic_strength_emb(ionic_strengths)
if self.training:
# Randomly mask some of the ionic strength values
mask_ionic = (
torch.rand(ionic_strengths.shape[0], device=ionic_strengths.device)
< 0.2
)
ionic_strengths[mask_ionic] = torch.zeros_like(ionic_strengths[mask_ionic])
# Convert input sequence to embeddings
x = self.sequence_learned_embedding(x)
# Add positional encodings to the input data
x = self.sequence_positional_encoding(x)
for layer in self.layers:
x = layer(x, mask=mask, ionic_strengths=ionic_strengths)
# Apply final normalization
x = self.final_norm(x)
return x
[docs]
class DiTBlock(nn.Module):
[docs]
def __init__(self, embed_dim: int, num_heads: int, context_dim: int):
"""
Transformer decoder layer. The transformer decoder layer consists of a self attention layer,
cross attention layer and a feed forward layer. The self attention layer is used to capture the
relationships between different elements in the input data. The cross attention layer is used to
capture the relationships between the input data and the context data (usually the transformer
encoder output). The feed forward layer is used to introduce non-linearity in the network.
Parameters
----------
embed_dim : int
The input dimension of the data. Used to initialize the self attention, cross attention and feed forward layers.
num_heads : int
The number of heads in the multi-head attention layer. Used to initialize the self attention and cross attention layers.
context_dim : _type_
The dimension of the context data. Used to initialize the cross attention layer.
"""
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim, context_dim)
self.norm2 = nn.LayerNorm(embed_dim, context_dim)
self.norm3 = nn.LayerNorm(embed_dim, context_dim)
self.self_attention = MultiHeadAttention(embed_dim, num_heads)
self.cross_attention = MultiHeadAttention(embed_dim, num_heads, context_dim)
self.feed_forward = FeedForward(embed_dim)
[docs]
def forward(self, x: torch.Tensor, context, context_mask) -> torch.Tensor:
# Prenorm the input to the self attention layer
x_normed = self.norm1(x)
x = x + self.self_attention(query=x_normed, context=x_normed)
# Prenorm the input to the cross attention layer (context is layernormed)
x_normed = self.norm2(x)
x = x + self.cross_attention(
query=x_normed, context=context, context_mask=context_mask
)
# Prenorm the input to the feed forward layer
x_normed = self.norm3(x)
x = x + self.feed_forward(x_normed)
return x