starling.models.continuous_diffusion.ContinuousDiffusion

class ContinuousDiffusion[source]

Bases: LightningModule

Methods

__init__

add_module

Add a child module to the current module.

all_gather

Gather tensors or collections of tensors from multiple processes.

apply

Apply fn recursively to every submodule (as returned by .children()) as well as self.

backward

Called to perform backward on the loss returned in training_step().

bfloat16

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers

Return an iterator over module buffers.

children

Return an iterator over immediate children modules.

clip_gradients

Handles gradient clipping internally.

compile

Compile this Module's forward using torch.compile().

configure_callbacks

Configure model-specific callbacks.

configure_gradient_clipping

Perform gradient clipping for the optimizer parameters.

configure_model

Hook to create modules in a strategy and precision aware context.

configure_optimizers

Configure the optimizer and the learning rate scheduler for the model.

configure_sharded_model

Deprecated.

cpu

See torch.nn.Module.cpu().

cuda

Moves all model parameters and buffers to the GPU.

double

See torch.nn.Module.double().

eval

Set the module in evaluation mode.

extra_repr

Return the extra representation of the module.

float

See torch.nn.Module.float().

forward

Forward pass that samples random timesteps and calculates loss.

freeze

Freeze all params for inference.

get_buffer

Return the buffer given by target if it exists, otherwise throw an error.

get_extra_state

Return any extra state to include in the module's state_dict.

get_parameter

Return the parameter given by target if it exists, otherwise throw an error.

get_submodule

Return the submodule given by target if it exists, otherwise throw an error.

half

See torch.nn.Module.half().

ipu

Move all model parameters and buffers to the IPU.

load_from_checkpoint

Primary way of loading a model from a checkpoint.

load_state_dict

Copy parameters and buffers from state_dict into this module and its descendants.

log

Log a key, value pair.

log_dict

Log a dictionary of values at once.

lr_scheduler_step

Override this method to adjust the default way the Trainer calls each scheduler.

lr_schedulers

Returns the learning rate scheduler(s) that are being used during training.

manual_backward

Call this directly from your training_step() when doing optimizations manually.

modules

Return an iterator over all modules in the network.

mtia

Move all model parameters and buffers to the MTIA.

named_buffers

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

on_after_backward

Called after loss.backward() and before optimizers are stepped.

on_after_batch_transfer

Override to alter or apply batch augmentations to your batch after it is transferred to the device.

on_before_backward

Called before loss.backward().

on_before_batch_transfer

Override to alter or apply batch augmentations to your batch before it is transferred to the device.

on_before_optimizer_step

Called before optimizer.step().

on_before_zero_grad

Called after training_step() and before optimizer.zero_grad().

on_fit_end

Called at the very end of fit.

on_fit_start

Called at the very beginning of fit.

on_load_checkpoint

Called by Lightning to restore your model.

on_predict_batch_end

Called in the predict loop after the batch.

on_predict_batch_start

Called in the predict loop before anything happens for that batch.

on_predict_end

Called at the end of predicting.

on_predict_epoch_end

Called at the end of predicting.

on_predict_epoch_start

Called at the beginning of predicting.

on_predict_model_eval

Called when the predict loop starts.

on_predict_start

Called at the beginning of predicting.

on_save_checkpoint

Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.

on_test_batch_end

Called in the test loop after the batch.

on_test_batch_start

Called in the test loop before anything happens for that batch.

on_test_end

Called at the end of testing.

on_test_epoch_end

Called in the test loop at the very end of the epoch.

on_test_epoch_start

Called in the test loop at the very beginning of the epoch.

on_test_model_eval

Called when the test loop starts.

on_test_model_train

Called when the test loop ends.

on_test_start

Called at the beginning of testing.

on_train_batch_end

Called in the training loop after the batch.

on_train_batch_start

Called in the training loop before anything happens for that batch.

on_train_end

Called at the end of training before logger experiment is closed.

