starling.samplers.plms_sampler.PLMSSampler

class PLMSSampler[source]

Bases: object

Methods

__init__

generate_labels

Generate labels to condition the generative process on.

p_sample_plms

sample

Sample the generative process using the DDIM model.

__init__(ddpm_model, encoder_model, n_steps, ionic_strength=150, ddim_discretize='uniform', schedule='linear', **kwargs)[source]
generate_labels(labels: str) Tensor[source]

Generate labels to condition the generative process on.

Parameters:

labels (str) – A sequence to generate labels from.

Returns:

The labels to condition the generative process on.

Return type:

torch.Tensor

sample(num_conformations: int, labels: Tensor, repeat_noise: bool = False, temperature: float = 1.0, show_per_step_progress_bar: bool = True, batch_count: int = 1, max_batch_count: int = 1, constraint=None) Tensor[source]

Sample the generative process using the DDIM model.

Parameters:
  • num_conformations (int) – Number of conformations to generate.

  • labels (torch.Tensor) – The labels to condition the generative process on.

  • repeat_noise (bool, optional) – _description_, by default False

  • temperature (float, optional) – _description_, by default 1.0

  • show_per_step_progress_bar (bool, optional) – whether to show progress bar per step.

  • batch_count (int, optional) – The batch count for the progress bar, by default 1

  • max_batch_count (int, optional) – The maximum batch count for the progress bar, by default 1

Returns:

The generated distance maps.

Return type:

torch.Tensor

p_sample_plms(x, c, t, attention_mask, index, temperature=1.0, old_eps=None, t_next=None)[source]