import time
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
from IPython import embed
from scipy.spatial.distance import pdist, squareform
from sklearn.manifold import MDS
from tqdm import trange
from starling import configs
from starling.frontend.ensemble_generation import generate
from starling.inference.evaluate_vae import get_errors
from starling.structure.coordinates import (
create_ca_topology_from_coords,
distance_matrix_to_3d_structure_gd,
distance_matrix_to_3d_structure_mds,
distance_matrix_to_3d_structure_torch_mds,
)
from starling.utilities import get_data
[docs]
def visualize_comparison(original_dms, scipy_coords, torch_coords):
"""
Visualization function with color scaling based on min/max errors for both methods.
"""
fig, axes = plt.subplots(2, 3, figsize=(7, 5))
# Compute all differences first for consistent scaling
all_diffs = []
for i in range(3):
scipy_dm = squareform(pdist(scipy_coords[i]))
torch_dm = squareform(pdist(torch_coords[i]))
scipy_diff = original_dms[i] - scipy_dm
torch_diff = original_dms[i] - torch_dm
all_diffs.extend([scipy_diff, torch_diff])
vmin = min(diff.min() for diff in all_diffs)
vmax = max(diff.max() for diff in all_diffs)
abs_max = max(abs(vmin), abs(vmax))
# Define modern colormap
cmap = "viridis"
for i in range(3):
scipy_dm = squareform(pdist(scipy_coords[i]))
torch_dm = squareform(pdist(torch_coords[i]))
scipy_diff = original_dms[i] - scipy_dm
torch_diff = original_dms[i] - torch_dm
# Plot with improved aesthetics
sns.heatmap(
scipy_diff,
ax=axes[0, i],
cmap=cmap,
center=0,
vmin=-abs_max,
vmax=abs_max,
cbar_kws={"shrink": 0.8},
square=True,
)
axes[0, i].set_title(
f"Scipy MDS Conformer {i + 1}\nRMSE: {np.sqrt(np.mean(scipy_diff**2)):.2e}",
pad=10,
fontsize=9,
)
sns.heatmap(
torch_diff,
ax=axes[1, i],
cmap=cmap,
center=0,
vmin=-abs_max,
vmax=abs_max,
cbar_kws={"shrink": 0.8},
square=True,
)
axes[1, i].set_title(
f"Torch SMACOF Conformer {i + 1}\nRMSE: {np.sqrt(np.mean(torch_diff**2)):.2e}",
pad=10,
fontsize=9,
)
# Remove tick labels for cleaner look
axes[0, i].set_xticks([])
axes[0, i].set_yticks([])
axes[1, i].set_xticks([])
axes[1, i].set_yticks([])
plt.tight_layout()
return fig
[docs]
def plot_error_distributions(original_dms, scipy_coords, torch_coords):
"""
Create a histogram comparing error distributions for both methods.
"""
scipy_errors = []
torch_errors = []
for i in range(len(original_dms)):
scipy_dm = squareform(pdist(scipy_coords[i]))
torch_dm = squareform(pdist(torch_coords[i]))
scipy_diff = (original_dms[i] - scipy_dm).flatten()
torch_diff = (original_dms[i] - torch_dm).flatten()
scipy_errors.extend(scipy_diff)
torch_errors.extend(torch_diff)
fig, ax = plt.subplots(figsize=(8, 5))
# Modern blue and green
colors = ["#3498db", "#2ecc71"]
# Plot histograms with improved aesthetics
ax.hist(
scipy_errors,
bins=50,
alpha=0.6,
label="Scipy MDS",
density=True,
color=colors[0],
edgecolor="white",
linewidth=0.5,
)
ax.hist(
torch_errors,
bins=50,
alpha=0.6,
label="Torch SMACOF",
density=True,
color=colors[1],
edgecolor="white",
linewidth=0.5,
)
ax.set_xlabel("Mean Error (Å)", fontsize=10)
ax.set_ylabel("Density", fontsize=10)
ax.set_title("Distribution of Errors", fontsize=12, pad=15)
# Clean up the plot
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.grid(False)
# Add statistics to legend with cleaner formatting
ax.legend(
[
f"Scipy MDS (μ={np.mean(scipy_errors):.2e})",
f"Torch SMACOF (μ={np.mean(torch_errors):.2e})",
],
frameon=True,
framealpha=0.9,
edgecolor="none",
)
return fig
[docs]
def benchmark_methods(
dms, n_iter=100, tol=1e-4, n_repeats=3, batch_size=None, verbose=False
):
"""Benchmark methods with batch processing support"""
scipy_times = []
torch_times = []
for _ in range(n_repeats):
# Time torch SMACOF with batch processing
start_time = time.perf_counter()
torch_coords, _ = distance_matrix_to_3d_structure_torch_mds(
dms,
# batch=100,
n_iter=n_iter,
tol=tol,
)
if verbose:
print("torch", time.perf_counter() - start_time)
torch_times.append(time.perf_counter() - start_time)
# Time Scipy MDS
start_time = time.perf_counter()
scipy_mds_coords = [distance_matrix_to_3d_structure_mds(dm) for dm in dms]
scipy_times.append(time.perf_counter() - start_time)
if verbose:
print("scipy", scipy_times[-1])
return {
"scipy_times": np.array(scipy_times),
"torch_times": np.array(torch_times),
"scipy_coords": scipy_mds_coords,
"torch_coords": torch_coords,
}
[docs]
def plot_timing_comparison(timing_results, n_conformers):
"""
Create a violin plot comparing the timing distributions.
Args:
timing_results: Dictionary containing timing measurements
n_conformers: Number of conformers processed
Returns:
matplotlib figure
"""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
# Modern blue and green
colors = ["#3498db", "#2ecc71"]
# Create violin plot
parts = ax1.violinplot(
[timing_results["scipy_times"], timing_results["torch_times"]], showmeans=True
)
# Set colors for individual violins
for i, body in enumerate(parts["bodies"]):
body.set_facecolor(colors[i])
body.set_edgecolor(colors[i])
body.set_alpha(0.7) # Optional transparency for better visibility
# Set the mean line color to red
parts["cmeans"].set_color("#e74c3c")
ax1.set_xticks([1, 2])
ax1.set_xticklabels(["Scipy MDS", "Torch SMACOF"])
ax1.set_ylabel("Time (seconds)")
ax1.set_title("Computation Time Distribution", pad=15)
# Remove top and right spines
ax1.spines["top"].set_visible(False)
ax1.spines["right"].set_visible(False)
# Add mean times as text
mean_scipy = np.mean(timing_results["scipy_times"])
mean_torch = np.mean(timing_results["torch_times"])
speedup_torch = mean_scipy / mean_torch
text = (
f"Mean Times:\nScipy: {mean_scipy:.3f}s\ntorch: {mean_torch:.3f}s\n\n"
f"Speedup (torch): {speedup_torch:.2f}x\n"
f"Conformers: {n_conformers}"
)
ax1.text(
0.95,
0.95,
text,
transform=ax1.transAxes,
verticalalignment="top",
horizontalalignment="right",
bbox=dict(boxstyle="round", facecolor="white", alpha=0.8),
)
# Bar plot with improved aesthetics
times_per_conformer = [
np.mean(timing_results["scipy_times"]) / n_conformers,
np.mean(timing_results["torch_times"]) / n_conformers,
]
bars = ax2.bar(
["Scipy MDS", "Torch SMACOF"],
times_per_conformer,
color=colors,
alpha=0.8,
width=0.6,
)
ax2.set_ylabel("Time per Conformer (seconds)")
ax2.set_title("Average Computation Time", pad=15)
# Remove top and right spines
ax2.spines["top"].set_visible(False)
ax2.spines["right"].set_visible(False)
ax2.grid(False)
# Add value labels on bars
for bar in bars:
height = bar.get_height()
ax2.text(
bar.get_x() + bar.get_width() / 2.0,
height,
f"{height:.3f}s",
ha="center",
va="bottom",
fontsize=9,
)
plt.tight_layout()
return fig
[docs]
def run_benchmark_comparison(dms, n_iter=100, tol=1e-4, n_repeats=5):
"""Run complete benchmark with all visualizations"""
timing_results = benchmark_methods(dms, n_iter=n_iter, tol=tol, n_repeats=n_repeats)
# Create all visualization figures
heatmap_fig = visualize_comparison(
dms[:3],
timing_results["scipy_coords"][:3],
timing_results["torch_coords"][:3],
)
dist_fig = plot_error_distributions(
dms,
timing_results["scipy_coords"],
timing_results["torch_coords"],
)
timing_fig = plot_timing_comparison(timing_results, len(dms))
# Save figures
heatmap_fig.savefig(
"mds_comparison_sample_heatmaps.pdf", dpi=300, bbox_inches="tight"
)
heatmap_fig.savefig(
"mds_comparison_sample_heatmaps.png", dpi=300, bbox_inches="tight"
)
dist_fig.savefig("mds_mean_error_distribution.pdf", dpi=300, bbox_inches="tight")
dist_fig.savefig("mds_mean_error_distribution.png", dpi=300, bbox_inches="tight")
timing_fig.savefig("mds_timing_comparison.pdf", dpi=300, bbox_inches="tight")
timing_fig.savefig("mds_timing_comparison.png", dpi=300, bbox_inches="tight")
return heatmap_fig, dist_fig, timing_fig, timing_results
if __name__ == "__main__":
# shape (200, 384, 384)
ens = generate("PLKE" * 95)
dms = ens["sequence_1"].distance_maps()
heatmap_fig, dist_fig, timing_fig, timing_results = run_benchmark_comparison(
dms, n_iter=300, tol=1e-4, n_repeats=10
)