Source code for starling.configs

import importlib.util
import os

from starling.utilities import fix_ref_to_home

# stand-alone default parameters
# NB: you can overwrite these by adding a configs.py file to ~/.starling_weights/
DEFAULT_MODEL_DIR = os.path.join(
    os.path.expanduser(os.path.join("~/", ".starling_weights"))
)
# DEFAULT_ENCODE_WEIGHTS = "model-kernel-epoch=99-epoch_val_loss=1.72.ckpt"
# DEFAULT_DDPM_WEIGHTS = "model-kernel-epoch=47-epoch_val_loss=0.03.ckpt"

DEFAULT_ENCODE_WEIGHTS = "STARLING_v2.0.0_ViT_VAE_2025_10_14.ckpt"
DEFAULT_DDPM_WEIGHTS = "STARLING_v2.0.0_ViT_DDPM_2025_10_14.ckpt"
DEFAULT_NUMBER_CONFS = 400
DEFAULT_BATCH_SIZE = 100
DEFAULT_STEPS = 30
DEFAULT_MDS_NUM_INIT = 4
DEFAULT_STRUCTURE_GEN = "mds"
CONVERT_ANGSTROM_TO_NM = 10
MAX_SEQUENCE_LENGTH = 380  # set longest sequence the model can work on
DEFAULT_IONIC_STRENGTH = 150  # default ionic strength in mM
DEFAULT_SAMPLER = "ddim"  # default sampler for diffusion model

# Model compilation settings
TORCH_COMPILATION = {
    "enabled": False,
    "options": {
        "mode": "default",  # Options: "default", "reduce-overhead", "max-autotune"
        "fullgraph": True,  # Whether to use the full graph for compilation
        "backend": "inductor",  # Default PyTorch backend
        "dynamic": None,  # Whether to handle dynamic shapes
    },
}


# model model-kernel-epoch=47-epoch_val_loss=0.03.ckpt has  a UNET_LABELS_DIM of 512
# model model-kernel-epoch=47-epoch_val_loss=0.03.ckpt has a UNET_LABELS_DIM of 384
UNET_LABELS_DIM = 512

# Path to user config file
USER_CONFIG_PATH = os.path.expanduser(
    os.path.join("~/", ".starling_weights", "configs.py")
)


##
## The code block below lets us over-ride default values based on the configs.py file in the
## ~/.starling_weights directory
##


[docs] def load_user_config(): """Load user configuration if the file exists and override default values.""" if os.path.exists(USER_CONFIG_PATH): spec = importlib.util.spec_from_file_location("user_config", USER_CONFIG_PATH) user_config = importlib.util.module_from_spec(spec) spec.loader.exec_module(user_config) for key, value in vars(user_config).items(): if not key.startswith("__") and key in globals(): old_value = globals()[key] globals()[key] = value print(f"[Starling Config] Overriding {key}: {old_value}{value}")
# Load user-defined config if available load_user_config() ### Derived default values # default paths to the model weights DEFAULT_ENCODER_WEIGHTS_PATH = fix_ref_to_home( os.path.join(DEFAULT_MODEL_DIR, DEFAULT_ENCODE_WEIGHTS) ) DEFAULT_DDPM_WEIGHTS_PATH = fix_ref_to_home( os.path.join(DEFAULT_MODEL_DIR, DEFAULT_DDPM_WEIGHTS) ) # Github Releases URLs for model weights GITHUB_ENCODER_URL = ( f"https://github.com/idptools/starling/releases/download/v2.0.0/{DEFAULT_ENCODE_WEIGHTS}" ) GITHUB_DDPM_URL = f"https://github.com/idptools/starling/releases/download/v2.0.0/{DEFAULT_DDPM_WEIGHTS}" # Update default paths to check Hub first DEFAULT_ENCODER_WEIGHTS_PATH = os.environ.get( "STARLING_ENCODER_PATH", GITHUB_ENCODER_URL ) DEFAULT_DDPM_WEIGHTS_PATH = os.environ.get("STARLING_DDPM_PATH", GITHUB_DDPM_URL) # Set the default number of CPUs to use DEFAULT_CPU_COUNT_MDS = min(DEFAULT_MDS_NUM_INIT, os.cpu_count()) # define valid amino acids VALID_AA = "ACDEFGHIKLMNPQRSTVWY" # define conversion dictionaries for AAs AA_THREE_TO_ONE = { "ALA": "A", "CYS": "C", "ASP": "D", "GLU": "E", "PHE": "F", "GLY": "G", "HIS": "H", "ILE": "I", "LYS": "K", "LEU": "L", "MET": "M", "ASN": "N", "PRO": "P", "GLN": "Q", "ARG": "R", "SER": "S", "THR": "T", "VAL": "V", "TRP": "W", "TYR": "Y", } AA_ONE_TO_THREE = {} for x in AA_THREE_TO_ONE: AA_ONE_TO_THREE[AA_THREE_TO_ONE[x]] = x # --------------------------------------------------------------------------- # Search (FAISS + SQLite) default configuration & lazy fetch # --------------------------------------------------------------------------- # Directory for cached search artifacts (separate from model weights to allow lighter syncs) DEFAULT_SEARCH_DIR = os.path.expanduser(os.path.join("~", ".starling_search")) # Default artifact filenames (can be overridden via user config or env) DEFAULT_FAISS_INDEX_NAME = ( "ensemble_search_gpu_nlist_32768_m_64_nbits_8_use_opq_True_compressed_False.faiss" ) DEFAULT_SEQSTORE_NAME = DEFAULT_FAISS_INDEX_NAME + ".seqs.sqlite" DEFAULT_MANIFEST_NAME = DEFAULT_FAISS_INDEX_NAME + ".manifest.json" # Environment variable overrides (paths OR HTTP(S) URLs) ENV_FAISS_INDEX_PATH = os.environ.get("STARLING_FAISS_INDEX_PATH") ENV_SEQSTORE_PATH = os.environ.get("STARLING_SEQSTORE_PATH") ENV_MANIFEST_PATH = os.environ.get("STARLING_FAISS_MANIFEST_PATH") ZENODO_FAISS_INDEX_URL = os.environ.get( "STARLING_ZENODO_FAISS_URL", "https://zenodo.org/records/17342150/files/ensemble_search_gpu_nlist_32768_m_64_nbits_8_use_opq_True_compressed_False.faiss?download=1", ) ZENODO_SEQSTORE_URL = os.environ.get( "STARLING_ZENODO_SEQSTORE_URL", "https://zenodo.org/records/17342150/files/ensemble_search_gpu_nlist_32768_m_64_nbits_8_use_opq_True_compressed_False.faiss.seqs.sqlite?download=1", ) ZENODO_MANIFEST_URL = os.environ.get( "STARLING_ZENODO_MANIFEST_URL", "https://zenodo.org/records/17342150/files/ensemble_search_gpu_nlist_32768_m_64_nbits_8_use_opq_True_compressed_False.faiss.manifest.json?download=1", ) # Resolved local cache paths (before existence check) DEFAULT_FAISS_INDEX_PATH = ENV_FAISS_INDEX_PATH or os.path.join( DEFAULT_SEARCH_DIR, DEFAULT_FAISS_INDEX_NAME ) DEFAULT_SEQSTORE_DB_PATH = ENV_SEQSTORE_PATH or os.path.join( DEFAULT_SEARCH_DIR, DEFAULT_SEQSTORE_NAME ) DEFAULT_FAISS_MANIFEST_PATH = ENV_MANIFEST_PATH or os.path.join( DEFAULT_SEARCH_DIR, DEFAULT_MANIFEST_NAME ) FAISS_INDEX_MD5 = ( os.environ.get("STARLING_FAISS_INDEX_MD5") or "e4a72e12b2f9cdabd8ec4f8207f3d28d" ) SEQSTORE_MD5 = ( os.environ.get("STARLING_SEQSTORE_MD5") or "ade24690e7962768eee1acbb4f95904c" ) MANIFEST_MD5 = ( os.environ.get("STARLING_FAISS_MANIFEST_MD5") or "f0057554e3303b3f2e7b4e2fd3aad70a" ) def _md5_file(path: str) -> str: import hashlib h = hashlib.md5() with open(path, "rb") as f: for chunk in iter(lambda: f.read(1 << 20), b""): h.update(chunk) return h.hexdigest() def _normalize_expected_md5(expected: str) -> str: digest = expected.strip().lower() if not digest: return "" if digest.startswith("md5:"): digest = digest.split(":", 1)[1] return digest def _download_if_missing(url: str, dest: str, expected_checksum: str = "") -> None: """Download a file to a temporary path then atomically publish. Writes to dest+'.part' first; on success (and optional hash verify) renames to dest. Cleans up partial file on failure or hash mismatch. """ if not url or "PLACEHOLDER" in url: return need = True expected_md5 = _normalize_expected_md5(expected_checksum) if os.path.exists(dest): if expected_md5: try: if _md5_file(dest) == expected_md5: need = False except Exception: pass else: need = False if not need: return os.makedirs(os.path.dirname(dest) or ".", exist_ok=True) tmp = dest + ".part" resume_bytes = os.path.getsize(tmp) if os.path.exists(tmp) else 0 from urllib import error, request while True: headers = {} if resume_bytes: headers["Range"] = f"bytes={resume_bytes}-" req = request.Request(url, headers=headers) try: resp = request.urlopen(req) break except error.HTTPError as e: if resume_bytes and e.code == 416: try: os.remove(tmp) except OSError: pass resume_bytes = 0 continue raise total_size = resp.getheader("Content-Length") if total_size is not None: total_size = int(total_size) if getattr(resp, "status", None) == 206: total_size += resume_bytes mode = "ab" if resume_bytes and getattr(resp, "status", None) == 206 else "wb" if mode == "wb" and resume_bytes: resume_bytes = 0 print( f"[Starling Search] Downloading {url} -> {dest}" + (" (resuming)" if resume_bytes else "") ) chunk_size = 4 << 20 from tqdm import tqdm progress = tqdm( total=total_size, initial=resume_bytes, unit="B", unit_scale=True, unit_divisor=1024, desc=os.path.basename(dest), ) try: with open(tmp, mode) as f: while True: chunk = resp.read(chunk_size) if not chunk: break f.write(chunk) progress.update(len(chunk)) finally: progress.close() resp.close() # Hash verify before publish if expected_md5: try: got = _md5_file(tmp) if got.lower() != expected_md5.lower(): print( f"[Starling Search] MD5 mismatch (expected {expected_md5} got {got}); discarding" ) try: os.remove(tmp) except Exception: pass return except Exception as e: print(f"[Starling Search] Hash check failed: {e}") # proceed without deleting; still publish os.replace(tmp, dest)
[docs] def ensure_search_artifacts(download: bool = True) -> tuple[str, str, str]: """Ensure FAISS index, sequence store, and manifest are present locally. Attempts to download from the configured URLs when files are missing and ``download`` is True. Returns the resolved paths regardless of existence. """ if download: _download_if_missing( ZENODO_FAISS_INDEX_URL, DEFAULT_FAISS_INDEX_PATH, FAISS_INDEX_MD5, ) _download_if_missing( ZENODO_SEQSTORE_URL, DEFAULT_SEQSTORE_DB_PATH, SEQSTORE_MD5, ) _download_if_missing( ZENODO_MANIFEST_URL, DEFAULT_FAISS_MANIFEST_PATH, MANIFEST_MD5, ) return ( DEFAULT_FAISS_INDEX_PATH, DEFAULT_SEQSTORE_DB_PATH, DEFAULT_FAISS_MANIFEST_PATH, )