Source code for starling.inference.generation

import gc
import os
import time
from datetime import datetime

import numpy as np
import torch
from soursop.sstrajectory import SSTrajectory
from tqdm.auto import tqdm

from starling import configs
from starling.data.tokenizer import StarlingTokenizer
from starling.inference.model_loading import ModelManager
from starling.samplers.ddim_sampler import DDIMSampler
from starling.samplers.ddpm_sampler import DDPMSampler
from starling.samplers.plms_sampler import PLMSSampler
from starling.structure.coordinates import (
    create_ca_topology_from_coords,
    generate_3d_coordinates_from_distances,
)
from starling.structure.ensemble import Ensemble

# initialize model_manager singleton. This happens when this module
# is imported to ensemble_generation, so we can use the
# same model_manager for all calls to generate_backend.
model_manager = ModelManager()


[docs] def symmetrize_distance_map(dist_map): """ Symmetrize a distance map by replacing the lower triangle with the upper triangle values. Parameters ---------- dist_map : torch.Tensor A 2D tensor representing the distance map. Returns ------- torch.Tensor A symmetrized distance map. """ # Ensure the distance map is 2D dist_map = dist_map.squeeze(0) if dist_map.dim() == 3 else dist_map # Create a copy of the distance map to modify sym_dist_map = dist_map.clone() # Replace the lower triangle with the upper triangle values mask_upper_triangle = torch.triu(torch.ones_like(dist_map), diagonal=1).bool() mask_lower_triangle = ~mask_upper_triangle # Set lower triangle values to be the same as the upper triangle sym_dist_map[mask_lower_triangle] = dist_map.T[mask_lower_triangle] # Set diagonal values to zero sym_dist_map.fill_diagonal_(0) return sym_dist_map.cpu()
[docs] def sequence_encoder_backend( sequence_dict, device, batch_size, ionic_strength, aggregate=True, output_directory=None, model_manager=model_manager, encoder_path=None, ddpm_path=None, pretokenized: bool = False, bucket: bool = False, bucket_size: int = 32, free_cuda_cache: bool = False, return_on_cpu: bool = True, ): """ Generate embeddings for sequences and optionally save them to disk. Parameters ---------- sequence_dict : dict Dictionary of sequence names to sequences device : str Device to use for computation batch_size : int Batch size for processing ionic_strength : float Ionic strength [mM] to condition the model output_directory : str, optional If provided, embeddings will be saved to this directory with sequence name as filename model_manager : ModelManager Model manager instance encoder_path : str, optional Custom encoder path ddpm_path : str, optional Custom diffusion model path pretokenized : bool, default False If True, values of sequence_dict are assumed to already be iterable collections of integer token ids (lists/tuples/torch tensors). Skips tokenization. bucket : bool, default False If True, sequences are grouped into coarse length buckets (multiple of bucket_size) to reduce padding waste. Beneficial when length distribution is very broad. bucket_size : int, default 32 Length resolution for bucketing when bucket=True. Sequences with lengths that fall into the same bucket ( (L//bucket_size) ) are batched together. free_cuda_cache : bool, default False If True and running on CUDA, calls torch.cuda.empty_cache() after each batch. return_on_cpu : bool, default True If True, embeddings are transferred to CPU before being returned or saved. If False, embeddings remain on the original device (e.g., GPU), which can be useful when performing downstream tensor operations on the same device. Returns ------- dict or None If output_directory is None, returns dictionary name -> tensor (L_i, D). Otherwise returns None (embeddings written to disk as <name>.pt). """ tokenizer = None if pretokenized else StarlingTokenizer() _, diffusion = model_manager.get_models( device=device, encoder_path=encoder_path, ddpm_path=ddpm_path ) ionic_strength = torch.tensor([ionic_strength], device=device).unsqueeze(0) # Prepare output handling if output_directory is not None: os.makedirs(output_directory, exist_ok=True) print(f"Saving embeddings to: {os.path.abspath(output_directory)}") embedding_dict = None else: embedding_dict = {} # Normalize sequences -> (name, token_list) prepared = [] for name, seq in sequence_dict.items(): if pretokenized: # Accept list/tuple/torch.Tensor of ints if isinstance(seq, torch.Tensor): tokens = seq.tolist() else: tokens = list(seq) else: tokens = tokenizer.encode(seq) prepared.append((name, tokens)) # Optional bucketing to reduce padding if bucket: bucketed = {} for name, toks in prepared: key = (len(toks) // bucket_size) * bucket_size bucketed.setdefault(key, []).append((name, toks)) # Flatten buckets ordered by descending key (longer first) ordered = [] for key in sorted(bucketed.keys(), reverse=True): # within bucket sort descending length ordered.extend(sorted(bucketed[key], key=lambda x: len(x[1]), reverse=True)) prepared = ordered else: # Just sort by length descending prepared.sort(key=lambda x: len(x[1]), reverse=True) names = [n for n, _ in prepared] seqs = [t for _, t in prepared] total = len(seqs) lengths = [len(t) for t in seqs] _inference_ctx = ( torch.inference_mode if hasattr(torch, "inference_mode") else torch.no_grad ) with _inference_ctx(): for start in range(0, total, batch_size): end = min(start + batch_size, total) batch_sequences = seqs[start:end] batch_names = names[start:end] batch_lengths = lengths[start:end] current_bs = end - start max_length = batch_lengths[0] # Build tensors (only CPU→GPU transfer) sequence_tensor = torch.zeros( (current_bs, max_length), dtype=torch.long, device=device ) attention_mask = torch.zeros( (current_bs, max_length), dtype=torch.bool, device=device ) for i, (seq_tokens, length_i) in enumerate( zip(batch_sequences, batch_lengths) ): sequence_tensor[i, :length_i] = torch.as_tensor( seq_tokens, dtype=torch.long, device=device ) attention_mask[i, :length_i] = True batch_embeddings = diffusion.sequence2labels( sequences=sequence_tensor, sequence_mask=attention_mask, ionic_strength=ionic_strength, ) # Transfer to CPU if requested if return_on_cpu: batch_embeddings = batch_embeddings.cpu() # Process results for i, (name, length_i) in enumerate(zip(batch_names, batch_lengths)): emb = batch_embeddings[i, :length_i] if aggregate: emb = emb.mean(axis=0) if output_directory is not None: torch.save(emb, os.path.join(output_directory, f"{name}.pt")) else: embedding_dict[name] = emb del batch_embeddings, sequence_tensor, attention_mask if ( free_cuda_cache and torch.cuda.is_available() and device.startswith("cuda") ): torch.cuda.empty_cache() return embedding_dict
[docs] def ensemble_encoder_backend( ensemble, device, batch_size, output_directory=None, model_manager=model_manager, encoder_path=None, ddpm_path=None, ): encoder_model, diffusion = model_manager.get_models( device=device, encoder_path=encoder_path, ddpm_path=ddpm_path ) assert isinstance(ensemble, np.ndarray), ( "ensemble must be a numpy array. If you have a torch tensor, convert it to numpy first." ) assert ensemble.ndim == 3, "ensemble must be a 3D array (batch, height, width)." if ensemble.shape[2] != 384: H_pad = max(0, 384 - ensemble.shape[1]) # vertical padding (bottom) W_pad = max(0, 384 - ensemble.shape[2]) # horizontal padding (right) ensemble = np.pad( ensemble, pad_width=( (0, 0), (0, H_pad), (0, W_pad), ), # (N axis, bottom of H axis, right of W axis) mode="constant", constant_values=0, ) # get num_batches and remaining samples num_batches = ensemble.shape[0] // batch_size remaining_samples = ensemble.shape[0] % batch_size latent_spaces = [] # real_batch_count no longer needed (legacy from previous implementation) ensemble = torch.from_numpy(ensemble) ensemble = ensemble.unsqueeze(1) # Add a channel dimension for batch in range(num_batches): start_idx = batch * batch_size end_idx = (batch + 1) * batch_size batch_ensemble = ensemble[start_idx:end_idx] latent_space = encoder_model.encode( batch_ensemble.to(device), ).mode() latent_spaces.append(latent_space.detach().squeeze().cpu().numpy()) if remaining_samples > 0: start_idx = num_batches * batch_size batch_ensemble = ensemble[start_idx:] latent_space = encoder_model.encode( batch_ensemble.to(device), ).mode() latent_spaces.append(latent_space.detach().squeeze().cpu().numpy()) return np.concatenate(latent_spaces)
[docs] def generate_backend( sequence_dict, conformations, device, steps, sampler, return_structures, batch_size, num_cpus_mds, num_mds_init, output_directory, return_data, verbose, show_progress_bar, show_per_step_progress_bar, pdb_trajectory, model_manager=model_manager, ionic_strength=150, constraint=None, encoder_path=None, ddpm_path=None, ): """ Backend function for generating the distance maps using STARLING. NOTE - this function does VERY littel sanity checking; to actually perform predictions use starling.frontend.ensemble_generation. This is NOT a user facing function! Parameters --------------- sequence_dict : dict A dictionary with the sequence names as the key and the sequences as the values. These names will be used to write any output files (if writing is requested). ddpm : str The path to the DDPM model device : str The device to use for predictions. steps : int The number of steps to run the DDPM model. ddim : bool Whether to use DDIM for sampling. return_structures : bool Whether to return the 3D structure. batch_size : int The batch size to use for sampling. num_cpus_mds : int The number of CPUs to use for MDS. There is no point specifying more than the default number of MDS runs performed (defined in configs) output_directory : str or None If None, no output is saved. If not None, will save the output to the specified path. This includes the distance maps and if return_structures=True, the 3D structures. The distance maps are saved as .npy files with the names <sequence_name>_STARLING_DM.npy and the structures are save with the file names <sequence_name>_STARLING.xtc and <sequence_name>_STARLING.pdb. return_data : bool If True, will return the distance maps and structures (if generated) as a dictionary regardless of the output_directory. If False, will return None. Note the reason to set this to None is if you're predicting a large set of sequences this will save memory. verbose : bool Whether to print verbose output. Default is False. show_progress_bar : bool Whether to show a progress bar. Default is True. show_per_step_progress_bar : bool, optional whether to show progress bar per step. pdb_trajectory: bool Whether to save the trajectory as a PDB file. Default is False. model_manager : ModelManager A ModelManager object to manage loaded models. This lets us avoid loading the model iteratively when calling generate multiple times in a single session. Default is model_manager, which is initialized outside of this function code block. To update the path to the models, update the paths in config.py, which are read into the ModelManager object located the model_loading.py encoder_path : str, optional Path to a custom encoder model checkpoint file to use instead of the default. Default is None, which uses the default model path from configs.py. ddpm_path : str, optional Path to a custom diffusion model checkpoint file to use instead of the default. Default is None, which uses the default model path from configs.py. Returns --------------- dict or None: A dict with the sequence names as the key and a starling.ensembl.Ensemble objects for each sequence as values. If output_directory is not none, the output will save to the specified path. """ overall_start_time = time.time() # get models. This will only load once even if we call this # function multiple times. encoder_model, diffusion = model_manager.get_models( device=device, encoder_path=encoder_path, ddpm_path=ddpm_path ) # Construct a sampler if sampler.lower() == "plms": print("Using PLMS sampler") sampler = PLMSSampler( ddpm_model=diffusion, encoder_model=encoder_model, n_steps=steps, ionic_strength=ionic_strength, ) elif sampler.lower() == "ddim": print("Using DDIM sampler") sampler = DDIMSampler( ddpm_model=diffusion, encoder_model=encoder_model, n_steps=steps, ionic_strength=ionic_strength, ) elif sampler.lower() == "ddpm": print("Using DDPM sampler") sampler = DDPMSampler( ddpm_model=diffusion, encoder_model=encoder_model, ionic_strength=ionic_strength, ) else: raise ValueError( f"Error: sampler must be one of 'plms', 'ddim', or 'ddpm'. Got {sampler}." ) # get num_batchs and remaining samples num_batches = conformations // batch_size remaining_samples = conformations % batch_size if remaining_samples > 0: real_batch_count = num_batches + 1 else: real_batch_count = num_batches # dictionary to hold distance maps and structures if applicable. output_dict = {} # see if a progress bar is wanted. If it is, set it up. # position here is 0, so it will be the first progress bar if show_progress_bar: pbar = tqdm( total=len(sequence_dict), position=0, desc="Progress through sequences", leave=True, ) # iterate over sequence_dict for num, seq_name in enumerate(sequence_dict): ## ----------------------------------------- ## Start of prediction cycle for this sequence start_time_prediction = time.time() # list to hold distance maps starling_dm = [] # get sequence sequence = sequence_dict[seq_name] # iterate over batches for actual DDIM sampling for batch in range(num_batches): distance_maps = sampler.sample( batch_size, labels=sequence, show_per_step_progress_bar=show_per_step_progress_bar, batch_count=batch + 1, max_batch_count=real_batch_count, constraint=constraint, ) starling_dm.append( [ symmetrize_distance_map(dm[:, : len(sequence), : len(sequence)]) for dm in distance_maps ] ) # iterate over remaining samples if remaining_samples > 0: distance_maps = sampler.sample( remaining_samples, labels=sequence, show_per_step_progress_bar=show_per_step_progress_bar, batch_count=real_batch_count, max_batch_count=real_batch_count, constraint=constraint, ) starling_dm.append( [ symmetrize_distance_map(dm[:, : len(sequence), : len(sequence)]) for dm in distance_maps ] ) # concatenate symmetrized distance maps. sym_distance_maps = torch.cat( [torch.stack(batch) for batch in starling_dm], dim=0 ) end_time_prediction = time.time() # set time at which we start structure generation to 0 start_time_structure_generation = time.time() # we initialize this to 0 and will update as needed (or not) end_time_structure_generation = time.time() # do ensemble reconstruction if requested if return_structures: coordinates = generate_3d_coordinates_from_distances( device, batch_size, num_cpus_mds, num_mds_init, sym_distance_maps, progress_bar=show_progress_bar, ) # make traj as an sstrajectory object and extract out the ssprotein object ssprotein = SSTrajectory( TRJ=create_ca_topology_from_coords(sequence, coordinates) ).proteinTrajectoryList[0] end_time_structure_generation = time.time() # if no structures are requested, set ssprotein to None else: ssprotein = None # pull the distance maps out of the tensor and convert to numpy final_distance_maps = sym_distance_maps.detach().cpu().numpy() # create Ensemble object. Note if the ssprotein argument is None # this is expected and will initialize the ensemble without # structures E = Ensemble(final_distance_maps, sequence, ssprot_ensemble=ssprotein) # if we are saving things, save as we progress through so we generate # structures/DMs in situ if output_directory is not None: # num == 0 just means we are on the first sequence. if verbose and num == 0: print(f"Saving results to: {os.path.abspath(output_directory)}") # if we're saving structures do that first; if return_structures: # this saves both a topology (PDB) and a trajectory (XTC) file E.save_trajectory( filename_prefix=os.path.join( output_directory, seq_name + "_STARLING" ), pdb_trajectory=pdb_trajectory, ) # save full ensemble E.save(os.path.join(output_directory, f"{seq_name}")) ## End of prediction cycle for this sequence ## ----------------------------------------- # if we are returning data, add the data to the output_dict if return_data: output_dict[seq_name] = E # if not, force cleanup of things to save memory else: del E del final_distance_maps gc.collect() # update progress bar if we have one. if show_progress_bar: pbar.update(1) if verbose: elapsed_time_structure_generation = ( end_time_structure_generation - start_time_structure_generation ) elapsed_time_prediction = end_time_prediction - start_time_prediction total_time = elapsed_time_structure_generation + elapsed_time_prediction n_conformers = len(sym_distance_maps) print( f"\n\n##### SUMMARY OF SEQUENCE PREDICTION ({num + 1}/{len(sequence_dict)}) #####" ) print(f"Sequence name : {seq_name}") print(f"Sequence length : {len(sequence)}") print(f"Number of conformers : {n_conformers}") print(f"Number of steps : {steps}") print( f"Total time for prediction : {round(elapsed_time_prediction, 2)}s ({round(100 * (elapsed_time_prediction / total_time), 2)}% of time)" ) print( f"Total time for structure generation : {round(elapsed_time_structure_generation, 2)}s ({round(100 * (elapsed_time_structure_generation / total_time), 2)}% of time)" ) print(f"Time per conformer : {total_time / n_conformers}s") print("\n") else: # changed from print to pass because if not verbose, we don't need to do anything. pass # make sure we close the progress bar if we used one if show_progress_bar: pbar.close() if verbose: # Convert total time to hours, minutes, and seconds overall_time = time.time() - overall_start_time total_hours = round(overall_time // 3600, 2) total_minutes = round((overall_time % 3600) // 60, 2) total_seconds = round(overall_time % 60, 2) print("-------------------------------------------------------") print(f"Summary of all predictions ({len(sequence_dict)} sequences)") print("-------------------------------------------------------") print( f"\nTotal time (all sequences, all I/O) : {total_hours} hrs {total_minutes} mins {total_seconds} secs" ) current_datetime = datetime.now() formatted_datetime = current_datetime.strftime("%Y-%m-%d %H:%M:%S") print("STARLING predictions completed at:", formatted_datetime) print("") return output_dict