on_train_epoch_end

Called in the training loop at the very end of the epoch.

on_train_epoch_start

Called in the training loop at the very beginning of the epoch.

on_train_start

Called at the beginning of training after sanity check.

on_validation_batch_end

Called in the validation loop after the batch.

on_validation_batch_start

Called in the validation loop before anything happens for that batch.

on_validation_end

Called at the end of validation.

on_validation_epoch_end

Called in the validation loop at the very end of the epoch.

on_validation_epoch_start

Called in the validation loop at the very beginning of the epoch.

on_validation_model_eval

Called when the validation loop starts.

on_validation_model_train

Called when the validation loop ends.

on_validation_model_zero_grad

Called by the training loop to release gradients before entering the validation loop.

on_validation_start

Called at the beginning of validation.

optimizer_step

Override this method to adjust the default way the Trainer calls the optimizer.

optimizer_zero_grad

Override this method to change the default behaviour of optimizer.zero_grad().

optimizers

Returns the optimizer(s) that are being used during training.

p_losses

Calculate model loss based on predicted vs actual noise.

parameters

Return an iterator over module parameters.

predict_dataloader

An iterable or collection of iterables specifying prediction samples.

predict_step

Step function called during predict().

prepare_data

Use this to download and prepare data.

print

Prints only from process 0.

q_sample

random_times

register_backward_hook

Register a backward hook on the module.

register_buffer

Add a buffer to the module.

register_forward_hook

Register a forward hook on the module.

register_forward_pre_hook

Register a forward pre-hook on the module.

register_full_backward_hook

Register a backward hook on the module.

register_full_backward_pre_hook

Register a backward pre-hook on the module.

register_load_state_dict_post_hook

Register a post-hook to be run after module's load_state_dict() is called.

register_load_state_dict_pre_hook

Register a pre-hook to be run before module's load_state_dict() is called.

register_module

Alias for add_module().

register_parameter

Add a parameter to the module.

register_state_dict_post_hook

Register a post-hook for the state_dict() method.

register_state_dict_pre_hook

Register a pre-hook for the state_dict() method.

requires_grad_

Change if autograd should record operations on parameters in this module.

save_hyperparameters

Save arguments to hparams attribute.

sequence2labels

Converts sequences to labels based on user defined models,

set_extra_state

Set extra state contained in the loaded state_dict.

set_submodule

Set the submodule given by target if it exists, otherwise throw an error.

setup

Called at the beginning of fit (train + validate), validate, test, or predict.

share_memory

See torch.Tensor.share_memory_().

state_dict

Return a dictionary containing references to the whole state of the module.

teardown

Called at the end of fit (train + validate), validate, test, or predict.

test_dataloader

An iterable or collection of iterables specifying test samples.

test_step

Operates on a single batch of data from the test set.

to

See torch.nn.Module.to().

to_empty

Move the parameters and buffers to the specified device without copying storage.

to_onnx

Saves the model in ONNX format.

to_torchscript

By default compiles the whole model to a torch.jit.ScriptModule.

toggle_optimizer

Makes sure only the gradients of the current optimizer's parameters are calculated in the training step to prevent dangling gradients in multiple-optimizer setup.

toggled_optimizer

Makes sure only the gradients of the current optimizer's parameters are calculated in the training step to prevent dangling gradients in multiple-optimizer setup.

train

Set the module in training mode.

train_dataloader

An iterable or collection of iterables specifying training samples.

training_step

Training step that encodes inputs and calculates diffusion loss.

transfer_batch_to_device

Override this hook if your DataLoader returns tensors wrapped in a custom data structure.

type

See torch.nn.Module.type().

unfreeze

Unfreeze all parameters for training.

untoggle_optimizer

Resets the state of required gradients that were toggled with toggle_optimizer().

val_dataloader

An iterable or collection of iterables specifying validation samples.

validation_step

