import argparse
import glob
import os
import hydra
import pytorch_lightning as pl
import torch
import wandb
import yaml
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from starling.models.vae import VAE
[docs]
@rank_zero_only
def wandb_init(project: str = "starling", id=None):
wandb.init(project=project, resume="allow", id=id)
[docs]
def setup_directories(output_path):
"""Create necessary directories and save the configuration file."""
os.makedirs(output_path, exist_ok=True)
[docs]
def save_config(config, output_path):
"""Save the configuration to a YAML file."""
with open(f"{output_path}/config.yaml", "w") as f:
yaml.dump(config, f)
[docs]
def setup_checkpoints(output_path):
"""Set up model checkpoint callbacks."""
checkpoint_callback = ModelCheckpoint(
monitor="epoch_val_loss",
dirpath=output_path,
filename="model-kernel-{epoch:02d}-{epoch_val_loss:.2f}",
save_top_k=-1,
mode="min",
)
save_last_checkpoint = ModelCheckpoint(
dirpath=output_path,
filename="last",
)
return checkpoint_callback, save_last_checkpoint
[docs]
def get_checkpoint_path(output_path):
"""Determine the checkpoint path to resume training if available."""
checkpoint_pattern = os.path.join(output_path, "last.ckpt")
checkpoint_files = glob.glob(checkpoint_pattern)
return "last" if checkpoint_files else None
[docs]
def setup_data_module(cfg, effective_batch_size=None):
"""Set up the data module for VAE training."""
if cfg.dataloader.type == "h5":
dataloader_config = cfg.dataloader.h5
dataset = instantiate(dataloader_config)
dataset.setup(stage="fit")
elif cfg.dataloader.type == "tar":
from starling.data.VAE_loader_tar import VAEdataloader
dataset = VAEdataloader(
config=cfg.dataloader.tar, effective_batch_size=effective_batch_size
)
dataset.setup(stage="fit")
else:
raise ValueError(f"Unsupported dataloader type: {cfg.dataloader.type}")
return dataset
[docs]
def setup_vae_model(cfg):
"""Set up the VAE model, with support for resuming from checkpoint or fine-tuning with custom args."""
model_path = cfg.trainer.get("checkpoint", None)
if cfg.trainer.get("fine_tune", False) and model_path:
print(f"Fine-tuning VAE from checkpoint: {model_path}")
# First instantiate with custom arguments
vae = instantiate(cfg.vae_model)
# Load state dict from checkpoint
checkpoint = torch.load(model_path, map_location="cpu")
state_dict = checkpoint["state_dict"]
# Load only the weights, ignoring missing or extra keys
vae.load_state_dict(state_dict, strict=True)
print("Loaded checkpoint weights with custom model configuration")
else:
vae = instantiate(cfg.vae_model)
return vae
@hydra.main(
version_base=None,
config_path=os.path.join(os.path.dirname(__file__), "../configs"),
config_name="vae_configs",
)
def train_vae(cfg: DictConfig):
"""Train a VAE model using the configuration specified by Hydra.
Supports:
- Training from scratch
- Resuming training from a checkpoint
- Fine-tuning from a checkpoint
Args:
cfg: The configuration object loaded by Hydra
"""
# Setup directories and save config
output_path = cfg.trainer.output_path
os.makedirs(output_path, exist_ok=True)
# Save the config for reference
OmegaConf.save(cfg, f"{output_path}/config.yaml")
# Initialize WandB
wandb_init(cfg.trainer.project_name, id=cfg.trainer.get("wandb_id", None))
"""Set up model checkpoint callbacks."""
checkpoint_callback = ModelCheckpoint(
monitor="epoch_val_loss",
dirpath=output_path,
filename="model-kernel-{epoch:02d}-{epoch_val_loss:.2f}",
save_top_k=1,
mode="min",
)
save_last_checkpoint = ModelCheckpoint(
dirpath=output_path,
filename="last",
)
# Determine if we should resume from checkpoint
ckpt_path = cfg.trainer.get("checkpoint", None)
if not cfg.trainer.get("fine_tune", False) and ckpt_path is None:
# Try to find the last checkpoint if not explicitly provided
last_checkpoint = os.path.join(output_path, "last.ckpt")
if os.path.exists(last_checkpoint):
ckpt_path = last_checkpoint
print(f"Resuming from last checkpoint: {ckpt_path}")
# Setup data module
if cfg.dataloader.type == "tar":
effective_batch_size = (
cfg.trainer.cuda * cfg.trainer.num_nodes * cfg.dataloader.tar.batch_size
)
else:
effective_batch_size = None
dataset = setup_data_module(cfg, effective_batch_size=effective_batch_size)
# Setup VAE model
vae = setup_vae_model(cfg)
# Save model architecture
with open(f"{output_path}/model_architecture.txt", "w") as f:
f.write(str(vae))
# Setup logger
wandb_logger = WandbLogger(project=cfg.trainer.project_name)
wandb_logger.watch(vae)
# Setup trainer with all callbacks
trainer = pl.Trainer(
accelerator="auto",
devices=cfg.trainer.cuda,
num_nodes=cfg.trainer.num_nodes,
max_epochs=cfg.trainer.num_epochs,
callbacks=[
checkpoint_callback,
save_last_checkpoint,
LearningRateMonitor(logging_interval="step"),
],
precision=cfg.trainer.precision,
logger=wandb_logger,
)
# Start training (don't pass checkpoint path when fine-tuning)
trainer.fit(
vae,
dataset,
ckpt_path=None if cfg.trainer.get("fine_tune", False) else ckpt_path,
)
# Detach WandB logging
wandb_logger.experiment.unwatch(vae)
wandb.finish()
if __name__ == "__main__":
train_vae()