platform-demo-scripts/scripts/util.py

120 lines
3.9 KiB
Python

"""
This script offers general functionality required in multiple places.
"""
import numpy as np
import pandas as pd
import os
import logging
import glob
import wandb
from dotenv import load_dotenv
import sys
from config_loader import models_path, configs_path
import yaml
load_dotenv()
logging.basicConfig()
logging.getLogger().setLevel(logging.INFO)
def load_best_model_data(sweep_id, weights):
"""
Determines the model with the lowest validation loss.
If sweep_id is not provided the best model is determined based on the validation loss in checkpoint filenames in weights directory.
:param sweep_id:
:param weights:
:return:
"""
if sweep_id is not None:
wandb_project_name = os.environ.get("WANDB_PROJECT")
wandb_user = os.environ.get("WANDB_USER")
api = wandb.Api()
sweep = api.sweep(f"{wandb_user}/{wandb_project_name}/{sweep_id}")
# Get best run parameters
best_run = sweep.best_run()
run_id = best_run.id
matching_models = glob.glob(f"{models_path}/{weights}/*run={run_id}*ckpt")
if len(matching_models)!=1:
raise ValueError("Unable to determine the best checkpoint for run_id: " + run_id)
best_checkpoint_path = matching_models[0]
else:
checkpoints_path = f"{models_path}/{weights}/*ckpt"
logging.debug(f"Searching for checkpoints in dir: {checkpoints_path}")
checkpoints = glob.glob(checkpoints_path)
val_losses = []
for ckpt in checkpoints:
i = ckpt.index("val_loss=")
val_losses.append(float(ckpt[i + 9:-5]))
best_checkpoint_path = checkpoints[np.argmin(val_losses)]
run_id_st = best_checkpoint_path.index("run=") + 4
run_id_end = best_checkpoint_path.index("-epoch=")
run_id = best_checkpoint_path[run_id_st:run_id_end]
return best_checkpoint_path, run_id
def load_best_model(model_cls, weights, version):
"""
Determines the model with lowest validation loss from the csv logs and loads it
:param model_cls: Class of the lightning module to load
:param weights: Path to weights as in cmd arguments
:param version: String of version file
:return: Instance of lightning module that was loaded from the best checkpoint
"""
metrics = pd.read_csv(weights / version / "metrics.csv")
idx = np.nanargmin(metrics["val_loss"])
min_row = metrics.iloc[idx]
# For default checkpoint filename, see https://github.com/Lightning-AI/lightning/pull/11805
# and https://github.com/Lightning-AI/lightning/issues/16636.
# For example, 'epoch=0-step=1.ckpt' means the 1st step has finish, but the 1st epoch hasn't
checkpoint = f"epoch={min_row['epoch']:.0f}-step={min_row['step']+1:.0f}.ckpt"
# For default save path of checkpoints, see https://github.com/Lightning-AI/lightning/pull/12372
checkpoint_path = weights / version / "checkpoints" / checkpoint
return model_cls.load_from_checkpoint(checkpoint_path)
default_workers = os.getenv("BENCHMARK_DEFAULT_WORKERS", None)
if default_workers is None:
logging.warning(
"BENCHMARK_DEFAULT_WORKERS not set. "
"Will use 12 workers if not specified otherwise in configuration."
)
default_workers = 12
else:
default_workers = int(default_workers)
def load_sweep_config(sweep_fname):
"""
Loads sweep config from yaml file
:param sweep_fname: sweep yaml file, expected to be in configs_path
:return: Dictionary containing sweep config
"""
sweep_config_path = f"{configs_path}/{sweep_fname}"
try:
with open(sweep_config_path, "r") as file:
sweep_config = yaml.load(file, Loader=yaml.FullLoader)
except FileNotFoundError:
logging.error(f"Could not find sweep config file: {sweep_fname}. "
f"Please make sure the file exists and is in {configs_path} directory.")
sys.exit(1)
return sweep_config