platform-demo-scripts/scripts/util.py

158 lines
4.7 KiB
Python
Raw Normal View History

"""
-----------------
Copyright © 2023 ACK Cyfronet AGH, Poland.
This work was partially funded by EPOS Project funded in frame of PL-POIR4.2
-----------------
This script runs the pipeline for the training and evaluation of the models.
"""
import numpy as np
import pandas as pd
import os
import logging
import glob
import json
import wandb
from dotenv import load_dotenv
import sys
from config_loader import models_path, configs_path, config_path
import yaml
load_dotenv()
logging.basicConfig()
logging.getLogger().setLevel(logging.INFO)
def load_best_model_data(sweep_id, weights):
"""
Determines the model with the lowest validation loss.
If sweep_id is not provided the best model is determined based on the validation loss in checkpoint filenames in weights directory.
:param sweep_id:
:param weights:
:return:
"""
if sweep_id is not None:
wandb_project_name = os.environ.get("WANDB_PROJECT")
wandb_user = os.environ.get("WANDB_USER")
api = wandb.Api()
sweep = api.sweep(f"{wandb_user}/{wandb_project_name}/{sweep_id}")
# Get best run parameters
best_run = sweep.best_run()
run_id = best_run.id
run = api.run(f"{wandb_user}/{wandb_project_name}/runs/{run_id}")
dataset = run.config["dataset_name"]
model = run.config["model_name"][0]
experiment = f"{dataset}_{model}"
checkpoints_path = f"{models_path}/{experiment}/*run={run_id}*ckpt"
logging.debug(f"Searching for checkpoints in dir: {checkpoints_path}")
matching_models = glob.glob(checkpoints_path)
if len(matching_models) != 1:
raise ValueError("Unable to determine the best checkpoint for run_id: " + run_id)
best_checkpoint_path = matching_models[0]
else:
checkpoints_path = f"{models_path}/{weights}/*ckpt"
logging.debug(f"Searching for checkpoints in dir: {checkpoints_path}")
checkpoints = glob.glob(checkpoints_path)
val_losses = []
for ckpt in checkpoints:
i = ckpt.index("val_loss=")
val_losses.append(float(ckpt[i + 9:-5]))
best_checkpoint_path = checkpoints[np.argmin(val_losses)]
run_id_st = best_checkpoint_path.index("run=") + 4
run_id_end = best_checkpoint_path.index("-epoch=")
run_id = best_checkpoint_path[run_id_st:run_id_end]
return best_checkpoint_path, run_id
default_workers = os.getenv("BENCHMARK_DEFAULT_WORKERS", None)
if default_workers is None:
logging.warning(
"BENCHMARK_DEFAULT_WORKERS not set. "
"Will use 12 workers if not specified otherwise in configuration."
)
default_workers = 12
else:
default_workers = int(default_workers)
def load_sweep_config(sweep_fname):
"""
Loads sweep config from yaml file
:param sweep_fname: sweep yaml file, expected to be in configs_path
:return: Dictionary containing sweep config
"""
sweep_config_path = f"{configs_path}/{sweep_fname}"
try:
with open(sweep_config_path, "r") as file:
sweep_config = yaml.load(file, Loader=yaml.FullLoader)
except FileNotFoundError:
logging.error(f"Could not find sweep config file: {sweep_fname}. "
f"Please make sure the file exists and is in {configs_path} directory.")
sys.exit(1)
return sweep_config
def log_metrics(results_file):
"""
:param results_file: csv file with calculated metrics
:return:
"""
api = wandb.Api()
wandb_project_name = os.environ.get("WANDB_PROJECT")
wandb_user = os.environ.get("WANDB_USER")
results = pd.read_csv(results_file)
for run_id in results["version"].unique():
try:
run = api.run(f"{wandb_user}/{wandb_project_name}/{run_id}")
metrics_to_log = {}
run_results = results[results["version"] == run_id]
for col in run_results.columns:
if 'mae' in col:
metrics_to_log[col] = run_results[col].values[0]
run.summary[col] = run_results[col].values[0]
run.summary.update()
logging.info(f"Logged metrics for run: {run_id}, {metrics_to_log}")
except Exception as e:
print(f"An error occurred: {e}, {type(e).__name__}, {e.args}")
def set_dataset(dataset_name):
"""
Sets the dataset name in the config file
:param dataset_name:
:return:
"""
with open(config_path, "r+") as f:
config = json.load(f)
config["dataset_name"] = dataset_name
config["data_path"] = f"datasets/{dataset_name}/seisbench_format/"
f.seek(0) # rewind
json.dump(config, f, indent=4)
f.truncate()