starling.inference.constraints.StericClashConstraint
- class StericClashConstraint[source]
Bases:
ConstraintMethods
Initialize base constraint with common parameters.
applyApply the constraint to the given latents.
bell_shaped_scheduleBell-shaped schedule for time-dependent guidance strength.
Compute steric clash loss based on distance maps.
cosine_weightCosine schedule for time-dependent guidance strength.
get_adaptive_clip_thresholdGet an adaptive clipping threshold that follows a cosine schedule.
get_time_scaleGet the time-dependent scaling factor.
initializeCalled by the sampler to set model parameters.
should_apply_guidanceCheck if guidance should be applied at the current timestep.
- __init__(steric_clash_definition=5.0, force_constant=2.0, **kwargs)[source]
Initialize base constraint with common parameters.
- Parameters:
constraint_weight (float, default=1.0) – Weight factor for the constraint
schedule (str, default="cosine") – Scheduling function for time-dependent guidance strength
verbose (bool, default=True) – Whether to print debug information
guidance_start (float, default=0.0) – Normalized timestep to start applying guidance (0.0 = beginning)
guidance_end (float, default=1.0) – Normalized timestep to stop applying guidance (1.0 = end)
- compute_loss(distance_maps: Tensor) Tuple[Tensor, Tensor][source]
Compute steric clash loss based on distance maps. This loss penalizes distances below a certain threshold (default 5.0 Å) and applies a flat-bottom potential for distances below this threshold.
- Parameters:
distance_maps (torch.Tensor) – Pre-computed distance maps from the latents
- Returns:
Per-batch loss and mean loss
- Return type: