starling.models.vae.VAE
- class VAE[source]
Bases:
LightningModuleMethods
The variational autoencoder (VAE) model that is used to learn the latent space of protein distance maps.
add_moduleAdd a child module to the current module.
all_gatherGather tensors or collections of tensors from multiple processes.
applyApply
fnrecursively to every submodule (as returned by.children()) as well as self.backwardCalled to perform backward on the loss returned in
training_step().bfloat16Casts all floating point parameters and buffers to
bfloat16datatype.buffersReturn an iterator over module buffers.
childrenReturn an iterator over immediate children modules.
clip_gradientsHandles gradient clipping internally.
compileCompile this Module's forward using
torch.compile().configure_callbacksConfigure model-specific callbacks.
configure_gradient_clippingPerform gradient clipping for the optimizer parameters.
configure_modelHook to create modules in a strategy and precision aware context.
Configure the optimizer and the learning rate scheduler for the model.
configure_sharded_modelDeprecated.
cpucudaMoves all model parameters and buffers to the GPU.
Decodes the latent space back into the original data
doubleTakes the data and encodes it into the latent space, by returning the mean and log variance
evalSet the module in evaluation mode.
extra_reprReturn the extra representation of the module.
floatForward pass of the VAE
freezeFreeze all params for inference.
Calculates the likelihood of input data given latent space (p(x|z)) under Gaussian assumption.
get_bufferReturn the buffer given by
targetif it exists, otherwise throw an error.get_extra_stateReturn any extra state to include in the module's state_dict.
get_parameterReturn the parameter given by
targetif it exists, otherwise throw an error.get_submoduleReturn the submodule given by
targetif it exists, otherwise throw an error.halfipuMove all model parameters and buffers to the IPU.
load_from_checkpointPrimary way of loading a model from a checkpoint.
load_state_dictCopy parameters and buffers from
state_dictinto this module and its descendants.logLog a key, value pair.
log_dictLog a dictionary of values at once.
lr_scheduler_stepOverride this method to adjust the default way the
Trainercalls each scheduler.lr_schedulersReturns the learning rate scheduler(s) that are being used during training.
manual_backwardCall this directly from your
training_step()when doing optimizations manually.modulesReturn an iterator over all modules in the network.
mtiaMove all model parameters and buffers to the MTIA.
named_buffersReturn an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
named_childrenReturn an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
named_modulesReturn an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
named_parametersReturn an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
on_after_backwardCalled after
loss.backward()and before optimizers are stepped.on_after_batch_transferOverride to alter or apply batch augmentations to your batch after it is transferred to the device.
on_before_backwardCalled before
loss.backward().on_before_batch_transferOverride to alter or apply batch augmentations to your batch before it is transferred to the device.
on_before_optimizer_stepCalled before
optimizer.step().on_before_zero_gradCalled after
training_step()and beforeoptimizer.zero_grad().on_fit_endCalled at the very end of fit.
on_fit_startCalled at the very beginning of fit.
on_load_checkpointCalled by Lightning to restore your model.
on_predict_batch_endCalled in the predict loop after the batch.
on_predict_batch_startCalled in the predict loop before anything happens for that batch.
on_predict_endCalled at the end of predicting.
on_predict_epoch_endCalled at the end of predicting.
on_predict_epoch_startCalled at the beginning of predicting.
on_predict_model_evalCalled when the predict loop starts.
on_predict_startCalled at the beginning of predicting.
on_save_checkpointCalled by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.
on_test_batch_endCalled in the test loop after the batch.
on_test_batch_startCalled in the test loop before anything happens for that batch.
on_test_endCalled at the end of testing.
on_test_epoch_endCalled in the test loop at the very end of the epoch.
on_test_epoch_startCalled in the test loop at the very beginning of the epoch.
on_test_model_evalCalled when the test loop starts.
on_test_model_trainCalled when the test loop ends.
on_test_startCalled at the beginning of testing.
on_train_batch_endCalled in the training loop after the batch.
on_train_batch_startCalled in the training loop before anything happens for that batch.
on_train_endCalled at the end of training before logger experiment is closed.
Calculate and log the mean training losses for the epoch.
on_train_epoch_startCalled in the training loop at the very beginning of the epoch.
Called at the beginning of training after sanity check.
on_validation_batch_endCalled in the validation loop after the batch.
on_validation_batch_startCalled in the validation loop before anything happens for that batch.
on_validation_endCalled at the end of validation.
on_validation_epoch_endCalled in the validation loop at the very end of the epoch.
on_validation_epoch_startCalled in the validation loop at the very beginning of the epoch.
on_validation_model_evalCalled when the validation loop starts.
on_validation_model_trainCalled when the validation loop ends.
on_validation_model_zero_gradCalled by the training loop to release gradients before entering the validation loop.
on_validation_startCalled at the beginning of validation.
optimizer_stepOverride this method to adjust the default way the
Trainercalls the optimizer.optimizer_zero_gradOverride this method to change the default behaviour of
optimizer.zero_grad().optimizersReturns the optimizer(s) that are being used during training.
parametersReturn an iterator over module parameters.
predict_dataloaderAn iterable or collection of iterables specifying prediction samples.
predict_stepStep function called during
predict().prepare_dataUse this to download and prepare data.
printPrints only from process 0.
register_backward_hookRegister a backward hook on the module.
register_bufferAdd a buffer to the module.
register_forward_hookRegister a forward hook on the module.
register_forward_pre_hookRegister a forward pre-hook on the module.
register_full_backward_hookRegister a backward hook on the module.
register_full_backward_pre_hookRegister a backward pre-hook on the module.
register_load_state_dict_post_hookRegister a post-hook to be run after module's
load_state_dict()is called.register_load_state_dict_pre_hookRegister a pre-hook to be run before module's
load_state_dict()is called.register_moduleAlias for
add_module().register_parameterAdd a parameter to the module.
register_state_dict_post_hookRegister a post-hook for the
state_dict()method.register_state_dict_pre_hookRegister a pre-hook for the
state_dict()method.Reparametarization trick that allows for the flow of gradients through the non-random process.
requires_grad_Change if autograd should record operations on parameters in this module.
save_hyperparametersSave arguments to
hparamsattribute.set_extra_stateSet extra state contained in the loaded state_dict.
set_submoduleSet the submodule given by
targetif it exists, otherwise throw an error.Set up the model, including optional compilation.
share_memorystate_dictReturn a dictionary containing references to the whole state of the module.
Symmetrizes the reconstructed data so that the weights can learn other patterns.
teardownCalled at the end of fit (train + validate), validate, test, or predict.
test_dataloaderAn iterable or collection of iterables specifying test samples.
test_stepOperates on a single batch of data from the test set.
toSee
torch.nn.Module.to().to_emptyMove the parameters and buffers to the specified device without copying storage.
to_onnxSaves the model in ONNX format.
to_torchscriptBy default compiles the whole model to a
torch.jit.ScriptModule.toggle_optimizerMakes sure only the gradients of the current optimizer's parameters are calculated in the training step to prevent dangling gradients in multiple-optimizer setup.
toggled_optimizerMakes sure only the gradients of the current optimizer's parameters are calculated in the training step to prevent dangling gradients in multiple-optimizer setup.
trainSet the module in training mode.
train_dataloaderAn iterable or collection of iterables specifying training samples.
Training step of the VAE compatible with Pytorch Lightning
transfer_batch_to_deviceOverride this hook if your
DataLoaderreturns tensors wrapped in a custom data structure.typeunfreezeUnfreeze all parameters for training.
untoggle_optimizerResets the state of required gradients that were toggled with
toggle_optimizer().Calculates the loss of the VAE, using the sum between the KLD loss of the latent space to N(0, I) and either mean squared error between the reconstructed data and the ground truth or the negative log likelihood of the input data given the latent space under a Gaussian assumption.
val_dataloaderAn iterable or collection of iterables specifying validation samples.
Validation step of the VAE compatible with Pytorch Lightning.
xpuMove all model parameters and buffers to the XPU.
zero_gradReset gradients of all model parameters.
Attributes
CHECKPOINT_HYPER_PARAMS_KEYCHECKPOINT_HYPER_PARAMS_NAMECHECKPOINT_HYPER_PARAMS_TYPET_destinationautomatic_optimizationIf set to
Falseyou are responsible for calling.backward(),.step(),.zero_grad().call_super_initcurrent_epochThe current epoch in the
Trainer, or 0 if not attached.devicedevice_meshStrategies like
ModelParallelStrategywill create a device mesh that can be accessed in theconfigure_model()hook to parallelize the LightningModule.dtypedump_patchesexample_input_arrayThe example input array is a specification of what the module can consume in the
forward()method.fabricglobal_rankThe index of the current process across all nodes and devices.
global_stepTotal training batches seen across all epochs.
hparamsThe collection of hyperparameters saved with
save_hyperparameters().hparams_initialThe collection of hyperparameters saved with
save_hyperparameters().local_rankThe index of the current process within a single node.
loggerReference to the logger object in the Trainer.
loggersReference to the list of loggers in the Trainer.
on_gpuReturns
Trueif this model is currently located on a GPU.strict_loadingDetermines how Lightning loads this model using .load_state_dict(..., strict=model.strict_loading).
trainertraining- __init__(model_type: str, in_channels: int, latent_dim: int, dimension: int, loss_type: str, KLD_weight: float, lr_scheduler: str, set_lr: float, norm: str = 'instance', base: int = 64, optimizer: str = 'SGD', KLD_warmup_fraction: float = 0, KLD_scheduler_type: str = 'cyclical', compile_mode: str = 'max-autotune', weights_type: str = None) None[source]
The variational autoencoder (VAE) model that is used to learn the latent space of protein distance maps. The model is based on the ResNet architecture and uses a Gaussian distribution to model the latent space. The model is trained using the evidence lower bound (ELBO) loss, which is a combination of the reconstruction loss and the Kullback-Leibler divergence loss. The reconstruction loss can be either mean squared error or negative log likelihood. The weights for the reconstruction loss can be calculated based on the distance between residues in the ground truth distance map. The model can be trained using different learning rate schedulers and the learning rate can be set manually.
References
Kingma, D. P. & Welling, M. Auto-Encoding Variational Bayes. arXiv [stat.ML] (2013).
2) Rombach, R., Blattmann, A., Lorenz, D., Esser, P. & Ommer, B. High-resolution image synthesis with latent diffusion models. arXiv [cs.CV] (2021).
- Parameters:
model_type (str) – What ResNet architecture to use for the encoder and decoder portion of the VAE
in_channels (int) – Number of input channels in the input data
latent_dim (int) – The number of channels in the latent space representation of the data
dimension (int) – The size of the image in the height and width dimensions (i.e., distance maps)
loss_type (str) – The type of loss to use for the reconstruction loss. Options are “mse” and “nll”
weights_type (str) – The type of weights to use for the reconstruction loss. Options are “linear”, “reciprocal”, and “equal”
KLD_weight (float) – The weight to apply to the KLD loss in the ELBO loss function, KLD loss regularizes the latent space
lr_scheduler (str) – The learning rate scheduler to use for training the model. Options are “CosineAnnealingWarmRestarts”, “OneCycleLR”, and “CosineAnnealingLR”
set_lr (float) – The learning rate to use for training the model
norm (str, optional) – The normalization layer to use in the ResNet architecture, by default “instance”
base (int, optional) – The base (starting) number of channels to use in the ResNet architecture, by default 64
optimizer (str, optional) – The optimizer to use in the ResNet architecture, by default “SGD”
- encode(data: Tensor) List[Tuple[Tensor, Tensor]][source]
Takes the data and encodes it into the latent space, by returning the mean and log variance
- Parameters:
data (torch.Tensor) – Data in the shape of (batch, channel, height, width)
- Returns:
Return the mean and log variance of the latent space
- Return type:
List[Tuple[torch.Tensor, torch.Tensor]]
- decode(latents: Tensor) Tensor[source]
Decodes the latent space back into the original data
- Parameters:
latents (torch.Tensor) – latents in the shape of (batch, channel, height, width)
- Returns:
Returns the reconstructed data
- Return type:
- reparameterize(mu: Tensor, logvar: Tensor) Tensor[source]
Reparametarization trick that allows for the flow of gradients through the non-random process. Check out the paper for more details: https://arxiv.org/abs/1312.6114
- Parameters:
mu (torch.Tensor) – A tensor containing means of the latent space
logvar (torch.Tensor) – A tensor containg the log variance of the latent space
- Returns:
Returns the latent encoding
- Return type:
- gaussian_likelihood(data_reconstructed: Tensor, log_std: Tensor, data: Tensor) Tensor[source]
Calculates the likelihood of input data given latent space (p(x|z)) under Gaussian assumption. The reconstructured data is treated as the mean of the Gaussian distributions and the log_std is a tensor of learned log standard deviations.
- Parameters:
data_reconstructed (torch.Tensor) – A tensor containing the reconstructed data that will be treated as the mean to parameterize the Gaussian distribution
log_std (torch.Tensor) – Learned the log standard deviations of the Gaussian distribution
data (torch.Tensor) – The ground truth data that the likelihood will be calculated against
- Returns:
Returns the likelihood of the input data given the latent space
- Return type:
- vae_loss(data_reconstructed: Tensor, data: Tensor, mu: Tensor, logvar: Tensor) dict[source]
Calculates the loss of the VAE, using the sum between the KLD loss of the latent space to N(0, I) and either mean squared error between the reconstructed data and the ground truth or the negative log likelihood of the input data given the latent space under a Gaussian assumption. Additional loss is added to ensure the contacts are reconstructed correctly.
- Parameters:
data_reconstructed (torch.Tensor) – Reconstructed data; output of the VAE
data (torch.Tensor) – Ground truth data, input to the VAE
mu (torch.Tensor) – Means of the normal distributions of the latent space
logvar (torch.Tensor) – Log variances of the normal distributions of the latent space
KLD_weight (int, optional) – How much to weight the importance of the regularization term of the latent space. Setting this to lower than 1 will lead to less regular and interpretable latent space, by default None
- Returns:
Returns a dictionary containing the total loss, reconstruction loss, and KLD loss
- Return type:
- Raises:
ValueError – If the loss type is not mse or elbo
- forward(data: Tensor) List[Tuple[Tensor, Tensor, Tensor]][source]
Forward pass of the VAE
- Parameters:
data (torch.Tensor) – Data in the shape of (batch, channel, height, width) to pass through the VAE
- Returns:
Returns the reconstructed data, the mean of the latent space, and the log variance
- Return type:
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
- training_step(batch: dict, batch_idx) Tensor[source]
Training step of the VAE compatible with Pytorch Lightning
- Parameters:
batch (dict) – A batch of data read in using the DataLoader
batch_idx (_type_) – Batch number the model is on during training
- Returns:
Total training loss of this batch
- Return type:
- on_train_epoch_end() None[source]
Calculate and log the mean training losses for the epoch. Reset the loss accumulators for the next epoch.
- validation_step(batch: Tensor, batch_idx) Tensor[source]
Validation step of the VAE compatible with Pytorch Lightning. This is called after each epoch.
- Parameters:
batch (torch.Tensor) – A batch of data read in using the DataLoader
batch_idx (_type_) – Batch number the model is on during the validation of the model
- Returns:
Total validation loss of this batch
- Return type:
- configure_optimizers()[source]
Configure the optimizer and the learning rate scheduler for the model. Here I am using NVIDIA suggested settings for learning rate and weight decay. For ResNet50 they have seen best performance with CosineAnnealingLR, initial learning rate of 0.256 for batch size of 256 and linearly scaling it down/up for other batch sizes. The weight decay is set to 1/32768 for all parameters except the batch normalization layers. For further information check: https://catalog.ngc.nvidia.com/orgs/nvidia/resources/resnet_50_v1_5_for_pytorch
- Returns:
Returns the optimizer and the learning rate scheduler
- Return type:
List
- Raises:
ValueError – If the scheduler is not implemented
- symmetrize(data_reconstructed: Tensor) Tensor[source]
Symmetrizes the reconstructed data so that the weights can learn other patterns. Loss calculated only on the reconstruction faithfulness of the upper triangle of the distance map
- Parameters:
data_reconstructed (torch.Tensor) – Reconstructed data; output of the decoder
- Returns:
Symmetric version of the reconstructed data
- Return type: