starling.inference.constraints.StericClashConstraint

class StericClashConstraint[source]

Bases: Constraint

Methods

__init__

Initialize base constraint with common parameters.

apply

Apply the constraint to the given latents.

bell_shaped_schedule

Bell-shaped schedule for time-dependent guidance strength.

compute_loss

Compute steric clash loss based on distance maps.

cosine_weight

Cosine schedule for time-dependent guidance strength.

get_adaptive_clip_threshold

Get an adaptive clipping threshold that follows a cosine schedule.

get_time_scale

Get the time-dependent scaling factor.

initialize

Called by the sampler to set model parameters.

should_apply_guidance

Check if guidance should be applied at the current timestep.

__init__(steric_clash_definition=5.0, force_constant=2.0, **kwargs)[source]

Initialize base constraint with common parameters.

Parameters:
  • constraint_weight (float, default=1.0) – Weight factor for the constraint

  • schedule (str, default="cosine") – Scheduling function for time-dependent guidance strength

  • verbose (bool, default=True) – Whether to print debug information

  • guidance_start (float, default=0.0) – Normalized timestep to start applying guidance (0.0 = beginning)

  • guidance_end (float, default=1.0) – Normalized timestep to stop applying guidance (1.0 = end)

compute_loss(distance_maps: Tensor) Tuple[Tensor, Tensor][source]

Compute steric clash loss based on distance maps. This loss penalizes distances below a certain threshold (default 5.0 Å) and applies a flat-bottom potential for distances below this threshold.

Parameters:

distance_maps (torch.Tensor) – Pre-computed distance maps from the latents

Returns:

Per-batch loss and mean loss

Return type:

torch.Tensor