120 lines
3.9 KiB
Python
120 lines
3.9 KiB
Python
"""
|
|
This script offers general functionality required in multiple places.
|
|
"""
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import os
|
|
import logging
|
|
import glob
|
|
import wandb
|
|
from dotenv import load_dotenv
|
|
import sys
|
|
from config_loader import models_path, configs_path
|
|
import yaml
|
|
load_dotenv()
|
|
|
|
|
|
logging.basicConfig()
|
|
logging.getLogger().setLevel(logging.INFO)
|
|
|
|
|
|
def load_best_model_data(sweep_id, weights):
|
|
"""
|
|
Determines the model with the lowest validation loss.
|
|
If sweep_id is not provided the best model is determined based on the validation loss in checkpoint filenames in weights directory.
|
|
|
|
:param sweep_id:
|
|
:param weights:
|
|
:return:
|
|
"""
|
|
|
|
if sweep_id is not None:
|
|
wandb_project_name = os.environ.get("WANDB_PROJECT")
|
|
wandb_user = os.environ.get("WANDB_USER")
|
|
api = wandb.Api()
|
|
sweep = api.sweep(f"{wandb_user}/{wandb_project_name}/{sweep_id}")
|
|
|
|
# Get best run parameters
|
|
best_run = sweep.best_run()
|
|
run_id = best_run.id
|
|
matching_models = glob.glob(f"{models_path}/{weights}/*run={run_id}*ckpt")
|
|
if len(matching_models)!=1:
|
|
raise ValueError("Unable to determine the best checkpoint for run_id: " + run_id)
|
|
best_checkpoint_path = matching_models[0]
|
|
|
|
else:
|
|
checkpoints_path = f"{models_path}/{weights}/*ckpt"
|
|
logging.debug(f"Searching for checkpoints in dir: {checkpoints_path}")
|
|
|
|
checkpoints = glob.glob(checkpoints_path)
|
|
val_losses = []
|
|
|
|
for ckpt in checkpoints:
|
|
i = ckpt.index("val_loss=")
|
|
val_losses.append(float(ckpt[i + 9:-5]))
|
|
|
|
best_checkpoint_path = checkpoints[np.argmin(val_losses)]
|
|
run_id_st = best_checkpoint_path.index("run=") + 4
|
|
run_id_end = best_checkpoint_path.index("-epoch=")
|
|
run_id = best_checkpoint_path[run_id_st:run_id_end]
|
|
|
|
return best_checkpoint_path, run_id
|
|
|
|
|
|
def load_best_model(model_cls, weights, version):
|
|
"""
|
|
Determines the model with lowest validation loss from the csv logs and loads it
|
|
|
|
:param model_cls: Class of the lightning module to load
|
|
:param weights: Path to weights as in cmd arguments
|
|
:param version: String of version file
|
|
:return: Instance of lightning module that was loaded from the best checkpoint
|
|
"""
|
|
metrics = pd.read_csv(weights / version / "metrics.csv")
|
|
|
|
idx = np.nanargmin(metrics["val_loss"])
|
|
min_row = metrics.iloc[idx]
|
|
|
|
# For default checkpoint filename, see https://github.com/Lightning-AI/lightning/pull/11805
|
|
# and https://github.com/Lightning-AI/lightning/issues/16636.
|
|
# For example, 'epoch=0-step=1.ckpt' means the 1st step has finish, but the 1st epoch hasn't
|
|
checkpoint = f"epoch={min_row['epoch']:.0f}-step={min_row['step']+1:.0f}.ckpt"
|
|
|
|
# For default save path of checkpoints, see https://github.com/Lightning-AI/lightning/pull/12372
|
|
checkpoint_path = weights / version / "checkpoints" / checkpoint
|
|
|
|
return model_cls.load_from_checkpoint(checkpoint_path)
|
|
|
|
|
|
default_workers = os.getenv("BENCHMARK_DEFAULT_WORKERS", None)
|
|
if default_workers is None:
|
|
logging.warning(
|
|
"BENCHMARK_DEFAULT_WORKERS not set. "
|
|
"Will use 12 workers if not specified otherwise in configuration."
|
|
)
|
|
default_workers = 12
|
|
else:
|
|
default_workers = int(default_workers)
|
|
|
|
|
|
def load_sweep_config(sweep_fname):
|
|
"""
|
|
Loads sweep config from yaml file
|
|
|
|
:param sweep_fname: sweep yaml file, expected to be in configs_path
|
|
:return: Dictionary containing sweep config
|
|
"""
|
|
|
|
sweep_config_path = f"{configs_path}/{sweep_fname}"
|
|
|
|
try:
|
|
with open(sweep_config_path, "r") as file:
|
|
sweep_config = yaml.load(file, Loader=yaml.FullLoader)
|
|
except FileNotFoundError:
|
|
logging.error(f"Could not find sweep config file: {sweep_fname}. "
|
|
f"Please make sure the file exists and is in {configs_path} directory.")
|
|
sys.exit(1)
|
|
|
|
return sweep_config
|