starling.models.vae.VAE

class VAE[source]

Bases: LightningModule

Methods

__init__

The variational autoencoder (VAE) model that is used to learn the latent space of protein distance maps.

add_module

Add a child module to the current module.

all_gather

Gather tensors or collections of tensors from multiple processes.

apply

Apply fn recursively to every submodule (as returned by .children()) as well as self.

backward

Called to perform backward on the loss returned in training_step().

bfloat16

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers

Return an iterator over module buffers.

children

Return an iterator over immediate children modules.

clip_gradients

Handles gradient clipping internally.

compile

Compile this Module's forward using torch.compile().

configure_callbacks

Configure model-specific callbacks.

configure_gradient_clipping

Perform gradient clipping for the optimizer parameters.

configure_model

Hook to create modules in a strategy and precision aware context.

configure_optimizers

Configure the optimizer and the learning rate scheduler for the model.

configure_sharded_model

Deprecated.

cpu

See torch.nn.Module.cpu().

cuda

Moves all model parameters and buffers to the GPU.

decode

Decodes the latent space back into the original data

double

See torch.nn.Module.double().

encode

Takes the data and encodes it into the latent space, by returning the mean and log variance

eval

Set the module in evaluation mode.

extra_repr

Return the extra representation of the module.

float

See torch.nn.Module.float().

forward

Forward pass of the VAE

freeze

Freeze all params for inference.

gaussian_likelihood

Calculates the likelihood of input data given latent space (p(x|z)) under Gaussian assumption.

get_buffer

Return the buffer given by target if it exists, otherwise throw an error.

get_extra_state

Return any extra state to include in the module's state_dict.

get_parameter

Return the parameter given by target if it exists, otherwise throw an error.

get_submodule

Return the submodule given by target if it exists, otherwise throw an error.

half

See torch.nn.Module.half().

ipu

Move all model parameters and buffers to the IPU.

load_from_checkpoint

Primary way of loading a model from a checkpoint.

load_state_dict

Copy parameters and buffers from state_dict into this module and its descendants.

log

Log a key, value pair.

log_dict

Log a dictionary of values at once.

lr_scheduler_step

Override this method to adjust the default way the Trainer calls each scheduler.

lr_schedulers

Returns the learning rate scheduler(s) that are being used during training.

manual_backward

Call this directly from your training_step() when doing optimizations manually.

modules

Return an iterator over all modules in the network.

mtia

Move all model parameters and buffers to the MTIA.

named_buffers

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

on_after_backward

Called after loss.backward() and before optimizers are stepped.

on_after_batch_transfer

Override to alter or apply batch augmentations to your batch after it is transferred to the device.

on_before_backward

Called before loss.backward().

on_before_batch_transfer

Override to alter or apply batch augmentations to your batch before it is transferred to the device.

on_before_optimizer_step

Called before optimizer.step().

on_before_zero_grad

Called after training_step() and before optimizer.zero_grad().

on_fit_end

Called at the very end of fit.

on_fit_start

Called at the very beginning of fit.

on_load_checkpoint

Called by Lightning to restore your model.

on_predict_batch_end

Called in the predict loop after the batch.

on_predict_batch_start

Called in the predict loop before anything happens for that batch.

on_predict_end

Called at the end of predicting.

on_predict_epoch_end

Called at the end of predicting.

on_predict_epoch_start

Called at the beginning of predicting.

on_predict_model_eval

Called when the predict loop starts.

on_predict_start

Called at the beginning of predicting.

on_save_checkpoint

Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.

on_test_batch_end

Called in the test loop after the batch.

on_test_batch_start

Called in the test loop before anything happens for that batch.

on_test_end

Called at the end of testing.

on_test_epoch_end

Called in the test loop at the very end of the epoch.

on_test_epoch_start

Called in the test loop at the very beginning of the epoch.

on_test_model_eval

Called when the test loop starts.

on_test_model_train

Called when the test loop ends.

on_test_start

Called at the beginning of testing.

on_train_batch_end

Called in the training loop after the batch.

on_train_batch_start

Called in the training loop before anything happens for that batch.

on_train_end

Called at the end of training before logger experiment is closed.

on_train_epoch_end

Calculate and log the mean training losses for the epoch.

on_train_epoch_start

Called in the training loop at the very beginning of the epoch.

on_train_start

