import math
from typing import List, Union
import pytorch_lightning as pl
import torch
import torch.nn as nn
from torch.amp import autocast
from torch.functional import F
from torch.optim.lr_scheduler import (
CosineAnnealingLR,
CosineAnnealingWarmRestarts,
LambdaLR,
OneCycleLR,
)
from starling.data.schedulers import (
cosine_beta_schedule,
linear_beta_schedule,
sigmoid_beta_schedule,
)
from starling.models.vae import VAE
# Adapted from https://github.com/Camaltra/this-is-not-real-aerial-imagery/blob/main/src/ai/diffusion_process.py
# and https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/classifier_free_guidance.py#L720
torch.set_float32_matmul_precision("high")
# Helper function
[docs]
class DiffusionModel(pl.LightningModule):
"""
Denoising diffusion probabilistic model for latent space generation.
Implements the diffusion process described in:
- Sohl-Dickstein et al. (2015): Nonequilibrium Thermodynamics
- Ho et al. (2020): Denoising Diffusion Probabilistic Models
- Rombach et al. (2021): High-resolution image synthesis with latent diffusion
"""
SCHEDULER_MAPPING = {
"linear": linear_beta_schedule,
"cosine": cosine_beta_schedule,
"sigmoid": sigmoid_beta_schedule,
}
[docs]
def __init__(
self,
model: nn.Module,
sequence_encoder: nn.Module,
distance_map_encoder: nn.Module,
beta_scheduler: str = "cosine",
timesteps: int = 1000,
set_lr: float = 1e-4,
min_snr_loss: bool = False,
min_snr_gamma: float = 5.0,
config_scheduler: str = "LinearWarmupCosineAnnealingLR",
) -> None:
"""
A discrete-time denoising-diffusion model framework for latent space diffusion models.
The model is based on the work of Sohl-Dickstein et al. [1], Ho et al. [2], and Rombach et al. [3].
References
----------
1) Sohl-Dickstein, J., Weiss, E., Maheswaranathan, N. & Ganguli, S.
Deep Unsupervised Learning using Nonequilibrium Thermodynamics.
in Proceedings of the 32nd International Conference on Machine Learning
(eds. Bach, F. & Blei, D.) vol. 37 2256–2265 (PMLR, Lille, France, 07--09 Jul 2015).
2) Ho, J., Jain, A. & Abbeel, P. Denoising Diffusion Probabilistic Models. arXiv [cs.LG] (2020).
3) Rombach, R., Blattmann, A., Lorenz, D., Esser, P. & Ommer, B.
High-resolution image synthesis with latent diffusion models. arXiv [cs.CV] (2021).
Parameters
----------
model : nn.Module
A neural network model that takes in an image, a timestamp, and optionally labels to condition on
and outputs the predicted noise
encoder_model : nn.Module
A VAE model that takes in the data (e.g., a distance map) and outputs the compressed representation of
the data (e.g., a latent space). The denoising-diffusion model is then trained to denoise the latent space.
image_size : int
The size of the latent space (height and width)
beta_scheduler : str, optional
The name of the beta scheduler to use, by default "cosine"
timesteps : int, optional
The number of timesteps to run the diffusion process, by default 1000
schedule_fn_kwargs : Union[dict, None], optional
Additional arguments to pass to the beta scheduler function, by default None
labels : str, optional
The type of labels to condition the model on, by default "learned-embeddings"
set_lr : float, optional
The initial learning rate for the optimizer, by default 1e-4
config_scheduler : str, optional
The name of the learning rate scheduler to use, by default "CosineAnnealingLR"
Raises
------
ValueError
If the beta scheduler is not implemented
"""
super().__init__()
# Save the hyperparameters of the model but ignore the encoder_model and the U-Net model
self.save_hyperparameters(
ignore=["model", "sequence_encoder", "distance_map_encoder"]
)
self.model = model
self.sequence_encoder = sequence_encoder
if distance_map_encoder is not None:
self.distance_map_encoder = VAE.load_from_checkpoint(distance_map_encoder)
self.__freeze_distance_map_encoder()
else:
self.distance_map_encoder = None
# Learning rate params
self.set_lr = set_lr
self.config_scheduler = config_scheduler
self.beta_scheduler_fn = self.SCHEDULER_MAPPING.get(beta_scheduler)
if self.beta_scheduler_fn is None:
raise ValueError(f"unknown beta schedule {beta_scheduler}")
self.min_snr_loss = min_snr_loss
self.min_snr_gamma = min_snr_gamma
# Register scaling factor buffer (calculated during first training step)
# Used to normalize latent space to unit variance per Reference #3
self.register_buffer(
"latent_space_scaling_factor", torch.tensor(1.0, dtype=torch.float32)
)
# Calculate diffusion process parameters
betas = self.beta_scheduler_fn(timesteps)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
posterior_variance = (
betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
)
# Register diffusion process buffers
buffers = {
"betas": betas,
"alphas_cumprod": alphas_cumprod,
"alphas_cumprod_prev": alphas_cumprod_prev,
"sqrt_recip_alphas": torch.sqrt(1.0 / alphas),
"sqrt_alphas_cumprod": torch.sqrt(alphas_cumprod),
"sqrt_one_minus_alphas_cumprod": torch.sqrt(1.0 - alphas_cumprod),
"posterior_variance": posterior_variance,
}
for name, buffer in buffers.items():
self.register_buffer(name, buffer)
# Store timesteps information
self.num_timesteps = int(betas.shape[0])
self.monitor = "epoch_val_loss"
def __freeze_distance_map_encoder(self):
self.distance_map_encoder.eval()
for param in self.distance_map_encoder.parameters():
param.requires_grad = False
# Remove mixed precision from this function, I've experienced numerical instability here
[docs]
@autocast(device_type="cuda", enabled=False)
def q_sample(
self, x_start: torch.Tensor, t: int, noise: torch.Tensor = None
) -> torch.Tensor:
"""
Add the noise to x_start tensor based on the timestamp t
Parameters
----------
x_start : torch.Tensor
The starting image tensor
t : int
The timestep of the denoising-diffusion process
noise : torch.Tensor, optional
Sampled noise to add, by default None
Returns
-------
torch.Tensor
Returns the properly (according to the timestamp) noised tensor
"""
if noise is None:
noise = torch.randn_like(x_start)
# Extract the necessary values from the buffers to calculate the noise to be added
sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape)
sqrt_one_minus_alphas_cumprod_t = extract(
self.sqrt_one_minus_alphas_cumprod, t, x_start.shape
)
# Return the noised tensor based on the timestamp
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
[docs]
def sequence2labels(
self, sequences: List, sequence_mask, ionic_strength
) -> torch.Tensor:
"""
Converts sequences to labels based on user defined models,
Parameters
----------
sequences : List
A list of sequences to convert to labels
Returns
-------
torch.Tensor
Returns the labels for the decoder
Raises
------
ValueError
If the labels are not one of the three options
"""
encoded = self.sequence_encoder(sequences, sequence_mask, ionic_strength)
return encoded
[docs]
def p_loss(
self,
x_start: torch.Tensor,
t: int,
labels: torch.Tensor,
mask: torch.Tensor,
ionic_strengths: torch.Tensor,
noise: torch.Tensor = None,
) -> torch.Tensor:
"""
A function that runs the model and calculates the loss based on the
predicted noise and the actual noise. The loss can either be L1 or L2.
Parameters
----------
x_start : torch.Tensor
The starting image tensor
t : int
The timestep along the denoising-diffusion process
labels : torch.Tensor, optional
Labels to condition the model on, by default None
noise : torch.Tensor, optional
Sampled noise from N(0,I), by default None
Returns
-------
torch.Tensor
Returns the loss
Raises
------
ValueError
If the loss type is not one of the two options (l1, l2)
"""
if noise is None:
noise = torch.randn_like(x_start)
# Offset noise that seems to improve the inference
# According to https://www.crosslabs.org/blog/diffusion-with-offset-noise
# noise += 0.1 * torch.randn(
# x_start.shape[0], x_start.shape[1], 1, 1, device=self.device
# )
# Noise the input data
x_noised = self.q_sample(x_start, t, noise=noise)
# Get the labels to condition the model on
labels = self.sequence2labels(labels, mask, ionic_strengths)
# Run the model to predict the noise
predicted_noise = self.model(x_noised, t, labels, mask)
# The following adapted from:
# https://github.com/huggingface/diffusers/blob/78a78515d64736469742e5081337dbcf60482750/examples/text_to_image/train_text_to_image.py#L927
if self.min_snr_loss:
# Apply min-SNR weighting as per Section 3.4 of https://arxiv.org/abs/2303.09556
# This improves training stability by reweighting timestep losses
snr = self.compute_snr(t)
# Calculate weight using min(snr, γ) / snr formula
# Handle zero SNR case by replacing potential infinities with 1.0
snr_weight = torch.clamp(self.min_snr_gamma / snr, min=1.0)
# Apply the SNR-weighted MSE loss
# First compute per-element losses, then average across spatial dimensions
# Finally, apply SNR weights and average across batch
mse_loss_raw = F.mse_loss(noise, predicted_noise, reduction="none")
mse_loss_per_sample = mse_loss_raw.mean(
dim=list(range(1, len(mse_loss_raw.shape)))
)
loss = (mse_loss_per_sample * snr_weight).mean()
else:
loss = F.mse_loss(noise, predicted_noise)
return loss
[docs]
def forward(
self, x: torch.Tensor, labels: torch.Tensor, mask, ionic_strengths
) -> torch.Tensor:
"""
Forward pass of the model, calculates the loss based on the
predicted noise and the actual noise.
Parameters
----------
x : torch.Tensor
The starting tensor to noise/denoise
labels : torch.Tensor, optional
Sequences to condition the model on, by default None
Returns
-------
torch.Tensor
Returns the loss
"""
b, c, h, w, device = *x.shape, x.device
# Generate random timestamps to noise the tensor and learn the denoising process
timestamps = torch.randint(0, self.num_timesteps, (b,), device=device).long()
return self.p_loss(x, timestamps, labels, mask, ionic_strengths)
def _initialize_latent_scaling(self, latent_encoding: torch.Tensor) -> None:
"""
Initialize the latent space mean and standard deviation using the first batch
for z-scoring (standardization).
Parameters
----------
latent_encoding : torch.Tensor
Batch of encoded latent vectors
"""
# Calculate local mean and standard deviation
# local_mean = latent_encoding.mean()
local_std = latent_encoding.std()
# Gather from all processes and compute global mean and standard deviation
# gathered_mean = self.all_gather(local_mean)
gathered_std = self.all_gather(local_std)
# mean_mean = gathered_mean.mean()
mean_std = 1 / gathered_std.mean()
# Update the registered buffers with computed values
# self.latent_space_mean = mean_mean.float().to(self.device)
self.latent_space_scaling_factor = mean_std.float().to(self.device)
[docs]
def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
"""Training step that computes diffusion loss on a batch."""
latent_encoding, sequences, sequence_attention_mask, ionic_strengths = (
batch["data"],
batch["sequence"],
batch["attention_mask"],
batch["ionic_strengths"],
)
if self.distance_map_encoder is not None:
with torch.no_grad():
latent_encoding = self.distance_map_encoder.encode(
latent_encoding
).mode()
# Calculate scaling factor on first batch (only once during training)
if self.global_step == 0 and batch_idx == 0:
self._initialize_latent_scaling(latent_encoding)
# Z-score the latent encoding
latent_encoding = latent_encoding * self.latent_space_scaling_factor
# Compute loss
loss = self.forward(
latent_encoding,
labels=sequences,
mask=sequence_attention_mask,
ionic_strengths=ionic_strengths,
)
# Log training metrics
self.log("train_loss", loss, prog_bar=True, batch_size=latent_encoding.size(0))
return loss
[docs]
def validation_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
"""Validation step that evaluates diffusion loss on a batch."""
latent_encoding, sequences, sequence_attention_mask, ionic_strengths = (
batch["data"],
batch["sequence"],
batch["attention_mask"],
batch["ionic_strengths"],
)
if self.distance_map_encoder is not None:
with torch.no_grad():
latent_encoding = self.distance_map_encoder.encode(
latent_encoding
).mode()
# Z-score the latent encoding
latent_encoding = latent_encoding * self.latent_space_scaling_factor
# Compute loss
loss = self.forward(
latent_encoding,
labels=sequences,
mask=sequence_attention_mask,
ionic_strengths=ionic_strengths,
)
self.log(
"epoch_val_loss",
loss,
prog_bar=True,
sync_dist=True,
batch_size=latent_encoding.size(0),
)
return loss
[docs]
def compute_snr(self, timesteps):
"""
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
"""
alphas_cumprod = self.alphas_cumprod
sqrt_alphas_cumprod = alphas_cumprod**0.5
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
alpha = sqrt_alphas_cumprod[timesteps]
sigma = sqrt_one_minus_alphas_cumprod[timesteps]
# Compute SNR.
snr = (alpha / sigma) ** 2
return snr