starling.inference.constraints.Constraint

class Constraint[source]

Bases: ABC

Methods

__init__

Initialize base constraint with common parameters.

apply

Apply the constraint to the given latents.

bell_shaped_schedule

Bell-shaped schedule for time-dependent guidance strength.

compute_loss

Compute the loss for this constraint without applying gradients.

cosine_weight

Cosine schedule for time-dependent guidance strength.

get_adaptive_clip_threshold

Get an adaptive clipping threshold that follows a cosine schedule.

get_time_scale

Get the time-dependent scaling factor.

initialize

Called by the sampler to set model parameters.

should_apply_guidance

Check if guidance should be applied at the current timestep.

__init__(constraint_weight=1.0, schedule='cosine', verbose=True, guidance_start=0.0, guidance_end=1.0)[source]

Initialize base constraint with common parameters.

Parameters:
  • constraint_weight (float, default=1.0) – Weight factor for the constraint

  • schedule (str, default="cosine") – Scheduling function for time-dependent guidance strength

  • verbose (bool, default=True) – Whether to print debug information

  • guidance_start (float, default=0.0) – Normalized timestep to start applying guidance (0.0 = beginning)

  • guidance_end (float, default=1.0) – Normalized timestep to stop applying guidance (1.0 = end)

initialize(encoder_model, latent_space_scaling_factor, n_steps, sequence_length)[source]

Called by the sampler to set model parameters.

should_apply_guidance(timestep, total_steps)[source]

Check if guidance should be applied at the current timestep.

Parameters:
  • timestep (int) – Current diffusion timestep.

  • total_steps (int) – Total number of diffusion steps.

Returns:

True if current timestep is within the guidance window.

Return type:

bool

cosine_weight(t, total_steps, s=0.008)[source]

Cosine schedule for time-dependent guidance strength.

Parameters:
  • t (int) – Current timestep.

  • total_steps (int) – Total number of steps.

  • s (float, optional) – Smoothing parameter (default: 0.008).

Returns:

Guidance weight following cosine schedule.

Return type:

float

bell_shaped_schedule(timestep: int) float[source]

Bell-shaped schedule for time-dependent guidance strength.

Creates a schedule that gradually increases guidance strength, peaks at 60% through the sampling process, and then decreases again.

Parameters:

timestep (int) – Current diffusion timestep

Returns:

Guidance strength factor (peaks in the middle of sampling)

Return type:

float

get_adaptive_clip_threshold(timestep)[source]

Get an adaptive clipping threshold that follows a cosine schedule.

The threshold starts high at the beginning of sampling and gradually decreases, allowing larger gradients early on and more refined adjustments later.

Parameters:

timestep (int) – Current diffusion timestep.

Returns:

Clipping threshold for gradient magnitudes.

Return type:

float

compute_loss(distance_maps: Tensor) Tuple[Tensor, Tensor][source]

Compute the loss for this constraint without applying gradients.

Parameters:

distance_maps (torch.Tensor) – Pre-computed distance maps from the latents

Returns:

(per_batch_loss, total_loss) - Individual sample losses and mean loss

Return type:

tuple[torch.Tensor, torch.Tensor]

apply(latents: Tensor, timestep: int, logger=None) Tensor[source]

Apply the constraint to the given latents.

get_time_scale(timestep: int) float[source]

Get the time-dependent scaling factor.