Called at the beginning of training after sanity check.

on_validation_batch_end

Called in the validation loop after the batch.

on_validation_batch_start

Called in the validation loop before anything happens for that batch.

on_validation_end

Called at the end of validation.

on_validation_epoch_end

Called in the validation loop at the very end of the epoch.

on_validation_epoch_start

Called in the validation loop at the very beginning of the epoch.

on_validation_model_eval

Called when the validation loop starts.

on_validation_model_train

Called when the validation loop ends.

on_validation_model_zero_grad

Called by the training loop to release gradients before entering the validation loop.

on_validation_start

Called at the beginning of validation.

optimizer_step

Override this method to adjust the default way the Trainer calls the optimizer.

optimizer_zero_grad

Override this method to change the default behaviour of optimizer.zero_grad().

optimizers

Returns the optimizer(s) that are being used during training.

parameters

Return an iterator over module parameters.

predict_dataloader

An iterable or collection of iterables specifying prediction samples.

predict_step

Step function called during predict().

prepare_data

Use this to download and prepare data.

print

Prints only from process 0.

register_backward_hook

Register a backward hook on the module.

register_buffer

Add a buffer to the module.

register_forward_hook

Register a forward hook on the module.

register_forward_pre_hook

Register a forward pre-hook on the module.

register_full_backward_hook

Register a backward hook on the module.

register_full_backward_pre_hook

Register a backward pre-hook on the module.

register_load_state_dict_post_hook

Register a post-hook to be run after module's load_state_dict() is called.

register_load_state_dict_pre_hook

Register a pre-hook to be run before module's load_state_dict() is called.

register_module

Alias for add_module().

register_parameter

Add a parameter to the module.

register_state_dict_post_hook

Register a post-hook for the state_dict() method.

register_state_dict_pre_hook

Register a pre-hook for the state_dict() method.

reparameterize

Reparametarization trick that allows for the flow of gradients through the non-random process.

requires_grad_

Change if autograd should record operations on parameters in this module.

save_hyperparameters

Save arguments to hparams attribute.

set_extra_state

Set extra state contained in the loaded state_dict.

set_submodule

Set the submodule given by target if it exists, otherwise throw an error.

setup

Set up the model, including optional compilation.

share_memory

See torch.Tensor.share_memory_().

state_dict

Return a dictionary containing references to the whole state of the module.

symmetrize

Symmetrizes the reconstructed data so that the weights can learn other patterns.

teardown

Called at the end of fit (train + validate), validate, test, or predict.

test_dataloader

An iterable or collection of iterables specifying test samples.

test_step

Operates on a single batch of data from the test set.

to

See torch.nn.Module.to().

to_empty

Move the parameters and buffers to the specified device without copying storage.

to_onnx

Saves the model in ONNX format.

to_torchscript

By default compiles the whole model to a torch.jit.ScriptModule.

toggle_optimizer

Makes sure only the gradients of the current optimizer's parameters are calculated in the training step to prevent dangling gradients in multiple-optimizer setup.

toggled_optimizer

Makes sure only the gradients of the current optimizer's parameters are calculated in the training step to prevent dangling gradients in multiple-optimizer setup.

train

Set the module in training mode.

train_dataloader

An iterable or collection of iterables specifying training samples.

training_step

Training step of the VAE compatible with Pytorch Lightning

transfer_batch_to_device

Override this hook if your DataLoader returns tensors wrapped in a custom data structure.

type

See torch.nn.Module.type().

unfreeze

Unfreeze all parameters for training.

untoggle_optimizer

Resets the state of required gradients that were toggled with toggle_optimizer().

vae_loss

Calculates the loss of the VAE, using the sum between the KLD loss of the latent space to N(0, I) and either mean squared error between the reconstructed data and the ground truth or the negative log likelihood of the input data given the latent space under a Gaussian assumption.

val_dataloader

An iterable or collection of iterables specifying validation samples.

validation_step

Validation step of the VAE compatible with Pytorch Lightning.

xpu

Move all model parameters and buffers to the XPU.

zero_grad

Reset gradients of all model parameters.

Attributes

CHECKPOINT_HYPER_PARAMS_KEY

CHECKPOINT_HYPER_PARAMS_NAME

CHECKPOINT_HYPER_PARAMS_TYPE

T_destination

automatic_optimization

