starling.models.diffusion.extract

extract(constants: Tensor, timestamps: Tensor, shape: int) Tensor[source]

Extract values from a tensor based on given timestamps.

Parameters:
  • constants (torch.Tensor) – The tensor to extract values from.

  • timestamps (torch.Tensor) – A 1D tensor containing the indices for extraction.

  • shape (int) – The desired shape of the output tensor.

Returns:

The tensor with extracted values.

Return type:

torch.Tensor