starling.models.diffusion.DiffusionModel
- class DiffusionModel[source]
Bases:
LightningModuleDenoising diffusion probabilistic model for latent space generation.
Implements the diffusion process described in: - Sohl-Dickstein et al. (2015): Nonequilibrium Thermodynamics - Ho et al. (2020): Denoising Diffusion Probabilistic Models - Rombach et al. (2021): High-resolution image synthesis with latent diffusion
Methods
A discrete-time denoising-diffusion model framework for latent space diffusion models.
add_moduleAdd a child module to the current module.
all_gatherGather tensors or collections of tensors from multiple processes.
applyApply
fnrecursively to every submodule (as returned by.children()) as well as self.backwardCalled to perform backward on the loss returned in
training_step().bfloat16Casts all floating point parameters and buffers to
bfloat16datatype.buffersReturn an iterator over module buffers.
childrenReturn an iterator over immediate children modules.
clip_gradientsHandles gradient clipping internally.
compileCompile this Module's forward using
torch.compile().Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
configure_callbacksConfigure model-specific callbacks.
configure_gradient_clippingPerform gradient clipping for the optimizer parameters.
configure_modelHook to create modules in a strategy and precision aware context.
Configure the optimizer and the learning rate scheduler for the model.
configure_sharded_modelDeprecated.
cpucudaMoves all model parameters and buffers to the GPU.
doubleevalSet the module in evaluation mode.
extra_reprReturn the extra representation of the module.
floatForward pass of the model, calculates the loss based on the predicted noise and the actual noise.
freezeFreeze all params for inference.
get_bufferReturn the buffer given by
targetif it exists, otherwise throw an error.get_extra_stateReturn any extra state to include in the module's state_dict.
get_parameterReturn the parameter given by
targetif it exists, otherwise throw an error.get_submoduleReturn the submodule given by
targetif it exists, otherwise throw an error.halfipuMove all model parameters and buffers to the IPU.
load_from_checkpointPrimary way of loading a model from a checkpoint.
load_state_dictCopy parameters and buffers from
state_dictinto this module and its descendants.logLog a key, value pair.
log_dictLog a dictionary of values at once.
lr_scheduler_stepOverride this method to adjust the default way the
Trainercalls each scheduler.lr_schedulersReturns the learning rate scheduler(s) that are being used during training.
manual_backwardCall this directly from your
training_step()when doing optimizations manually.modulesReturn an iterator over all modules in the network.
mtiaMove all model parameters and buffers to the MTIA.
named_buffersReturn an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
named_childrenReturn an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
named_modulesReturn an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
named_parametersReturn an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
on_after_backwardCalled after
loss.backward()and before optimizers are stepped.on_after_batch_transferOverride to alter or apply batch augmentations to your batch after it is transferred to the device.
on_before_backwardCalled before
loss.backward().on_before_batch_transferOverride to alter or apply batch augmentations to your batch before it is transferred to the device.
on_before_optimizer_stepCalled before
optimizer.step().on_before_zero_gradCalled after
training_step()and beforeoptimizer.zero_grad().on_fit_endCalled at the very end of fit.
on_fit_startCalled at the very beginning of fit.
on_load_checkpointCalled by Lightning to restore your model.
on_predict_batch_endCalled in the predict loop after the batch.
on_predict_batch_startCalled in the predict loop before anything happens for that batch.
on_predict_endCalled at the end of predicting.
on_predict_epoch_endCalled at the end of predicting.
on_predict_epoch_startCalled at the beginning of predicting.
on_predict_model_evalCalled when the predict loop starts.
on_predict_startCalled at the beginning of predicting.
on_save_checkpointCalled by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.
on_test_batch_endCalled in the test loop after the batch.
on_test_batch_startCalled in the test loop before anything happens for that batch.
on_test_endCalled at the end of testing.
on_test_epoch_endCalled in the test loop at the very end of the epoch.
on_test_epoch_startCalled in the test loop at the very beginning of the epoch.
on_test_model_evalCalled when the test loop starts.
on_test_model_trainCalled when the test loop ends.
on_test_startCalled at the beginning of testing.
on_train_batch_endCalled in the training loop after the batch.
on_train_batch_startCalled in the training loop before anything happens for that batch.
on_train_endCalled at the end of training before logger experiment is closed.
on_train_epoch_endCalled in the training loop at the very end of the epoch.
on_train_epoch_startCalled in the training loop at the very beginning of the epoch.
on_train_startCalled at the beginning of training after sanity check.
on_validation_batch_endCalled in the validation loop after the batch.
on_validation_batch_startCalled in the validation loop before anything happens for that batch.
on_validation_endCalled at the end of validation.
on_validation_epoch_endCalled in the validation loop at the very end of the epoch.
on_validation_epoch_startCalled in the validation loop at the very beginning of the epoch.
on_validation_model_evalCalled when the validation loop starts.
on_validation_model_trainCalled when the validation loop ends.
on_validation_model_zero_gradCalled by the training loop to release gradients before entering the validation loop.
on_validation_startCalled at the beginning of validation.
optimizer_stepOverride this method to adjust the default way the
Trainercalls the optimizer.optimizer_zero_gradOverride this method to change the default behaviour of
optimizer.zero_grad().optimizersReturns the optimizer(s) that are being used during training.
A function that runs the model and calculates the loss based on the predicted noise and the actual noise.
parametersReturn an iterator over module parameters.
predict_dataloaderAn iterable or collection of iterables specifying prediction samples.
predict_stepStep function called during
predict().prepare_dataUse this to download and prepare data.
printPrints only from process 0.
Add the noise to x_start tensor based on the timestamp t
register_backward_hookRegister a backward hook on the module.
register_bufferAdd a buffer to the module.
register_forward_hookRegister a forward hook on the module.
register_forward_pre_hookRegister a forward pre-hook on the module.
register_full_backward_hookRegister a backward hook on the module.
register_full_backward_pre_hookRegister a backward pre-hook on the module.
register_load_state_dict_post_hookRegister a post-hook to be run after module's
load_state_dict()is called.register_load_state_dict_pre_hookRegister a pre-hook to be run before module's
load_state_dict()is called.register_moduleAlias for
add_module().register_parameterAdd a parameter to the module.
register_state_dict_post_hookRegister a post-hook for the
state_dict()method.register_state_dict_pre_hookRegister a pre-hook for the
state_dict()method.requires_grad_Change if autograd should record operations on parameters in this module.
save_hyperparametersSave arguments to
hparamsattribute.Converts sequences to labels based on user defined models,
set_extra_stateSet extra state contained in the loaded state_dict.
set_submoduleSet the submodule given by
targetif it exists, otherwise throw an error.setupCalled at the beginning of fit (train + validate), validate, test, or predict.
share_memorystate_dictReturn a dictionary containing references to the whole state of the module.
teardownCalled at the end of fit (train + validate), validate, test, or predict.
test_dataloaderAn iterable or collection of iterables specifying test samples.
test_stepOperates on a single batch of data from the test set.
toSee
torch.nn.Module.to().to_emptyMove the parameters and buffers to the specified device without copying storage.
to_onnxSaves the model in ONNX format.
to_torchscriptBy default compiles the whole model to a
torch.jit.ScriptModule.toggle_optimizerMakes sure only the gradients of the current optimizer's parameters are calculated in the training step to prevent dangling gradients in multiple-optimizer setup.
toggled_optimizerMakes sure only the gradients of the current optimizer's parameters are calculated in the training step to prevent dangling gradients in multiple-optimizer setup.
trainSet the module in training mode.
train_dataloaderAn iterable or collection of iterables specifying training samples.
Training step that computes diffusion loss on a batch.
transfer_batch_to_deviceOverride this hook if your
DataLoaderreturns tensors wrapped in a custom data structure.typeunfreezeUnfreeze all parameters for training.
untoggle_optimizerResets the state of required gradients that were toggled with
toggle_optimizer().val_dataloaderAn iterable or collection of iterables specifying validation samples.
Validation step that evaluates diffusion loss on a batch.
xpuMove all model parameters and buffers to the XPU.
zero_gradReset gradients of all model parameters.
Attributes
CHECKPOINT_HYPER_PARAMS_KEYCHECKPOINT_HYPER_PARAMS_NAMECHECKPOINT_HYPER_PARAMS_TYPET_destinationautomatic_optimizationIf set to
Falseyou are responsible for calling.backward(),.step(),.zero_grad().call_super_initcurrent_epochThe current epoch in the
Trainer, or 0 if not attached.devicedevice_meshStrategies like
ModelParallelStrategywill create a device mesh that can be accessed in theconfigure_model()hook to parallelize the LightningModule.dtypedump_patchesexample_input_arrayThe example input array is a specification of what the module can consume in the
forward()method.fabricglobal_rankThe index of the current process across all nodes and devices.
global_stepTotal training batches seen across all epochs.
hparamsThe collection of hyperparameters saved with
save_hyperparameters().hparams_initialThe collection of hyperparameters saved with
save_hyperparameters().local_rankThe index of the current process within a single node.
loggerReference to the logger object in the Trainer.
loggersReference to the list of loggers in the Trainer.
on_gpuReturns
Trueif this model is currently located on a GPU.strict_loadingDetermines how Lightning loads this model using .load_state_dict(..., strict=model.strict_loading).
trainertraining- SCHEDULER_MAPPING = {'cosine': <function cosine_beta_schedule>, 'linear': <function linear_beta_schedule>, 'sigmoid': <function sigmoid_beta_schedule>}
- __init__(model: Module, sequence_encoder: Module, distance_map_encoder: Module, beta_scheduler: str = 'cosine', timesteps: int = 1000, set_lr: float = 0.0001, min_snr_loss: bool = False, min_snr_gamma: float = 5.0, config_scheduler: str = 'LinearWarmupCosineAnnealingLR') None[source]
A discrete-time denoising-diffusion model framework for latent space diffusion models. The model is based on the work of Sohl-Dickstein et al. [1], Ho et al. [2], and Rombach et al. [3].
References
1) Sohl-Dickstein, J., Weiss, E., Maheswaranathan, N. & Ganguli, S. Deep Unsupervised Learning using Nonequilibrium Thermodynamics. in Proceedings of the 32nd International Conference on Machine Learning (eds. Bach, F. & Blei, D.) vol. 37 2256–2265 (PMLR, Lille, France, 07–09 Jul 2015).
Ho, J., Jain, A. & Abbeel, P. Denoising Diffusion Probabilistic Models. arXiv [cs.LG] (2020).
3) Rombach, R., Blattmann, A., Lorenz, D., Esser, P. & Ommer, B. High-resolution image synthesis with latent diffusion models. arXiv [cs.CV] (2021).
- Parameters:
model (nn.Module) – A neural network model that takes in an image, a timestamp, and optionally labels to condition on and outputs the predicted noise
encoder_model (nn.Module) – A VAE model that takes in the data (e.g., a distance map) and outputs the compressed representation of the data (e.g., a latent space). The denoising-diffusion model is then trained to denoise the latent space.
image_size (int) – The size of the latent space (height and width)
beta_scheduler (str, optional) – The name of the beta scheduler to use, by default “cosine”
timesteps (int, optional) – The number of timesteps to run the diffusion process, by default 1000
schedule_fn_kwargs (Union[dict, None], optional) – Additional arguments to pass to the beta scheduler function, by default None
labels (str, optional) – The type of labels to condition the model on, by default “learned-embeddings”
set_lr (float, optional) – The initial learning rate for the optimizer, by default 1e-4
config_scheduler (str, optional) – The name of the learning rate scheduler to use, by default “CosineAnnealingLR”
- Raises:
ValueError – If the beta scheduler is not implemented
- q_sample(x_start: Tensor, t: int, noise: Tensor = None) Tensor[source]
Add the noise to x_start tensor based on the timestamp t
- Parameters:
x_start (torch.Tensor) – The starting image tensor
t (int) – The timestep of the denoising-diffusion process
noise (torch.Tensor, optional) – Sampled noise to add, by default None
- Returns:
Returns the properly (according to the timestamp) noised tensor
- Return type:
- sequence2labels(sequences: List, sequence_mask, ionic_strength) Tensor[source]
Converts sequences to labels based on user defined models,
- Parameters:
sequences (List) – A list of sequences to convert to labels
- Returns:
Returns the labels for the decoder
- Return type:
- Raises:
ValueError – If the labels are not one of the three options
- p_loss(x_start: Tensor, t: int, labels: Tensor, mask: Tensor, ionic_strengths: Tensor, noise: Tensor = None) Tensor[source]
A function that runs the model and calculates the loss based on the predicted noise and the actual noise. The loss can either be L1 or L2.
- Parameters:
x_start (torch.Tensor) – The starting image tensor
t (int) – The timestep along the denoising-diffusion process
labels (torch.Tensor, optional) – Labels to condition the model on, by default None
noise (torch.Tensor, optional) – Sampled noise from N(0,I), by default None
- Returns:
Returns the loss
- Return type:
- Raises:
ValueError – If the loss type is not one of the two options (l1, l2)
- forward(x: Tensor, labels: Tensor, mask, ionic_strengths) Tensor[source]
Forward pass of the model, calculates the loss based on the predicted noise and the actual noise.
- Parameters:
x (torch.Tensor) – The starting tensor to noise/denoise
labels (torch.Tensor, optional) – Sequences to condition the model on, by default None
- Returns:
Returns the loss
- Return type:
- training_step(batch: Tensor, batch_idx: int) Tensor[source]
Training step that computes diffusion loss on a batch.
- validation_step(batch: Tensor, batch_idx: int) Tensor[source]
Validation step that evaluates diffusion loss on a batch.
- compute_snr(timesteps)[source]
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
- configure_optimizers()[source]
Configure the optimizer and the learning rate scheduler for the model. Here I am using NVIDIA suggested settings for learning rate and weight decay. For ResNet50 they have seen best performance with CosineAnnealingLR, initial learning rate of 0.256 for batch size of 256 and linearly scaling it down/up for other batch sizes. The weight decay is set to 1/32768 for all parameters except the batch normalization layers. For further information check: https://catalog.ngc.nvidia.com/orgs/nvidia/resources/resnet_50_v1_5_for_pytorch
- Returns:
Returns the optimizer and the learning rate scheduler
- Return type:
List
- Raises:
ValueError – If the scheduler is not implemented