If set to False you are responsible for calling .backward(), .step(), .zero_grad().

call_super_init

current_epoch

The current epoch in the Trainer, or 0 if not attached.

device

device_mesh

Strategies like ModelParallelStrategy will create a device mesh that can be accessed in the configure_model() hook to parallelize the LightningModule.

dtype

dump_patches

example_input_array

The example input array is a specification of what the module can consume in the forward() method.

fabric

global_rank

The index of the current process across all nodes and devices.

global_step

Total training batches seen across all epochs.

hparams

The collection of hyperparameters saved with save_hyperparameters().

hparams_initial

The collection of hyperparameters saved with save_hyperparameters().

local_rank

The index of the current process within a single node.

logger

Reference to the logger object in the Trainer.

loggers

Reference to the list of loggers in the Trainer.

on_gpu

Returns True if this model is currently located on a GPU.

strict_loading

Determines how Lightning loads this model using .load_state_dict(..., strict=model.strict_loading).

trainer

training

__init__(model_type: str, in_channels: int, latent_dim: int, dimension: int, loss_type: str, KLD_weight: float, lr_scheduler: str, set_lr: float, norm: str = 'instance', base: int = 64, optimizer: str = 'SGD', KLD_warmup_fraction: float = 0, KLD_scheduler_type: str = 'cyclical', compile_mode: str = 'max-autotune', weights_type: str = None) None[source]

The variational autoencoder (VAE) model that is used to learn the latent space of protein distance maps. The model is based on the ResNet architecture and uses a Gaussian distribution to model the latent space. The model is trained using the evidence lower bound (ELBO) loss, which is a combination of the reconstruction loss and the Kullback-Leibler divergence loss. The reconstruction loss can be either mean squared error or negative log likelihood. The weights for the reconstruction loss can be calculated based on the distance between residues in the ground truth distance map. The model can be trained using different learning rate schedulers and the learning rate can be set manually.

References

  1. Kingma, D. P. & Welling, M. Auto-Encoding Variational Bayes. arXiv [stat.ML] (2013).

2) Rombach, R., Blattmann, A., Lorenz, D., Esser, P. & Ommer, B. High-resolution image synthesis with latent diffusion models. arXiv [cs.CV] (2021).

Parameters:
  • model_type (str) – What ResNet architecture to use for the encoder and decoder portion of the VAE

  • in_channels (int) – Number of input channels in the input data

  • latent_dim (int) – The number of channels in the latent space representation of the data

  • dimension (int) – The size of the image in the height and width dimensions (i.e., distance maps)

  • loss_type (str) – The type of loss to use for the reconstruction loss. Options are “mse” and “nll”

  • weights_type (str) – The type of weights to use for the reconstruction loss. Options are “linear”, “reciprocal”, and “equal”

  • KLD_weight (float) – The weight to apply to the KLD loss in the ELBO loss function, KLD loss regularizes the latent space

  • lr_scheduler (str) – The learning rate scheduler to use for training the model. Options are “CosineAnnealingWarmRestarts”, “OneCycleLR”, and “CosineAnnealingLR”

  • set_lr (float) – The learning rate to use for training the model

  • norm (str, optional) – The normalization layer to use in the ResNet architecture, by default “instance”

  • base (int, optional) – The base (starting) number of channels to use in the ResNet architecture, by default 64

  • optimizer (str, optional) – The optimizer to use in the ResNet architecture, by default “SGD”

setup(stage=None)[source]

Set up the model, including optional compilation.

encode(data: Tensor) List[Tuple[Tensor, Tensor]][source]

Takes the data and encodes it into the latent space, by returning the mean and log variance

Parameters:

data (torch.Tensor) – Data in the shape of (batch, channel, height, width)

Returns:

Return the mean and log variance of the latent space

Return type:

List[Tuple[torch.Tensor, torch.Tensor]]

decode(latents: Tensor) Tensor[source]

Decodes the latent space back into the original data

Parameters:

latents (torch.Tensor) – latents in the shape of (batch, channel, height, width)

Returns:

Returns the reconstructed data

Return type:

torch.Tensor

reparameterize(mu: Tensor, logvar: Tensor) Tensor[source]

Reparametarization trick that allows for the flow of gradients through the non-random process. Check out the paper for more details: https://arxiv.org/abs/1312.6114

