""" ----------------- Copyright © 2023 ACK Cyfronet AGH, Poland. This work was partially funded by EPOS Project funded in frame of PL-POIR4.2 ----------------- This script runs the pipeline for the training and evaluation of the models. """ import numpy as np import pandas as pd import os import logging import glob import json import wandb from dotenv import load_dotenv import sys from config_loader import models_path, configs_path, config_path import yaml load_dotenv() logging.basicConfig() logging.getLogger().setLevel(logging.INFO) def load_best_model_data(sweep_id, weights): """ Determines the model with the lowest validation loss. If sweep_id is not provided the best model is determined based on the validation loss in checkpoint filenames in weights directory. :param sweep_id: :param weights: :return: """ if sweep_id is not None: wandb_project_name = os.environ.get("WANDB_PROJECT") wandb_user = os.environ.get("WANDB_USER") api = wandb.Api() sweep = api.sweep(f"{wandb_user}/{wandb_project_name}/{sweep_id}") # Get best run parameters best_run = sweep.best_run() run_id = best_run.id run = api.run(f"{wandb_user}/{wandb_project_name}/runs/{run_id}") dataset = run.config["dataset_name"] model = run.config["model_name"][0] experiment = f"{dataset}_{model}" checkpoints_path = f"{models_path}/{experiment}/*run={run_id}*ckpt" logging.debug(f"Searching for checkpoints in dir: {checkpoints_path}") matching_models = glob.glob(checkpoints_path) if len(matching_models) != 1: raise ValueError("Unable to determine the best checkpoint for run_id: " + run_id) best_checkpoint_path = matching_models[0] else: checkpoints_path = f"{models_path}/{weights}/*ckpt" logging.debug(f"Searching for checkpoints in dir: {checkpoints_path}") checkpoints = glob.glob(checkpoints_path) val_losses = [] for ckpt in checkpoints: i = ckpt.index("val_loss=") val_losses.append(float(ckpt[i + 9:-5])) best_checkpoint_path = checkpoints[np.argmin(val_losses)] run_id_st = best_checkpoint_path.index("run=") + 4 run_id_end = best_checkpoint_path.index("-epoch=") run_id = best_checkpoint_path[run_id_st:run_id_end] return best_checkpoint_path, run_id default_workers = os.getenv("BENCHMARK_DEFAULT_WORKERS", None) if default_workers is None: logging.warning( "BENCHMARK_DEFAULT_WORKERS not set. " "Will use 12 workers if not specified otherwise in configuration." ) default_workers = 12 else: default_workers = int(default_workers) def load_sweep_config(sweep_fname): """ Loads sweep config from yaml file :param sweep_fname: sweep yaml file, expected to be in configs_path :return: Dictionary containing sweep config """ sweep_config_path = f"{configs_path}/{sweep_fname}" try: with open(sweep_config_path, "r") as file: sweep_config = yaml.load(file, Loader=yaml.FullLoader) except FileNotFoundError: logging.error(f"Could not find sweep config file: {sweep_fname}. " f"Please make sure the file exists and is in {configs_path} directory.") sys.exit(1) return sweep_config def log_metrics(results_file): """ :param results_file: csv file with calculated metrics :return: """ api = wandb.Api() wandb_project_name = os.environ.get("WANDB_PROJECT") wandb_user = os.environ.get("WANDB_USER") results = pd.read_csv(results_file) for run_id in results["version"].unique(): try: run = api.run(f"{wandb_user}/{wandb_project_name}/{run_id}") metrics_to_log = {} run_results = results[results["version"] == run_id] for col in run_results.columns: if 'mae' in col: metrics_to_log[col] = run_results[col].values[0] run.summary[col] = run_results[col].values[0] run.summary.update() logging.info(f"Logged metrics for run: {run_id}, {metrics_to_log}") except Exception as e: print(f"An error occurred: {e}, {type(e).__name__}, {e.args}") def set_dataset(dataset_name): """ Sets the dataset name in the config file :param dataset_name: :return: """ with open(config_path, "r+") as f: config = json.load(f) config["dataset_name"] = dataset_name config["data_path"] = f"datasets/{dataset_name}/seisbench_format/" f.seek(0) # rewind json.dump(config, f, indent=4) f.truncate()