import math
from abc import ABC
from typing import Tuple
import torch
from einops import rearrange, reduce
from tqdm.auto import tqdm
from starling.utilities import helix_dm
[docs]
def symmetrize_distance_maps(dist_maps: torch.Tensor) -> torch.Tensor:
"""
Symmetrize a batch of distance maps in PyTorch.
Parameters
----------
dist_maps : torch.Tensor
Tensor of shape (B, N, N) representing pairwise distances.
Returns
-------
torch.Tensor
Symmetrized distance maps with zero diagonal.
"""
B, C, N, _ = dist_maps.shape
# Clone to avoid modifying input tensor in-place
dist_maps = dist_maps.clone()
# Reflect upper triangle onto lower triangle
i, j = torch.triu_indices(N, N, offset=1)
dist_maps[:, :, j, i] = dist_maps[:, :, i, j]
# Set diagonal to zero
dist_maps[:, :, torch.arange(N), torch.arange(N)] = 0.0
return dist_maps
[docs]
class Constraint(ABC):
[docs]
def __init__(
self,
constraint_weight=1.0,
schedule="cosine",
verbose=True,
guidance_start=0.0,
guidance_end=1.0,
):
"""Initialize base constraint with common parameters.
Parameters
----------
constraint_weight : float, default=1.0
Weight factor for the constraint
schedule : str, default="cosine"
Scheduling function for time-dependent guidance strength
verbose : bool, default=True
Whether to print debug information
guidance_start : float, default=0.0
Normalized timestep to start applying guidance (0.0 = beginning)
guidance_end : float, default=1.0
Normalized timestep to stop applying guidance (1.0 = end)
"""
# These will be set by the sampler
self.encoder_model = None
self.latent_space_scaling_factor = None
self.n_steps = None
self.device = None
# User-controlled parameters
self.constraint_weight = constraint_weight
self.schedule = schedule
self.verbose = verbose
self.guidance_start = guidance_start
self.guidance_end = guidance_end
def _setup_constraint(self):
"""Set up constraint-specific resources."""
pass # Implemented by subclasses
[docs]
def initialize(
self, encoder_model, latent_space_scaling_factor, n_steps, sequence_length
):
"""Called by the sampler to set model parameters."""
self.encoder_model = encoder_model
self.latent_space_scaling_factor = latent_space_scaling_factor
self.n_steps = n_steps
self.device = encoder_model.device
self.sequence_length = sequence_length
self._setup_constraint()
return self
[docs]
def should_apply_guidance(self, timestep, total_steps):
"""
Check if guidance should be applied at the current timestep.
Parameters
----------
timestep : int
Current diffusion timestep.
total_steps : int
Total number of diffusion steps.
Returns
-------
bool
True if current timestep is within the guidance window.
"""
t_frac = timestep / total_steps
reverse = 1 - t_frac
return self.guidance_start <= reverse <= self.guidance_end
[docs]
def cosine_weight(self, t, total_steps, s=0.008):
"""
Cosine schedule for time-dependent guidance strength.
Parameters
----------
t : int
Current timestep.
total_steps : int
Total number of steps.
s : float, optional
Smoothing parameter (default: 0.008).
Returns
-------
float
Guidance weight following cosine schedule.
"""
t_scaled = t / total_steps
return math.cos(t_scaled * math.pi / 2) ** 2
[docs]
def bell_shaped_schedule(self, timestep: int) -> float:
"""Bell-shaped schedule for time-dependent guidance strength.
Creates a schedule that gradually increases guidance strength,
peaks at 60% through the sampling process, and then decreases again.
Parameters
----------
timestep : int
Current diffusion timestep
Returns
-------
float
Guidance strength factor (peaks in the middle of sampling)
"""
normalized_t = timestep / self.n_steps
# Peak at 60% through the sampling process
return math.sin(normalized_t * math.pi) * math.exp(
-((normalized_t - 0.6) ** 2) / 0.1
)
[docs]
def get_adaptive_clip_threshold(self, timestep):
"""
Get an adaptive clipping threshold that follows a cosine schedule.
The threshold starts high at the beginning of sampling and gradually
decreases, allowing larger gradients early on and more refined
adjustments later.
Parameters
----------
timestep : int
Current diffusion timestep.
Returns
-------
float
Clipping threshold for gradient magnitudes.
"""
max_threshold = 2.0 # Maximum threshold at beginning
min_threshold = 1.0 # Minimum threshold at end
# Cosine decay from max_threshold to min_threshold
fraction_complete = 1 - (timestep / self.n_steps)
cosine_factor = math.cos(fraction_complete * math.pi / 2)
return min_threshold + cosine_factor**2 * (max_threshold - min_threshold)
[docs]
def compute_loss(
self, distance_maps: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute the loss for this constraint without applying gradients.
Parameters
----------
distance_maps : torch.Tensor
Pre-computed distance maps from the latents
Returns
-------
tuple[torch.Tensor, torch.Tensor]
(per_batch_loss, total_loss) - Individual sample losses and mean loss
"""
raise NotImplementedError("Subclasses should implement compute_loss")
[docs]
def apply(self, latents: torch.Tensor, timestep: int, logger=None) -> torch.Tensor:
"""Apply the constraint to the given latents."""
# Check if the constraint should be applied
if not self.should_apply_guidance(timestep, self.n_steps):
return latents
with torch.inference_mode(False):
latents_copy = latents.clone().requires_grad_(True)
scaled_latents = latents_copy / self.latent_space_scaling_factor
distance_maps = self.encoder_model.decode(scaled_latents)
distance_maps = symmetrize_distance_maps(distance_maps)
# Get per-sample losses and total loss
per_batch_loss, loss = self.compute_loss(distance_maps)
# Compute gradients
base_grad = torch.autograd.grad(loss, latents_copy)[0]
# Get time-dependent scaling
time_scale = self.get_time_scale(timestep)
# Calculate per-sample loss scaling
mean_loss = per_batch_loss.mean()
if mean_loss > 1e-6:
loss_scale = per_batch_loss / mean_loss
else:
# When mean loss is very small, use a uniform scale
loss_scale = torch.ones_like(per_batch_loss)
# loss_scale = per_batch_loss / per_batch_loss.mean()
# Prevent extreme scaling
max_scale_factor = 2.0
loss_scale = torch.clamp(loss_scale, max=max_scale_factor)
# Reshape loss_scale to match the shape of the latents
loss_scale = rearrange(loss_scale, "b -> b 1 1 1")
# Now apply meaningful scaling
update = -self.constraint_weight * time_scale * loss_scale * base_grad
# Per-sample gradient norms
grad_flat = rearrange(update, "b c h w -> b (c h w)")
grad_norms = grad_flat.norm(dim=1, keepdim=True)
# Compute per-sample clipping factors
max_allowed_grad_norm = 1.0
clip_factors = (max_allowed_grad_norm / (grad_norms + 1e-6)).clamp(max=1.0)
clip_factors = rearrange(clip_factors, "b 1 -> b 1 1 1")
# Apply clipping
update = update * clip_factors
# Log if logger is provided
if logger is not None and self.verbose:
logger.update(
timestep,
self.__class__.__name__,
{
"loss": loss.item(),
"grad_norm": update.norm().item(),
# "update_norm": update_norm,
"time_scale": time_scale,
"min_loss_scale": loss_scale.min().item(),
"max_loss_scale": loss_scale.max().item(),
# "clipped": update_norm > 1.0,
},
)
return latents + update.detach()
[docs]
def get_time_scale(self, timestep: int) -> float:
"""Get the time-dependent scaling factor."""
if self.schedule == "cosine":
return self.cosine_weight(timestep, total_steps=self.n_steps)
elif self.schedule == "bell_shaped":
return self.bell_shaped_schedule(timestep)
else:
return 1.0 - (timestep / self.n_steps)
[docs]
class BondConstraint(Constraint):
[docs]
def __init__(self, bond_length=3.81, tolerance=0.0, force_constant=2.0, **kwargs):
super().__init__(**kwargs)
self.bond_length = bond_length
self.tolerance = tolerance
self.force_constant = force_constant
[docs]
def compute_loss(
self, distance_maps: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute bond loss based on distance maps.
This loss penalizes deviations from the ideal bond length of 3.81 Å
and applies a flat-bottom potential for deviations beyond 1.0 Å.
Parameters
----------
distance_maps : torch.Tensor
Pre-computed distance maps from the latents
Returns
-------
Tuple[torch.Tensor, torch.Tensor]
Per-batch loss and mean loss
"""
distance_maps = distance_maps[
:, :, : self.sequence_length, : self.sequence_length
].squeeze()
# Take the one off diagonal
bonds = torch.diagonal(distance_maps, offset=1, dim1=1, dim2=2)
# Calculate deviation from the ideal bond length
deviation = torch.abs(bonds - self.bond_length)
# Apply flat-bottom: only penalize deviations beyond tolerance
excess = torch.nn.functional.relu(deviation - self.tolerance)
# Calculate harmonic potential for the excess deviation
per_batch_loss = (0.5 * self.force_constant * excess**2).mean(dim=1)
return per_batch_loss, per_batch_loss.mean()
[docs]
class StericClashConstraint(Constraint):
[docs]
def __init__(self, steric_clash_definition=5.0, force_constant=2.0, **kwargs):
super().__init__(**kwargs)
self.steric_clash_definition = steric_clash_definition
self.force_constant = force_constant
[docs]
def compute_loss(
self, distance_maps: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute steric clash loss based on distance maps.
This loss penalizes distances below a certain threshold (default 5.0 Å)
and applies a flat-bottom potential for distances below this threshold.
Parameters
----------
distance_maps : torch.Tensor
Pre-computed distance maps from the latents
Returns
-------
torch.Tensor
Per-batch loss and mean loss
"""
mask = torch.triu(
torch.ones(
self.sequence_length, self.sequence_length, device=distance_maps.device
),
diagonal=2,
)
distance_maps = distance_maps[
:, :, : self.sequence_length, : self.sequence_length
]
# Calculate the deviation from steric_clash_definition (only when distances are smaller)
deviation = torch.relu(self.steric_clash_definition - distance_maps)
# Apply harmonic potential formula: 0.5 * force_constant * deviation^2
steric_clash = 0.5 * self.force_constant * deviation**2
# Apply mask to consider only upper triangle without diagonals
steric_clash = steric_clash * mask
# Sum across all residue pairs and normalize
steric_clash = reduce(steric_clash, "b c h w -> b", "sum")
normalization_factor = mask.sum()
per_batch_loss = steric_clash / normalization_factor
return per_batch_loss, per_batch_loss.mean()
[docs]
class HelicityConstraint(Constraint):
[docs]
def __init__(
self, resid_start, resid_end, tolerance=0.0, force_constant=2.0, **kwargs
):
super().__init__(**kwargs)
self.resid_start = resid_start
self.resid_end = resid_end
self.tolerance = tolerance
self.force_constant = force_constant
# These will be initialized when the model is available
self.helix_ref = None
self.mask = None
self.weights = None
def _setup_constraint(self):
"""Set up device-specific tensors."""
if not self.encoder_model:
return
# Create helix reference
self.helix_ref = torch.from_numpy(helix_dm(L=384)).to(self.device)
# Create mask
self.mask = torch.zeros((384, 384), device=self.device)
self.mask[
self.resid_start : self.resid_end, self.resid_start : self.resid_end
] = torch.triu(
torch.ones(
(self.resid_end - self.resid_start, self.resid_end - self.resid_start)
),
diagonal=1,
)
# Create weights
self.weights = 1.0 / (self.helix_ref + 1e-2)
[docs]
def compute_loss(
self, distance_maps: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# Calculate deviation from reference helix
deviation = torch.abs(distance_maps - self.helix_ref)
# Apply flat-bottom potential: only penalize deviations beyond tolerance
excess = torch.nn.functional.relu(deviation - self.tolerance)
# Calculate harmonic potential for the excess deviation
region_loss = 0.5 * self.force_constant * (excess**2) * self.mask
normalization_factor = self.mask.sum()
per_batch_loss = (
reduce(region_loss, "b c h w -> b", "sum") / normalization_factor
)
# Return per-batch and mean loss
return per_batch_loss, per_batch_loss.mean()
[docs]
class DistanceConstraint(Constraint):
[docs]
def __init__(
self, resid1, resid2, target, tolerance=0.0, force_constant=2.0, **kwargs
):
"""Create constraint for distance between two residues."""
super().__init__(**kwargs)
self.resid1 = resid1
self.resid2 = resid2
self.target = target
self.tolerance = tolerance
self.force_constant = force_constant
[docs]
def compute_loss(
self, distance_maps: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# Extract distances between specified residues
distances = distance_maps[:, :, self.resid1, self.resid2]
# Calculate deviation from target
deviation = torch.abs(distances - self.target)
# Apply flat-bottom: only penalize deviations beyond tolerance
excess = torch.nn.functional.relu(deviation - self.tolerance)
# Calculate harmonic potential for the excess deviation
per_batch_loss = 0.5 * self.force_constant * excess**2
per_batch_loss = rearrange(per_batch_loss, "b 1 -> b")
return per_batch_loss, per_batch_loss.mean()
[docs]
class RgConstraint(Constraint):
[docs]
def __init__(self, target, tolerance=0.0, force_constant=2.0, **kwargs):
"""Create constraint for radius of gyration (Rg).
Parameters
----------
target : float
Target Rg value in Angstroms
tolerance : float, default=0.0
Allowed deviation from target before penalty applies
force_constant : float, default=2.0
Force constant for the harmonic potential
**kwargs
Additional parameters passed to parent Constraint class
"""
super().__init__(**kwargs)
self.target = target
self.tolerance = tolerance
self.force_constant = force_constant
def __compute_rg(self, distance_maps: torch.Tensor) -> torch.Tensor:
"""Calculate radius of gyration from distance maps.
Rg = sqrt(sum(d_ij^2) / (2*N^2)) where d_ij are pairwise distances.
Parameters
----------
distance_maps : torch.Tensor
Protein distance maps
Returns
-------
torch.Tensor
Calculated Rg values for each protein in the batch
"""
sequence_length = torch.tensor(self.sequence_length, device=self.device)
distance_maps = distance_maps[
:, :, : self.sequence_length, : self.sequence_length
]
squared_distances = torch.square(distance_maps)
distances = reduce(squared_distances, "b c h w -> b", "sum")
rg_vals = torch.sqrt(distances / (2 * torch.pow(sequence_length, 2)))
return rg_vals
[docs]
def compute_loss(
self, distance_maps: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute loss based on deviation from target Rg.
Parameters
----------
distance_maps : torch.Tensor
Pre-computed distance maps from the latents
Returns
-------
Tuple[torch.Tensor, torch.Tensor]
Per-batch loss and mean loss
"""
predicted_rg = self.__compute_rg(distance_maps)
# Calculate deviation from target
deviation = torch.abs(predicted_rg - self.target)
# Apply flat-bottom: only penalize deviations beyond tolerance
excess = torch.nn.functional.relu(deviation - self.tolerance)
# Calculate harmonic potential for the excess deviation
per_batch_loss = 0.5 * self.force_constant * excess**2
return per_batch_loss, per_batch_loss.mean()
[docs]
class ReConstraint(Constraint):
[docs]
def __init__(self, target, tolerance=0.0, force_constant=2.0, **kwargs):
"""Create constraint for end-to-end distance."""
super().__init__(**kwargs)
self.target = target
self.tolerance = tolerance
self.force_constant = force_constant
[docs]
def compute_loss(
self, distance_maps: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
distances = distance_maps[:, :, 0, self.sequence_length]
# Calculate deviation from target
deviation = torch.abs(distances - self.target)
# Apply flat-bottom: only penalize deviations beyond tolerance
excess = torch.nn.functional.relu(deviation - self.tolerance)
# Calculate harmonic potential for the excess deviation
per_batch_loss = 0.5 * self.force_constant * excess**2
per_batch_loss = rearrange(per_batch_loss, "b 1 -> b")
return per_batch_loss, per_batch_loss.mean()
[docs]
class MultiConstraint(Constraint):
"""Combines multiple constraints into a single optimization step."""
[docs]
def __init__(
self,
constraints,
schedule="cosine",
verbose=True,
):
"""
Parameters
----------
constraints : list
List of constraint objects to combine
constraint_weights : list, optional
Relative weights for each constraint (defaults to equal weights)
guidance_scale : float
Overall guidance scale for the combined constraint
schedule : str
Time schedule for constraint application ("cosine" or "linear")
verbose : bool
Whether to print debug info
"""
super().__init__(schedule=schedule, verbose=verbose)
self.constraints = constraints
self.constraint_weights = [
constraint.constraint_weight for constraint in constraints
]
self.guidance_starts = [constraint.guidance_start for constraint in constraints]
self.guidance_ends = [constraint.guidance_end for constraint in constraints]
[docs]
def initialize(
self, encoder_model, latent_space_scaling_factor, n_steps, sequence_length
):
"""Initialize all constraints with the model parameters."""
super().initialize(
encoder_model, latent_space_scaling_factor, n_steps, sequence_length
)
# Initialize all subconstraints
for constraint in self.constraints:
constraint.initialize(
encoder_model, latent_space_scaling_factor, n_steps, sequence_length
)
return self
[docs]
def compute_loss(
self, distance_maps: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute weighted combination of all constraint losses."""
total_per_batch_loss = None
total_loss = 0.0
for i, (constraint, weight) in enumerate(
zip(self.constraints, self.constraint_weights)
):
# Get per-batch and mean loss from each constraint
per_batch_loss, mean_loss = constraint.compute_loss(distance_maps)
# Apply weight to both
weighted_per_batch = weight * per_batch_loss
weighted_loss = weight * mean_loss
# Accumulate
if total_per_batch_loss is None:
total_per_batch_loss = weighted_per_batch
else:
total_per_batch_loss = total_per_batch_loss + weighted_per_batch
total_loss += weighted_loss
return total_per_batch_loss, total_loss
[docs]
class ConstraintLogger:
[docs]
def __init__(self, n_steps, verbose=True, update_freq=1):
self.n_steps = n_steps
self.verbose = verbose
self.update_freq = update_freq
self.constraint_data = {}
self.progress_bar = None
self.start_time = None
self.steps_applied = 0 # Add a counter for steps where constraint was applied
[docs]
def setup(self):
"""
Set up the progress bar for constraint logging.
Initializes the tqdm progress bar if verbose mode is enabled and
resets the step counter.
"""
self.steps_applied = 0 # Reset counter
if self.verbose:
self.progress_bar = tqdm(
desc="Applying constraints",
position=1,
leave=False,
)
[docs]
def update(self, timestep, constraint_name, metrics):
"""
Update logger with new constraint metrics.
Parameters
----------
timestep : int
Current diffusion timestep.
constraint_name : str
Name of the constraint being logged.
metrics : dict
Dictionary containing constraint metrics (loss, scale, etc.).
"""
if not self.verbose or self.progress_bar is None:
return
# Increment our internal counter of steps where constraint was applied
self.steps_applied += 1
# Store most recent data
self.constraint_data[constraint_name] = metrics
# Update the progress bar's position directly
self.progress_bar.n = self.steps_applied
# Create status message
status_parts = []
# Add constraint info
for name, data in self.constraint_data.items():
loss = data.get("loss", 0)
grad_norm = data.get("grad_norm", 0)
status_parts.append(f"{name[:3]} loss: {loss:.4f} grad: {grad_norm:.2f}")
# Update status text and refresh
self.progress_bar.set_postfix_str(" | ".join(status_parts))
self.progress_bar.refresh()
[docs]
def close(self):
"""
Close the progress bar and clean up.
Closes the tqdm progress bar if it exists in verbose mode.
"""
if self.verbose and self.progress_bar is not None:
self.progress_bar.close()