Parameters:
  • mu (torch.Tensor) – A tensor containing means of the latent space

  • logvar (torch.Tensor) – A tensor containg the log variance of the latent space

Returns:

Returns the latent encoding

Return type:

torch.Tensor

gaussian_likelihood(data_reconstructed: Tensor, log_std: Tensor, data: Tensor) Tensor[source]

Calculates the likelihood of input data given latent space (p(x|z)) under Gaussian assumption. The reconstructured data is treated as the mean of the Gaussian distributions and the log_std is a tensor of learned log standard deviations.

Parameters:
  • data_reconstructed (torch.Tensor) – A tensor containing the reconstructed data that will be treated as the mean to parameterize the Gaussian distribution

  • log_std (torch.Tensor) – Learned the log standard deviations of the Gaussian distribution

  • data (torch.Tensor) – The ground truth data that the likelihood will be calculated against

Returns:

Returns the likelihood of the input data given the latent space

Return type:

torch.Tensor

vae_loss(data_reconstructed: Tensor, data: Tensor, mu: Tensor, logvar: Tensor) dict[source]

Calculates the loss of the VAE, using the sum between the KLD loss of the latent space to N(0, I) and either mean squared error between the reconstructed data and the ground truth or the negative log likelihood of the input data given the latent space under a Gaussian assumption. Additional loss is added to ensure the contacts are reconstructed correctly.

Parameters:
  • data_reconstructed (torch.Tensor) – Reconstructed data; output of the VAE

  • data (torch.Tensor) – Ground truth data, input to the VAE

  • mu (torch.Tensor) – Means of the normal distributions of the latent space

  • logvar (torch.Tensor) – Log variances of the normal distributions of the latent space

  • KLD_weight (int, optional) – How much to weight the importance of the regularization term of the latent space. Setting this to lower than 1 will lead to less regular and interpretable latent space, by default None

Returns:

Returns a dictionary containing the total loss, reconstruction loss, and KLD loss

Return type:

dict

Raises:

ValueError – If the loss type is not mse or elbo

forward(data: Tensor) List[Tuple[Tensor, Tensor, Tensor]][source]

Forward pass of the VAE

Parameters:

data (torch.Tensor) – Data in the shape of (batch, channel, height, width) to pass through the VAE

Returns:

Returns the reconstructed data, the mean of the latent space, and the log variance

Return type:

List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]

training_step(batch: dict, batch_idx) Tensor[source]

Training step of the VAE compatible with Pytorch Lightning

Parameters:
  • batch (dict) – A batch of data read in using the DataLoader

  • batch_idx (_type_) – Batch number the model is on during training

Returns:

Total training loss of this batch

Return type:

torch.Tensor

on_train_epoch_end() None[source]

Calculate and log the mean training losses for the epoch. Reset the loss accumulators for the next epoch.

validation_step(batch: Tensor, batch_idx) Tensor[source]

Validation step of the VAE compatible with Pytorch Lightning. This is called after each epoch.

Parameters:
  • batch (torch.Tensor) – A batch of data read in using the DataLoader

  • batch_idx (_type_) – Batch number the model is on during the validation of the model

Returns:

Total validation loss of this batch

Return type:

torch.Tensor

configure_optimizers()[source]

Configure the optimizer and the learning rate scheduler for the model. Here I am using NVIDIA suggested settings for learning rate and weight decay. For ResNet50 they have seen best performance with CosineAnnealingLR, initial learning rate of 0.256 for batch size of 256 and linearly scaling it down/up for other batch sizes. The weight decay is set to 1/32768 for all parameters except the batch normalization layers. For further information check: https://catalog.ngc.nvidia.com/orgs/nvidia/resources/resnet_50_v1_5_for_pytorch

Returns:

Returns the optimizer and the learning rate scheduler

Return type:

List

Raises:

ValueError – If the scheduler is not implemented

symmetrize(data_reconstructed: Tensor) Tensor[source]

Symmetrizes the reconstructed data so that the weights can learn other patterns. Loss calculated only on the reconstruction faithfulness of the upper triangle of the distance map

Parameters:

data_reconstructed (torch.Tensor) – Reconstructed data; output of the decoder

Returns:

Symmetric version of the reconstructed data

Return type:

torch.Tensor

on_train_start()[source]

Called at the beginning of training after sanity check.