import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import nn
from torch.nn import functional as F
from starling.models.normalization import RMSNorm
[docs]
class MultiHeadAttention(nn.Module):
[docs]
def __init__(self, embed_dim: int, num_heads: int, context_dim: int = None):
"""
Multi-head attention module supporting both self- and cross-attention.
Parameters
----------
embed_dim : int
Dimension of the query input (and output)
num_heads : int
Number of attention heads
context_dim : int, optional
Dimension of context input. If None, defaults to `embed_dim` (i.e., self-attention).
"""
super().__init__()
self.embed_dim = embed_dim
self.context_dim = context_dim or embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
# Projections
self.query_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.key_proj = nn.Linear(self.context_dim, embed_dim, bias=False)
self.value_proj = nn.Linear(self.context_dim, embed_dim, bias=False)
self.out_proj = nn.Linear(embed_dim, embed_dim)
[docs]
def forward(self, query, context, query_mask=None, context_mask=None):
"""
query: (B, N, Dq) — tokens to be conditioned
context: (B, S, Dc) — conditioning source (or None for self-attention)
"""
B, N, _ = query.shape
_, S, _ = context.shape
# Project to Q, K, V
Q = self.query_proj(query)
K = self.key_proj(context)
V = self.value_proj(context)
# Reshape for multi-head attention
Q = rearrange(Q, "b n (h d) -> b h n d", h=self.num_heads)
K = rearrange(K, "b s (h d) -> b h s d", h=self.num_heads)
V = rearrange(V, "b s (h d) -> b h s d", h=self.num_heads)
# Build attention mask (broadcasted)
if query_mask is not None or context_mask is not None:
if query_mask is None:
query_mask = torch.ones(B, N, device=query.device, dtype=torch.bool)
if context_mask is None:
context_mask = torch.ones(B, S, device=query.device, dtype=torch.bool)
attn_mask = rearrange(query_mask, "b n -> b 1 n 1") & rearrange(
context_mask, "b s -> b 1 1 s"
) # (B, 1, N, S)
attn_mask = repeat(attn_mask, "b 1 n s -> b h n s", h=self.num_heads)
attn_mask = attn_mask.bool()
else:
attn_mask = None
# Scaled dot-product attention (Fused version in PyTorch ≥2.0)
out = F.scaled_dot_product_attention(Q, K, V, attn_mask=attn_mask)
# Merge heads
out = rearrange(out, "b h n d -> b n (h d)")
return self.out_proj(out)
[docs]
class CrossAttention(nn.Module):
[docs]
def __init__(
self,
embed_dim: int,
num_heads: int,
context_dim: int,
) -> None:
"""
Cross-attention between query (tokens) and context (e.g., protein sequence).
Parameters
----------
embed_dim : int
Dimensionality of the query tokens
num_heads : int
Number of attention heads
context_dim : int
Dimensionality of the context (keys/values)
"""
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.context_dim = context_dim
assert self.head_dim * num_heads == embed_dim, (
"embed_dim must be divisible by num_heads"
)
self.query_norm = nn.LayerNorm(embed_dim)
self.context_norm = nn.LayerNorm(context_dim)
self.query_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.key_proj = nn.Linear(context_dim, embed_dim, bias=False)
self.value_proj = nn.Linear(context_dim, embed_dim, bias=False)
self.out_proj = nn.Linear(embed_dim, embed_dim)
[docs]
def forward(self, query, context, query_mask=None, context_mask=None):
"""
query: (B, N, D) — tokens to be conditioned
context: (B, S, C) — context (e.g., sequence embeddings)
"""
B, N, D = query.shape
_, S, _ = context.shape
# Normalize and project
query = self.query_norm(query)
context = self.context_norm(context)
Q = self.query_proj(query) # (B, N, D)
K = self.key_proj(context) # (B, S, D)
V = self.value_proj(context) # (B, S, D)
# Multi-head reshape
Q = rearrange(Q, "b n (h d) -> b h n d", h=self.num_heads)
K = rearrange(K, "b s (h d) -> b h s d", h=self.num_heads)
V = rearrange(V, "b s (h d) -> b h s d", h=self.num_heads)
# Attention masks
if query_mask is not None or context_mask is not None:
if query_mask is None:
query_mask = torch.ones((B, N), device=query.device)
if context_mask is None:
context_mask = torch.ones((B, S), device=query.device)
query_mask = rearrange(query_mask, "b n -> b 1 n 1")
context_mask = rearrange(context_mask, "b s -> b 1 1 s")
attention_mask = query_mask * context_mask # (B, 1, N, S)
attention_mask = repeat(
attention_mask, "b 1 n s -> b h n s", h=self.num_heads
)
attention_mask = attention_mask.bool()
else:
attention_mask = None
out = F.scaled_dot_product_attention(Q, K, V, attn_mask=attention_mask)
out = rearrange(out, "b h n d -> b n (h d)") # back to (B, N, D)
return self.out_proj(out)
[docs]
class SelfAttention(nn.Module):
[docs]
def __init__(
self, embed_dim: int, num_heads: int, channels_last: bool = False
) -> None:
"""
This is a basic self-attention module. It uses linear layers to project
the input into query, key, and value matrices, then performs scaled dot-product
attention on these matrices. The output is then projected back to the original
embedding dimension. Commonly used in transformer models.
Parameters
----------
embed_dim : int
Dimension of the input embedding
num_heads : int
Number of heads for multi-head attention
channels_last : bool, optional
Whether the input has channels last format, if not it will be rearranged, by default False
"""
super(SelfAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.channels_last = channels_last
assert self.head_dim * num_heads == embed_dim, (
"embed_dim must be divisible by num_heads"
)
self.query_norm = nn.LayerNorm(embed_dim)
self.query_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.key_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.value_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.out_proj = nn.Linear(embed_dim, embed_dim)
[docs]
def forward(self, x, attention_mask=None):
input_dim = x.dim()
if input_dim == 4:
batch_size, channels, height, width = x.size()
elif input_dim == 3:
batch_size, seq_len, channels = x.size()
else:
raise ValueError("Input dimension not supported")
if not self.channels_last and input_dim == 4:
x = rearrange(x, "b c h w -> b h w c")
# Prenormalization
x = self.query_norm(x)
# Linear projection for the query
Q = self.query_proj(x)
K = self.key_proj(x)
V = self.value_proj(x)
if input_dim == 4:
# If input is 4D (images)
Q = rearrange(Q, "b x y (h d) -> b h (x y) d", h=self.num_heads)
K = rearrange(K, "b x y (h d) -> b h (x y) d", h=self.num_heads)
V = rearrange(V, "b x y (h d) -> b h (x y) d", h=self.num_heads)
elif input_dim == 3:
# If input is 3D (text)
Q = rearrange(Q, "b x (h d) -> b h x d", h=self.num_heads)
K = rearrange(K, "b x (h d) -> b h x d", h=self.num_heads)
V = rearrange(V, "b x (h d) -> b h x d", h=self.num_heads)
if attention_mask is not None:
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = attention_mask.expand(
batch_size, self.num_heads, seq_len, seq_len
)
attention_output = F.scaled_dot_product_attention(
Q, K, V, attn_mask=attention_mask
)
# Concatenate heads and reshape back to original dimensions
if input_dim == 4:
attention_output = rearrange(
attention_output, "b h (x y) d -> b x y (h d)", x=height, y=width
)
elif input_dim == 3:
attention_output = rearrange(
attention_output, "b h x d -> b x (h d)", x=seq_len
)
attention_output = self.out_proj(attention_output)
if not self.channels_last and input_dim == 4:
attention_output = rearrange(attention_output, "b h w c -> b c h w")
return attention_output
# The attention pooling could be used as an additional conditioning mechanism where its concatenated with
# timestep embeddings and then added to ResNet blocks (either in the middle or at the beginning)
# - Imagen seems to this at the beginning of the ResNet blocks
[docs]
class AttentionPooling(nn.Module):
[docs]
def __init__(self, feature_dim, hidden_dim):
super(AttentionPooling, self).__init__()
self.attention = nn.Sequential(
nn.SiLU(), # Swish activation function
nn.Linear(feature_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, 1),
)
[docs]
def forward(self, x):
# x: input features of shape (batch_size, num_features, feature_dim)
batch_size, num_features, feature_dim = x.size()
# Compute attention scores
attention_scores = self.attention(x) # shape: (batch_size, num_features, 1)
attention_weights = torch.softmax(
attention_scores, dim=1
) # shape: (batch_size, num_features, 1)
# Compute weighted sum of features
pooled_features = torch.sum(
attention_weights * x, dim=1
) # shape: (batch_size, feature_dim)
return pooled_features
[docs]
class SelfAttentionConv(nn.Module):
[docs]
def __init__(self, embed_dim: int, num_heads: int, kernel_size: int = 1) -> None:
"""
SelfAttentionConv module for use in UNet models. This module is used to
perform self-attention on 2D data. It is used to attend to spatial features
in the 2D data, effectively allowing the model to learn spatial relationships
between pixels.
Parameters
----------
embed_dim : int
Dimension of the input embedding
num_heads : int
Number of heads for multi-head attention
kernel_size : int, optional
Size of the kernel for generating query, key, and value matrices, by default 1
"""
super(SelfAttentionConv, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == embed_dim, (
"embed_dim must be divisible by num_heads"
)
self.query_conv = nn.Conv2d(
embed_dim, embed_dim, kernel_size=kernel_size, padding=kernel_size // 2
)
self.key_conv = nn.Conv2d(
embed_dim, embed_dim, kernel_size=kernel_size, padding=kernel_size // 2
)
self.value_conv = nn.Conv2d(
embed_dim, embed_dim, kernel_size=kernel_size, padding=kernel_size // 2
)
self.out_conv = nn.Sequential(
nn.Conv2d(embed_dim, embed_dim, kernel_size=1), RMSNorm(embed_dim)
)
[docs]
def forward(self, x: torch.Tensor):
batch_size, channels, height, width = x.size()
# Convolutional projections
Q = self.query_conv(x)
K = self.key_conv(x)
V = self.value_conv(x)
# Reshape to (batch_size, num_heads, head_dim, height * width)
Q = Q.view(batch_size, self.num_heads, self.head_dim, -1)
K = K.view(batch_size, self.num_heads, self.head_dim, -1)
V = V.view(batch_size, self.num_heads, self.head_dim, -1)
# Transpose for multi-head attention (batch_size, num_heads, height * width, head_dim)
Q = Q.transpose(2, 3)
K = K.transpose(2, 3)
V = V.transpose(2, 3)
# Scaled Dot-Product Attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim**0.5)
attention_weights = F.softmax(scores, dim=-1)
attention_output = torch.matmul(attention_weights, V)
# Concatenate heads and reshape back to original dimensions
attention_output = (
attention_output.transpose(2, 3)
.contiguous()
.view(batch_size, self.embed_dim, height, width)
)
attention_output = self.out_conv(attention_output)
return attention_output