Source code for starling.structure.bme_utils

from dataclasses import dataclass
from datetime import datetime
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np

# ----------------------------------------------------------------------
# Constants
# ----------------------------------------------------------------------
DEFAULT_THETA = 0.5
DEFAULT_MAX_ITERATIONS = 50000
DEFAULT_OPTIMIZER = "L-BFGS-B"
LAMBDA_INIT_SCALE = 1e-3
MIN_WEIGHT_THRESHOLD = 1e-50

# Valid constraint types
VALID_CONSTRAINTS = {"equality", "upper", "lower"}


# ----------------------------------------------------------------------
# ExperimentalObservable
# ----------------------------------------------------------------------
[docs] @dataclass class ExperimentalObservable: """ Container for experimental observable data. Parameters ---------- value : float The experimental value of the observable. uncertainty : float The experimental uncertainty (standard deviation). constraint : str, optional Type of constraint. Must be one of: - "equality" (default): Observable should match value ± uncertainty - "upper": Observable should not exceed value - "lower": Observable should not fall below value name : str, optional Optional name/description of the observable. """ value: float uncertainty: float constraint: str = "equality" name: Optional[str] = None def __post_init__(self): """Validate the observable data.""" if self.uncertainty <= 0: raise ValueError(f"Uncertainty must be positive, got {self.uncertainty}") # Validate constraint if not isinstance(self.constraint, str): raise TypeError( "constraint must be a string ('equality', 'upper', or 'lower'), " f"got {type(self.constraint).__name__}" ) constraint_lower = self.constraint.lower().strip() if constraint_lower not in VALID_CONSTRAINTS: raise ValueError( f"Invalid constraint: '{self.constraint}'. " "Must be 'equality', 'upper', or 'lower'" ) # Normalize to lowercase self.constraint = constraint_lower
[docs] def get_bounds(self) -> Tuple[Optional[float], Optional[float]]: """ Get the optimization bounds for the Lagrange multiplier. These bounds ensure the constraint type is enforced during optimization: - "equality": No bounds (lambda can be any value) - "upper": lambda >= 0 (positive lambda pushes observable down) - "lower": lambda <= 0 (negative lambda pushes observable up) """ if self.constraint == "equality": return (None, None) elif self.constraint == "upper": return (0.0, None) else: # "lower" return (None, 0.0)
# ---------------------------------------------------------------------- # BMEResult # ----------------------------------------------------------------------
[docs] @dataclass class BMEResult: """ Container for BME optimization results. """ weights: np.ndarray initial_weights: np.ndarray lambdas: np.ndarray chi_squared_initial: float chi_squared_final: float phi: float n_iterations: int success: bool message: str theta: float observables: List[ExperimentalObservable] calculated_values: np.ndarray metadata: dict def __str__(self): status = "SUCCESS" if self.success else "FAILED" return ( f"BME Result [{status}]\n" f" Chi-squared initial: {self.chi_squared_initial:.4f}\n" f" Chi-squared final: {self.chi_squared_final:.4f}\n" f" phi (effective fraction): {self.phi:.4f}\n" f" Iterations: {self.n_iterations}\n" f" Theta: {self.theta}" ) def __repr__(self): return self.__str__() @property def kl_divergence(self) -> float: """ returns the KL divergence (relative entropy) from phi. """ if self.phi > 0: return -np.log(self.phi) else: return np.inf # ----- Integrated QC helpers -------------------------------------------------
[docs] def diagnostics(self, warn_threshold: float = 0.5) -> dict: """ Diagnose BME reweighting results and identify potential issues. Returns ------- dict Dictionary containing diagnostic information and warnings. Notes ----- We report two notions of effective sample size: - neff_entropy (N_eff^(S)): entropy-based, derived from Φ = exp(-D_KL). This is the standard BME measure: N_eff^(S) = N * Φ. - neff_renyi2 (N_eff^(2)): 1 / sum_i w_i^2 (Rényi-2 / participation ratio). This is more sensitive to a few large weights, so it is always <= neff_entropy for the same weights. """ diagnostics: dict = {} warnings: List[str] = [] N = len(self.weights) # ---- Effective sample sizes --------------------------------------------- # 1) Entropy-based Neff from phi (matches BME papers) neff_entropy = N * float(self.phi) diagnostics["neff_entropy"] = neff_entropy diagnostics["neff_entropy_fraction"] = float(self.phi) # 2) Rényi-2 / participation ratio: 1 / sum w^2 neff_renyi2 = 1.0 / np.sum(self.weights**2) diagnostics["neff_renyi2"] = float(neff_renyi2) diagnostics["neff_renyi2_fraction"] = float(neff_renyi2 / N) # ---- Weight distribution statistics ------------------------------------- weight_min = float(self.weights.min()) weight_max = float(self.weights.max()) weight_std = float(self.weights.std()) diagnostics["weight_min"] = weight_min diagnostics["weight_max"] = weight_max diagnostics["weight_std"] = weight_std if weight_min > 0: diagnostics["weight_range_orders"] = float( np.log10(weight_max / weight_min) ) else: diagnostics["weight_range_orders"] = np.inf # ---- Chi-squared improvement -------------------------------------------- diagnostics["chi2_improvement"] = ( self.chi_squared_initial - self.chi_squared_final ) diagnostics["chi2_improvement_pct"] = ( (diagnostics["chi2_improvement"] / self.chi_squared_initial) * 100.0 if self.chi_squared_initial > 0 else np.nan ) # ---- Checks / warnings --------------------------------------------------- # Main diversity check uses the entropy-based fraction (Φ) if self.phi < warn_threshold: warnings.append( f"Low Phi ({self.phi:.3f} < {warn_threshold}): " "Significant loss of ensemble diversity. " "Consider increasing theta or loosening observable uncertainties." ) # Secondary check: very concentrated weights (using Renyi-2 Neff) if neff_renyi2 < 0.1 * N: warnings.append( "Low effective sample size (1/Σw²) " f"({neff_renyi2:.1f} / {N}): " f"Only ~{diagnostics['neff_renyi2_fraction'] * 100:.1f}% of frames " "are effectively used (participation ratio)." ) if diagnostics["weight_range_orders"] > 3: warnings.append( "Large weight range " f"({diagnostics['weight_range_orders']:.1f} orders of magnitude): " "A few frames dominate the reweighted ensemble." ) if self.chi_squared_final > 2 * len(self.observables): warnings.append( "High final Chi-squared " f"({self.chi_squared_final:.2f}): Poor fit to experimental data. " "Observables may be incompatible with ensemble." ) diagnostics["warnings"] = warnings diagnostics["status"] = "OK" if len(warnings) == 0 else "WARNING" return diagnostics
[docs] def print_diagnostics( self, warn_threshold: float = 0.5, ): """ Print a formatted diagnostic report for this BME result. """ diag = self.diagnostics(warn_threshold) N = len(self.weights) print("\n" + "=" * 60) print("BME DIAGNOSTIC REPORT") print("=" * 60) print(f"\nOptimization Status: {self.message}") print(f"Success: {self.success}") print(f"Iterations: {self.n_iterations}") # Use fixed column widths for alignment key_w = 32 val_w = 14 print("\nChi-squared:") print(f" {'Initial':<{key_w}} {self.chi_squared_initial:>{val_w}.4f}") print(f" {'Final':<{key_w}} {self.chi_squared_final:>{val_w}.4f}") print( f" {'Improvement':<{key_w}} {diag['chi2_improvement']:>{val_w}.4f} " f"({diag['chi2_improvement_pct']:.1f}%)" ) print("\nEnsemble Diversity:") print(f" {'Phi (Φ, entropy fraction)':<{key_w}} {self.phi:>{val_w}.4f}") print( f" {'N_eff^(S) (entropy-based)':<{key_w}} " f"{diag['neff_entropy']:>{val_w}.1f} / {N}" ) print( f" {'N_eff^(2) (1/Σw², Renyi-2)':<{key_w}} " f"{diag['neff_renyi2']:>{val_w}.1f} / {N}" ) print(f" {'Theta (θ)':<{key_w}} {self.theta:>{val_w}.4f}") print("\nWeight Distribution:") print(f" {'Min':<{key_w}} {diag['weight_min']:>{val_w}.2e}") print(f" {'Max':<{key_w}} {diag['weight_max']:>{val_w}.2e}") print(f" {'Std Dev':<{key_w}} {diag['weight_std']:>{val_w}.2e}") print( f" {'Range (orders of magnitude)':<{key_w}} " f"{diag['weight_range_orders']:>{val_w}.1f}" ) if len(diag["warnings"]) > 0: print(f"\n WARNINGS ({len(diag['warnings'])}):") for i, warning in enumerate(diag["warnings"], 1): print(f" {i}. {warning}") else: print(f"\n✓ Status: {diag['status']} - No issues detected") print("\nNotes on effective sample size:") print(" - N_eff^(S): entropy-based (from Φ = exp(-D_KL));") print(" this matches the usual BME definition.") print(" - N_eff^(2): 1/Σw² (Rényi-2 / participation ratio);") print(" more sensitive to a few large weights, so it is always") print(" ≤ N_eff^(S) for the same weights.") print("=" * 60)
# ---------------------------------------------------------------------- # Theta scan support # ----------------------------------------------------------------------
[docs] @dataclass class ThetaScanResult: """ Container for theta scan results. """ theta_values: np.ndarray chi_squared_values: np.ndarray phi_values: np.ndarray kl_divergence_values: np.ndarray results: List[BMEResult] optimal_theta: float optimal_idx: int method: str # ----- QC / visualization integration -------------------------------
[docs] def plot( self, figsize: Tuple[int, int] = (10, 4), save_path: Optional[str] = None, show: bool = True, ): """ Plot theta-scan diagnostics. Panels (1 x 2): [0] Chi squared vs N_eff (effective frames), colored by log10(theta) [1] Weight distribution at optimal theta """ import matplotlib.pyplot as plt import numpy as np fig, axes = plt.subplots(1, 2, figsize=figsize) theta = self.theta_values chi2 = self.chi_squared_values # phi is our definition of the effective fraction phi = self.phi_values kl = self.kl_divergence_values opt_idx = self.optimal_idx # ------------------------------------------------------------------ # Panel 1: Chi squared vs phi (L-curve style) # ------------------------------------------------------------------ ax1 = axes[0] sc = ax1.scatter( phi, chi2, c=np.log10(theta), s=50, alpha=0.7, ) ax1.plot(phi, chi2, "k-", alpha=0.3, linewidth=1) ax1.scatter( phi[opt_idx], chi2[opt_idx], color="red", s=200, marker="*", edgecolors="black", linewidths=2, label=f"Optimal θ={self.optimal_theta:.3f}", zorder=5, ) ax1.set_xlabel(rf"$N_{{eff}}$ (effective fraction)", fontsize=11) ax1.set_ylabel(rf"$\chi^2$", fontsize=11) # ax1.grid(True, alpha=0.3) ax1.legend(frameon=False) cbar = plt.colorbar(sc, ax=ax1) cbar.set_label(rf"$\log_{{10}}(\theta)$", fontsize=10) # ------------------------------------------------------------------ ax2 = axes[1] optimal_result = self.results[opt_idx] sorted_weights = np.sort(optimal_result.weights)[::-1] ax2.plot( sorted_weights, "o-", markersize=3, color="darkred", label="BME weights" ) ax2.axhline( 1.0 / len(sorted_weights), color="blue", linestyle="--", label="Uniform weight", ) ax2.set_xlabel("Conformation (sorted by weight)", fontsize=11) ax2.set_ylabel("Weight", fontsize=11) ax2.set_yscale("log") ax2.legend(frameon=False) # ax2.grid(True, alpha=0.3) plt.tight_layout() if save_path is not None: fig.savefig(save_path, dpi=150, bbox_inches="tight") if show: plt.show() else: # Prevent auto-rendering in Jupyter inline backend plt.close(fig) return fig
[docs] def print_summary(self, n_show: int = 5): """ Print a formatted summary of theta scan results. """ print("\n" + "=" * 60) print("THETA SCAN SUMMARY") print("=" * 60) print( f"\nScan range: {self.theta_values[0]:.4f} to {self.theta_values[-1]:.4f}" ) print(f"Number of points: {len(self.theta_values)}") print(f"Method: {self.method}\n") print(f"RECOMMENDED THETA: {self.optimal_theta:.4f}") opt_idx = self.optimal_idx print(f"Chi squared: {self.chi_squared_values[opt_idx]:.4f}") print(f"N_eff: {self.phi_values[opt_idx]:.4f}") print(f"Relative Entropy: {self.kl_divergence_values[opt_idx]:.4f}") # # Aligned table header # print("\n Sample of theta values:") # h_theta = "Theta" # h_chi = "Chi squared" # h_neff = "N_eff" # h_kl = "Rel. Entropy" # h_status = "Status" # print(f"{h_theta:>10} {h_chi:>14} {h_neff:>12} {h_kl:>14} {h_status:^8}") # print("-" * 60) # indices = np.linspace(0, len(self.theta_values) - 1, n_show, dtype=int) # for idx in indices: # status = "✓" if self.results[idx].success else "✗" # marker = " <- OPTIMAL" if idx == opt_idx else "" # print( # f"{self.theta_values[idx]:>10.4f} " # f"{self.chi_squared_values[idx]:>14.4f} " # f"{self.phi_values[idx]:>12.4f} " # f"{self.kl_divergence_values[idx]:>14.4f} " # f"{status:^8}{marker}" # ) print( " • L-curve shows trade-off between fit quality (Chi squared) " "and effective sample size (N_eff)" ) print(" • Optimal theta balances these competing objectives") print(" • N_eff < 0.5 indicates significant loss of ensemble diversity") print(" • Consider increasing theta if N_eff is too low at optimum") print("=" * 60)
# ---------------------------------------------------------------------- # Theta scan driver + knee finding # ----------------------------------------------------------------------
[docs] def theta_scan( observables: List[ExperimentalObservable], calculated_values: np.ndarray, theta_range: Union[Tuple[float, float], np.ndarray] = (0.01, 10.0), n_points: int = 15, log_scale: bool = True, max_iterations: int = DEFAULT_MAX_ITERATIONS, optimizer: str = DEFAULT_OPTIMIZER, verbose: bool = False, progress_callback: Optional[Callable[[int, int, float], None]] = None, initial_weights: Optional[np.ndarray] = None, method: str = "perpendicular", ) -> ThetaScanResult: """ Scan a range of regularization parameters (theta) for BME reweighting. Parameters ---------- observables : List[ExperimentalObservable] Experimental observables to use for reweighting. calculated_values : np.ndarray Calculated values from the ensemble. Shape typically (n_models, n_observables). theta_range : tuple or np.ndarray, optional If a tuple (min, max) is given, a grid of `n_points` thetas is created; if an array is supplied it is used directly. Default (0.01, 10.0). n_points : int, optional Number of theta samples when `theta_range` is a tuple. Default 15. log_scale : bool, optional If True and `theta_range` is a tuple, sample theta logarithmically. Default True. max_iterations : int, optional Maximum optimizer iterations forwarded to BME.fit. Default from module. optimizer : str, optional Optimizer name forwarded to BME.fit. Default from module. verbose : bool, optional If True, print progress messages. Default False. progress_callback : callable, optional Optional callback called as progress_callback(current_index, total, theta). initial_weights : np.ndarray or None, optional Optional initial weight vector for BME. method : str, optional Method used to pick optimal theta ('perpendicular' or 'curvature'). Returns ------- ThetaScanResult Object containing theta grid, metric arrays, individual BMEResult objects, and the selected optimal theta and index. Raises ------ ValueError If n_points < 1 when a tuple theta_range is provided, or if log_scale is True and theta_range tuple contains non-positive endpoints. """ from starling.structure.bme import BME if isinstance(theta_range, (tuple, list)): if n_points < 1: raise ValueError( f"n_points must be >= 1 when theta_range is a tuple/list. Received: {n_points}" ) if log_scale: if theta_range[0] <= 0 or theta_range[1] <= 0: raise ValueError( "theta_range endpoints must be positive when log_scale=True, Received: " f"{theta_range}" ) if progress_callback is not None and not callable(progress_callback): raise TypeError("progress_callback must be callable or None") # Generate theta values if isinstance(theta_range, (tuple, list)): if log_scale: theta_values = np.logspace( np.log10(theta_range[0]), np.log10(theta_range[1]), n_points ) else: theta_values = np.linspace(theta_range[0], theta_range[1], n_points) else: theta_values = np.asarray(theta_range) chi_squared_vals: List[float] = [] phi_vals: List[float] = [] kl_vals: List[float] = [] bme_results: List[BMEResult] = [] for i, theta in enumerate(theta_values): if progress_callback is not None: progress_callback(i + 1, len(theta_values), theta) if verbose: print(f"Processing theta {i + 1}/{len(theta_values)}: {theta:.4f}") bme = BME( observables=observables, calculated_values=calculated_values, initial_weights=initial_weights, ) # Single-theta fit, no auto scan: result = bme.fit( max_iterations=max_iterations, optimizer=optimizer, verbose=False, theta=float(theta), auto_theta=False, ) bme_results.append(result) chi_squared_vals.append(result.chi_squared_final) phi_vals.append(result.phi) kl_vals.append(result.kl_divergence) chi_squared_vals_arr = np.array(chi_squared_vals) phi_vals_arr = np.array(phi_vals) kl_vals_arr = np.array(kl_vals) optimal_idx, method = find_optimal_theta( chi_squared_vals_arr, kl_vals_arr, method=method ) optimal_theta = float(theta_values[optimal_idx]) return ThetaScanResult( theta_values=theta_values, chi_squared_values=chi_squared_vals_arr, phi_values=phi_vals_arr, kl_divergence_values=kl_vals_arr, results=bme_results, optimal_theta=optimal_theta, optimal_idx=optimal_idx, method=method, )
[docs] def find_optimal_theta( chi_squared_values: np.ndarray, kl_divergence_values: np.ndarray, method: str = "perpendicular", ) -> Tuple[int, str]: """ Find optimal theta value using L-curve analysis. """ if method == "curvature": idx = _find_knee_curvature(chi_squared_values, kl_divergence_values) return idx, "Menger curvature" elif method == "perpendicular": idx = _find_knee_perpendicular(chi_squared_values, kl_divergence_values) return idx, "Perpendicular distance" else: raise ValueError( f"Unknown method: {method}, must be 'curvature' or 'perpendicular'" )
def _find_knee_curvature(x_values: np.ndarray, y_values: np.ndarray) -> int: """ Find knee point using Menger curvature (3-point formula). """ x_norm = (x_values - x_values.min()) / (x_values.max() - x_values.min() + 1e-10) y_norm = (y_values - y_values.min()) / (y_values.max() - y_values.min() + 1e-10) n_points = len(x_norm) curvature = np.zeros(n_points) for i in range(1, n_points - 1): p0 = np.array([x_norm[i - 1], y_norm[i - 1]]) p1 = np.array([x_norm[i], y_norm[i]]) p2 = np.array([x_norm[i + 1], y_norm[i + 1]]) v1 = p1 - p0 v2 = p2 - p1 area = abs(v1[0] * v2[1] - v1[1] * v2[0]) / 2.0 a = np.linalg.norm(p2 - p1) b = np.linalg.norm(p0 - p2) c = np.linalg.norm(p1 - p0) if a * b * c > 1e-10: curvature[i] = 4 * area / (a * b * c) curvature[0] = curvature[1] curvature[-1] = curvature[-2] return int(np.argmax(curvature)) def _find_knee_perpendicular(x_values: np.ndarray, y_values: np.ndarray) -> int: """ Find knee point using perpendicular distance from line connecting endpoints. """ x_norm = (x_values - x_values.min()) / (x_values.max() - x_values.min() + 1e-10) y_norm = (y_values - y_values.min()) / (y_values.max() - y_values.min() + 1e-10) p1 = np.array([x_norm[0], y_norm[0]]) p2 = np.array([x_norm[-1], y_norm[-1]]) line_vec = p2 - p1 line_len = np.linalg.norm(line_vec) if line_len < 1e-10: return len(x_values) // 2 line_unit = line_vec / line_len distances = [] for i in range(len(x_norm)): point = np.array([x_norm[i], y_norm[i]]) vec = point - p1 proj_length = np.dot(vec, line_unit) proj_point = p1 + proj_length * line_unit perp_dist = np.linalg.norm(point - proj_point) distances.append(perp_dist) return int(np.argmax(distances))