""" This script offers general functionality required in multiple places. """ import numpy as np import pandas as pd import os import logging import glob import wandb from dotenv import load_dotenv import sys from config_loader import models_path, configs_path import yaml load_dotenv() logging.basicConfig() logging.getLogger().setLevel(logging.INFO) def load_best_model_data(sweep_id, weights): """ Determines the model with the lowest validation loss. If sweep_id is not provided the best model is determined based on the validation loss in checkpoint filenames in weights directory. :param sweep_id: :param weights: :return: """ if sweep_id is not None: wandb_project_name = os.environ.get("WANDB_PROJECT") wandb_user = os.environ.get("WANDB_USER") api = wandb.Api() sweep = api.sweep(f"{wandb_user}/{wandb_project_name}/{sweep_id}") # Get best run parameters best_run = sweep.best_run() run_id = best_run.id matching_models = glob.glob(f"{models_path}/{weights}/*run={run_id}*ckpt") if len(matching_models)!=1: raise ValueError("Unable to determine the best checkpoint for run_id: " + run_id) best_checkpoint_path = matching_models[0] else: checkpoints_path = f"{models_path}/{weights}/*ckpt" logging.debug(f"Searching for checkpoints in dir: {checkpoints_path}") checkpoints = glob.glob(checkpoints_path) val_losses = [] for ckpt in checkpoints: i = ckpt.index("val_loss=") val_losses.append(float(ckpt[i + 9:-5])) best_checkpoint_path = checkpoints[np.argmin(val_losses)] run_id_st = best_checkpoint_path.index("run=") + 4 run_id_end = best_checkpoint_path.index("-epoch=") run_id = best_checkpoint_path[run_id_st:run_id_end] return best_checkpoint_path, run_id def load_best_model(model_cls, weights, version): """ Determines the model with lowest validation loss from the csv logs and loads it :param model_cls: Class of the lightning module to load :param weights: Path to weights as in cmd arguments :param version: String of version file :return: Instance of lightning module that was loaded from the best checkpoint """ metrics = pd.read_csv(weights / version / "metrics.csv") idx = np.nanargmin(metrics["val_loss"]) min_row = metrics.iloc[idx] # For default checkpoint filename, see https://github.com/Lightning-AI/lightning/pull/11805 # and https://github.com/Lightning-AI/lightning/issues/16636. # For example, 'epoch=0-step=1.ckpt' means the 1st step has finish, but the 1st epoch hasn't checkpoint = f"epoch={min_row['epoch']:.0f}-step={min_row['step']+1:.0f}.ckpt" # For default save path of checkpoints, see https://github.com/Lightning-AI/lightning/pull/12372 checkpoint_path = weights / version / "checkpoints" / checkpoint return model_cls.load_from_checkpoint(checkpoint_path) default_workers = os.getenv("BENCHMARK_DEFAULT_WORKERS", None) if default_workers is None: logging.warning( "BENCHMARK_DEFAULT_WORKERS not set. " "Will use 12 workers if not specified otherwise in configuration." ) default_workers = 12 else: default_workers = int(default_workers) def load_sweep_config(sweep_fname): """ Loads sweep config from yaml file :param sweep_fname: sweep yaml file, expected to be in configs_path :return: Dictionary containing sweep config """ sweep_config_path = f"{configs_path}/{sweep_fname}" try: with open(sweep_config_path, "r") as file: sweep_config = yaml.load(file, Loader=yaml.FullLoader) except FileNotFoundError: logging.error(f"Could not find sweep config file: {sweep_fname}. " f"Please make sure the file exists and is in {configs_path} directory.") sys.exit(1) return sweep_config