Validation step that evaluates diffusion loss on a batch.

xpu

Move all model parameters and buffers to the XPU.

zero_grad

Reset gradients of all model parameters.

Attributes

CHECKPOINT_HYPER_PARAMS_KEY

CHECKPOINT_HYPER_PARAMS_NAME

CHECKPOINT_HYPER_PARAMS_TYPE

T_destination

automatic_optimization

If set to False you are responsible for calling .backward(), .step(), .zero_grad().

call_super_init

current_epoch

The current epoch in the Trainer, or 0 if not attached.

device

device_mesh

Strategies like ModelParallelStrategy will create a device mesh that can be accessed in the configure_model() hook to parallelize the LightningModule.

dtype

dump_patches

example_input_array

The example input array is a specification of what the module can consume in the forward() method.

fabric

global_rank

The index of the current process across all nodes and devices.

global_step

Total training batches seen across all epochs.

hparams

The collection of hyperparameters saved with save_hyperparameters().

hparams_initial

The collection of hyperparameters saved with save_hyperparameters().

local_rank

The index of the current process within a single node.

logger

Reference to the logger object in the Trainer.

loggers

Reference to the list of loggers in the Trainer.

on_gpu

Returns True if this model is currently located on a GPU.

strict_loading

Determines how Lightning loads this model using .load_state_dict(..., strict=model.strict_loading).

trainer

training

__init__(model, set_lr, config_scheduler, noise_schedule='karras', min_snr_loss_weight=False, min_snr_gamma=5)[source]
property device
sequence2labels(sequences: List) Tensor[source]

Converts sequences to labels based on user defined models,

Parameters:

sequences (List) – A list of sequences to convert to labels

Returns:

Returns the labels for the decoder

Return type:

torch.Tensor

Raises:

ValueError – If the labels are not one of the three options

q_sample(x_start, times, masks=None, noise=None)[source]
random_times(batch_size)[source]
p_losses(x_start: Tensor, t: Tensor, labels: Tensor = None, noise: Tensor = None, masks: Tensor = None) Tensor[source]

Calculate model loss based on predicted vs actual noise.

Parameters:
  • x_start (torch.Tensor) – The starting tensor to denoise

  • t (torch.Tensor) – Timesteps along the denoising-diffusion process

  • labels (torch.Tensor, optional) – Condition labels for the model

  • noise (torch.Tensor, optional) – Optional pre-defined noise, otherwise sampled from N(0,I)

  • masks (torch.Tensor, optional) – Optional masks for conditional generation

Returns:

Mean MSE loss between predicted and actual noise

Return type:

torch.Tensor

forward(x: Tensor, labels: Tensor, masks: Tensor = None) Tensor[source]

Forward pass that samples random timesteps and calculates loss.

Parameters:
Returns:

Loss value

Return type:

torch.Tensor

training_step(batch: Tensor, batch_idx: int) Tensor[source]

Training step that encodes inputs and calculates diffusion loss.

Parameters:
  • batch (torch.Tensor) – Batch containing data and sequence labels

  • batch_idx (int) – Index of the current batch

Returns:

Training loss

Return type:

torch.Tensor

validation_step(batch: Tensor, batch_idx: int) Tensor[source]

Validation step that evaluates diffusion loss on a batch.

configure_optimizers()[source]

Configure the optimizer and the learning rate scheduler for the model. Here I am using NVIDIA suggested settings for learning rate and weight decay. For ResNet50 they have seen best performance with CosineAnnealingLR, initial learning rate of 0.256 for batch size of 256 and linearly scaling it down/up for other batch sizes. The weight decay is set to 1/32768 for all parameters except the batch normalization layers. For further information check: https://catalog.ngc.nvidia.com/orgs/nvidia/resources/resnet_50_v1_5_for_pytorch

Returns:

Returns the optimizer and the learning rate scheduler

Return type:

List

Raises:

ValueError – If the scheduler is not implemented