starling.inference.constraints.Constraint
- class Constraint[source]
Bases:
ABCMethods
Initialize base constraint with common parameters.
Apply the constraint to the given latents.
Bell-shaped schedule for time-dependent guidance strength.
Compute the loss for this constraint without applying gradients.
Cosine schedule for time-dependent guidance strength.
Get an adaptive clipping threshold that follows a cosine schedule.
Get the time-dependent scaling factor.
Called by the sampler to set model parameters.
Check if guidance should be applied at the current timestep.
- __init__(constraint_weight=1.0, schedule='cosine', verbose=True, guidance_start=0.0, guidance_end=1.0)[source]
Initialize base constraint with common parameters.
- Parameters:
constraint_weight (float, default=1.0) – Weight factor for the constraint
schedule (str, default="cosine") – Scheduling function for time-dependent guidance strength
verbose (bool, default=True) – Whether to print debug information
guidance_start (float, default=0.0) – Normalized timestep to start applying guidance (0.0 = beginning)
guidance_end (float, default=1.0) – Normalized timestep to stop applying guidance (1.0 = end)
- initialize(encoder_model, latent_space_scaling_factor, n_steps, sequence_length)[source]
Called by the sampler to set model parameters.
- should_apply_guidance(timestep, total_steps)[source]
Check if guidance should be applied at the current timestep.
- cosine_weight(t, total_steps, s=0.008)[source]
Cosine schedule for time-dependent guidance strength.
- bell_shaped_schedule(timestep: int) float[source]
Bell-shaped schedule for time-dependent guidance strength.
Creates a schedule that gradually increases guidance strength, peaks at 60% through the sampling process, and then decreases again.
- get_adaptive_clip_threshold(timestep)[source]
Get an adaptive clipping threshold that follows a cosine schedule.
The threshold starts high at the beginning of sampling and gradually decreases, allowing larger gradients early on and more refined adjustments later.
- compute_loss(distance_maps: Tensor) Tuple[Tensor, Tensor][source]
Compute the loss for this constraint without applying gradients.
- Parameters:
distance_maps (torch.Tensor) – Pre-computed distance maps from the latents
- Returns:
(per_batch_loss, total_loss) - Individual sample losses and mean loss
- Return type: