#
# Core utilities for the package. This should not importa anything from within starling
# to avoid circular imports.
#
import gzip
import lzma
import os
import pickle
import platform
import warnings
import numpy as np
import torch
# code that allows access to the data directory
_ROOT = os.path.abspath(os.path.dirname(__file__))
[docs]
def get_data(path):
return os.path.join(_ROOT, "data", path)
[docs]
def fix_ref_to_home(input_path):
"""
Function to fix the path to the home directory.
Parameters
---------------
path : str
The path to fix.
Returns
---------------
str: The fixed path.
"""
if input_path.startswith("~"):
return os.path.expanduser(input_path)
return input_path
[docs]
def check_file_exists(input_path):
"""
Function to check if a file exists.
Parameters
---------------
path : str
The path to check.
Returns
---------------
bool: True if the file exists, False otherwise.
"""
return os.path.exists(input_path) and os.path.isfile(input_path)
[docs]
def remove_extension(input_path, verbose=True):
"""
Function to remove the extension from a file.
Parameters
---------------
path : str
The path to remove the extension from.
verbose : str
Define how noisy to be...
Returns
---------------
str
The path with the extension removed.
"""
new_filename = os.path.splitext(input_path)[0]
# added this in so we don't silently edit away a filename with a period
# that would be invisible...
if verbose:
if input_path != new_filename:
print(
f"Warning: removed file extension from input file name.\nWas: {input_path}\nNow: {new_filename}"
)
return new_filename
[docs]
def parse_output_path(args):
"""
Parse the output path from the command line arguments.
Parameters
----------
args : argparse.Namespace
The command line arguments.
Returns
-------
str
The output path and filename without an extension.
"""
# get the filename (+extension) if the input file, without
# any path info
input_filename = os.path.basename(args.input_file)
# if no output is specified, use the current directory;
# this is the default behavior
if args.output == ".":
outname = input_filename
# if input was provided for the output
else:
# if we were passed a path
if os.path.isdir(args.output):
outname = os.path.join(args.output, input_filename)
else:
outname = args.output
# remove the extension
outname = os.path.splitext(outname)[0]
return outname
[docs]
def get_macOS_version():
"""
Function to check the macOS version.
Returns
---------------
int
Returns -1 or the major macOS version.
"""
# get the macOS version
macos_version = platform.mac_ver()[0]
if len(macos_version) == 0:
return -1
else:
try:
return int(macos_version.split(".")[0])
except ValueError:
return -1
[docs]
def check_device(use_device, default_device="gpu"):
"""
Function to check the device was correctly set.
Parameters
---------------
use_device : str
Identifier for the device to be used for predictions.
Possible inputs: 'cpu', 'mps', 'cuda', 'cuda:int', where the int corresponds to
the index of a specific cuda-enabled GPU. If 'cuda' is specified and
cuda.is_available() returns False, this will raise an Exception
If 'mps' is specified and mps is not available, an exception will be raised.
default_device : str
The default device to use if device=None.
If device=None and default_device != 'cpu' and default_device is
not available, device_string will be returned as 'cpu'.
Default is 'gpu'. This checks first for cuda and then for mps
because STARLING is faster on both than it is on CPU, so we should
use the fastest device available.
Options are 'cpu' or 'gpu'
Returns
---------------
torch.device: A PyTorch device object representing the device to use.
"""
# Helper function to get CUDA device string (e.g., 'cuda:0', 'cuda:1')
def get_cuda_device(cuda_str):
if cuda_str == "cuda":
return torch.device("cuda")
else:
device_index = int(cuda_str.split(":")[1])
num_devices = torch.cuda.device_count()
if device_index >= num_devices:
raise ValueError(
f"{cuda_str} specified, but only {num_devices} CUDA-enabled GPUs are available. "
f"Valid device indices are from 0 to {num_devices - 1}."
)
return torch.device(cuda_str)
# If `use_device` is None, fall back to `default_device`
if use_device is None:
default_device = default_device.lower()
if default_device == "cpu":
return torch.device("cpu")
elif default_device == "gpu":
if torch.cuda.is_available():
return torch.device("cuda")
elif torch.backends.mps.is_available():
return torch.device("mps")
else:
return torch.device("cpu")
else:
raise ValueError("Default device can only be set to 'cpu' or 'gpu'.")
# if a device is passed as torch.device, change to string so we
# can make lowercase str for handling different device types.
if isinstance(use_device, torch.device):
use_device = str(use_device)
# Ensure `use_device` is a string
if not isinstance(use_device, str):
raise ValueError(
"Device must be type torch.device or string, valid options are: 'cpu', 'mps', 'cuda', or 'cuda:int'."
)
# make lower case to make checks easier.
use_device = use_device.lower()
# Handle specific device strings
if use_device == "cpu":
return torch.device("cpu")
elif use_device == "mps":
if torch.backends.mps.is_available():
return torch.device("mps")
else:
raise ValueError(
"MPS was specified, but MPS is not available. Make sure you're running on an Apple device with MPS support."
)
elif use_device.startswith("cuda"):
if not torch.cuda.is_available():
raise ValueError(
f"{use_device} was specified, but torch.cuda.is_available() returned False."
)
return get_cuda_device(use_device)
else:
raise ValueError(
"Device must be 'cpu', 'mps', 'cuda', or 'cuda:int' (where int is a valid GPU index)."
)
# This should never be reached
raise RuntimeError(
"Unexpected state in the check_device function. Please raise an issue on GitHub."
)
[docs]
def write_starling_ensemble(
ensemble_object,
filename,
compress=False,
reduce_precision=None,
compression_algorithm="lzma",
verbose=True,
):
"""
Function to write the STARLING ensemble to a file in the STARLING
format (.starling). This is actially just a dictionary with the
amino acid sequence, the distance maps, and the SSProtein object
if available.
Parameters
---------------
ensemble_object : starling.structure.Ensemble
The STARLING ensemble object to save to a file.
filename : str
The filename to save the ensemble to; note this should
not include a file extenison and if it does this will
be removed
compress : bool
Whether to compress the file or not. Default is False.
reduce_precision : bool
Whether to reduce the precision of the distance map to a
single decimal point and cast to float16 if possible.
Default is None. Sets to False if compression is False,
but True if compression is True.
compression_algorithm : str
The compression algorithm to use. Options are 'gzip' and 'lzma'.
`lzma` gives better compression if reduce_precision is set to True,
but actually 'gzip' is better if reduce_precision is False. 'lzma'
is also slower than 'gzip'. Default is 'lzma'.
verbose : bool
Flag to define how noisy we should be
"""
# set reduce_precision to mirror compress if not set
if reduce_precision is None:
reduce_precision = compress
# build_the save dictionary
save_dict = {
"sequence": ensemble_object.sequence,
"distance_maps": ensemble_object._Ensemble__distance_maps,
"traj": ensemble_object._Ensemble__trajectory,
"DEFAULT_ENCODER_WEIGHTS_PATH": ensemble_object._Ensemble__metadata[
"DEFAULT_ENCODER_WEIGHTS_PATH"
],
"DEFAULT_DDPM_WEIGHTS_PATH": ensemble_object._Ensemble__metadata[
"DEFAULT_DDPM_WEIGHTS_PATH"
],
"VERSION": ensemble_object._Ensemble__metadata["VERSION"],
"DATE": ensemble_object._Ensemble__metadata["DATE"],
}
# if we wish to reduce the precision of the distance map to a single decimal point
if reduce_precision:
# run here so we catch warnings if they happen
with warnings.catch_warnings(record=True) as w:
# cast to float16 array and round to 1 decimal place
tmp = np.round(ensemble_object._Ensemble__distance_maps, decimals=1).astype(
"float16"
)
# check if a RuntimeWarning was raised
for warning in w:
# IF yes, then we actually don't cast because we can't do so faithfully
if issubclass(warning.category, RuntimeWarning):
print(
"Warning: Could not reduce precision of distance maps to float16. Saving as float64."
)
tmp = np.round(ensemble_object._Ensemble__distance_maps, decimals=1)
# update the save dictionary one way or another.
save_dict["distance_maps"] = tmp
# Remove the extension if it exists
filename = remove_extension(filename, verbose=verbose)
# add starling extension
filename = filename + ".starling"
# If we want to compress the file
if compress:
if compression_algorithm == "gzip":
with gzip.open(filename + ".gzip", "wb") as file:
pickle.dump(save_dict, file)
elif compression_algorithm == "lzma":
with lzma.open(filename + ".xz", "wb") as file:
pickle.dump(save_dict, file)
else:
raise ValueError(
f"Compression algorithm {compression_algorithm} is not supported. Supported algorithms are 'gzip' and 'lzma'."
)
else:
with open(filename, "wb") as file:
pickle.dump(save_dict, file)
[docs]
def read_starling_ensemble(filename):
"""
Function to read a STARLING ensemble from a file in the STARLING
format (.starling) or a compressed starling file (.gzip, .xz).
The .starling file is actially just a dictionary with the
amino acid sequence, the distance maps, some metadata and the
SSProtein object if available.
Note this determines the file type based on the extension and
will raise an error if the file is not a .starling file, a
a .starling.gzip or .starling.xz file.
Parameters
---------------
filename : str
The filename to read the ensemble from; note this should
not include a file extenison and if it does this will
be removed
Returns
---------------
starling.structure.Ensemble: The STARLING ensemble object
"""
if filename.endswith(".starling.gzip"):
try:
with gzip.open(filename, "rb") as file:
return_dict = pickle.load(file)
except Exception:
raise ValueError(
f"Could not read the file {filename}. Please check the path and try again."
)
elif filename.endswith(".starling.xz"):
try:
with lzma.open(filename, "rb") as file:
return_dict = pickle.load(file)
except Exception:
raise ValueError(
f"Could not read the file {filename}. Please check the path and try again."
)
elif filename.endswith(".starling"):
try:
with open(filename, "rb") as file:
return_dict = pickle.load(file)
except Exception:
raise ValueError(
f"Could not read the file {filename}. Please check the path and try again."
)
else:
raise ValueError(
f"File {filename} does not have the extension of .starling.gzip, .starling.xz or .starling."
)
# NOTE: We if the distance maps are float16, we cast them to float32
if return_dict["distance_maps"].dtype == "float16":
return_dict["distance_maps"] = return_dict["distance_maps"].astype("float32")
return return_dict
[docs]
def symmetrize_distance_maps(dist_maps):
"""
Symmetrizes a stack of distance maps along an axis by reflecting the upper triangle
onto the lower triangle and setting the diagonal values to zero.
Parameters
----------
dist_maps : np.ndarray
A 3D NumPy array of shape (N, M, M), where N is the number of distance maps.
Returns
-------
np.ndarray
A symmetrized stack of distance maps.
"""
# Ensure input is a 3D array
assert dist_maps.ndim == 3, "Input must be a 3D array of shape (N, M, M)."
assert dist_maps.shape[1] == dist_maps.shape[2], "Each distance map must be square."
# Create masks for the upper triangle
M = dist_maps.shape[1]
mask_upper_triangle = np.triu_indices(M, k=1)
# Reflect the upper triangle onto the lower triangle
dist_maps[:, mask_upper_triangle[1], mask_upper_triangle[0]] = dist_maps[
:, mask_upper_triangle[0], mask_upper_triangle[1]
]
# Set diagonal values to zero
np.einsum("nii->ni", dist_maps)[:] = 0 # Efficient diagonal zeroing
return dist_maps
[docs]
def get_off_diagonals(
distance_map, min_separation=1, max_separation=4, return_mean=False
):
"""
Function to calculate the the off-diagonal elements of a matrix.
This is useful for error checking as we can KNOW the max distance between
any pair of residues base on the contour length of the protein, allowing
us to identify conformations that are simply impossible.
Parameters
---------------
distance_map : np.ndarray
The distance map to check for errors.
min_separation : int
The minimum sequence separation to check across.
max_separation : int
The maxiumum sequence separation to check across.
Returns
---------------
np.ndarray
All off-diagonal elements of the distance map from
min_separation to max_separation away from the true
diagonal.
"""
# check we have a square...
if distance_map.shape[0] != distance_map.shape[1]:
raise ValueError("Input matrix must be square.")
# get sequence length
n = distance_map.shape[0]
# build a single list with all the diagonal vals
values = []
for d in range(min_separation, max_separation + 1):
values.extend(distance_map.diagonal(d)) # Upper diagonals
values.extend(distance_map.diagonal(-d)) # Lower diagonals
if return_mean:
return np.mean(values)
else:
return np.array(values)
[docs]
def check_distance_map_for_error(
distance_map, min_separation=1, max_separation=None, max_bond_length=4.81
):
"""
Check a distance map for physically impossible inter-residue distances.
Two residues separated by ``|i - j|`` positions in the sequence are
connected by ``|i - j|`` bonds, so the largest distance they can possibly
be apart is ``|i - j| * max_bond_length`` (a fully extended chain). Any
measured distance that exceeds this bound is physically impossible and
flags the conformation as erroneous.
Crucially, the bound is applied *per residue pair* using each pair's own
sequence separation. A single global threshold (as used previously) either
misses short-range errors -- a sequence-adjacent pair could be ~4 x too
far apart without being caught -- or falsely flags valid long-range pairs.
The bound is a hard physical maximum, so it never produces false positives
(assuming no bond exceeds ``max_bond_length``). For large separations it
becomes loose and therefore less sensitive, but it remains correct, which
is why all pairs can safely be checked at once.
Parameters
---------------
distance_map : np.ndarray
The ``(n, n)`` distance map to check for errors.
min_separation : int
The minimum sequence separation ``|i - j|`` to check across. Default
is 1, which skips only the zero diagonal.
max_separation : int or None
The maximum sequence separation ``|i - j|`` to check across. If None
(default) every pair of residues is checked.
max_bond_length : float
Maximum physical length of a single bond in Angstroms, including an
error term. The Mpipi bond length is 3.81 A; the default of 4.81 A adds
a +1 A per-bond error margin to minimise the risk of false positives.
Returns
---------------
bool
Returns True if any residue pair is further apart than physically
possible, and False otherwise.
"""
if distance_map.shape[0] != distance_map.shape[1]:
raise ValueError("Input matrix must be square.")
n = distance_map.shape[0]
# sequence separation |i - j| for every residue pair
sep = np.abs(np.subtract.outer(np.arange(n), np.arange(n)))
# maximum physically possible distance for each pair (fully extended chain)
max_possible_dist = sep * max_bond_length
# restrict to the requested separation window
mask = sep >= min_separation
if max_separation is not None:
mask &= sep <= max_separation
# flag if ANY in-window pair exceeds its own physical maximum
return bool(np.any((distance_map > max_possible_dist) & mask))
[docs]
def helix_dm(L: int, rise_per_res=1.5, angle_per_res_deg=100, radius=2.3):
"""
Generate an L x L distance matrix for an alpha-helix made of L residues.
Parameters:
- L: number of residues
- rise_per_res: Ã…ngstroms rise per residue along the helix axis
- angle_per_res_deg: degrees of rotation per residue
- radius: radius of the helix in Ã…ngstroms
Returns:
- D: (L, L) distance matrix (float32)
"""
angles = np.deg2rad(np.arange(L) * angle_per_res_deg) # shape (L,)
z = np.arange(L) * rise_per_res
x = radius * np.cos(angles)
y = radius * np.sin(angles)
coords = np.stack([x, y, z], axis=1) # shape (L, 3)
# Compute pairwise Euclidean distances
diff = coords[:, None, :] - coords[None, :, :] # (L, L, 3)
D = np.linalg.norm(diff, axis=-1).astype(np.float32) # (L, L)
return D