2023-08-29 09:59:31 +02:00
|
|
|
"""
|
2023-10-12 14:27:53 +02:00
|
|
|
-----------------
|
|
|
|
Copyright © 2023 ACK Cyfronet AGH, Poland.
|
|
|
|
This work was partially funded by EPOS Project funded in frame of PL-POIR4.2
|
|
|
|
-----------------
|
|
|
|
|
|
|
|
This script runs the pipeline for the training and evaluation of the models.
|
2023-08-29 09:59:31 +02:00
|
|
|
"""
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import pandas as pd
|
|
|
|
import os
|
|
|
|
import logging
|
|
|
|
import glob
|
2023-10-12 14:27:53 +02:00
|
|
|
import json
|
2023-08-29 09:59:31 +02:00
|
|
|
import wandb
|
2023-10-12 14:27:53 +02:00
|
|
|
|
2023-08-29 09:59:31 +02:00
|
|
|
from dotenv import load_dotenv
|
|
|
|
import sys
|
2023-10-12 14:27:53 +02:00
|
|
|
from config_loader import models_path, configs_path, config_path
|
2023-08-29 09:59:31 +02:00
|
|
|
import yaml
|
|
|
|
|
2023-10-12 14:27:53 +02:00
|
|
|
load_dotenv()
|
2023-08-29 09:59:31 +02:00
|
|
|
|
|
|
|
logging.basicConfig()
|
2023-09-26 10:50:46 +02:00
|
|
|
logging.getLogger().setLevel(logging.INFO)
|
2023-08-29 09:59:31 +02:00
|
|
|
|
|
|
|
|
|
|
|
def load_best_model_data(sweep_id, weights):
|
|
|
|
"""
|
|
|
|
Determines the model with the lowest validation loss.
|
|
|
|
If sweep_id is not provided the best model is determined based on the validation loss in checkpoint filenames in weights directory.
|
|
|
|
|
|
|
|
:param sweep_id:
|
|
|
|
:param weights:
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
|
|
|
|
if sweep_id is not None:
|
|
|
|
wandb_project_name = os.environ.get("WANDB_PROJECT")
|
|
|
|
wandb_user = os.environ.get("WANDB_USER")
|
|
|
|
api = wandb.Api()
|
|
|
|
sweep = api.sweep(f"{wandb_user}/{wandb_project_name}/{sweep_id}")
|
|
|
|
|
|
|
|
# Get best run parameters
|
|
|
|
best_run = sweep.best_run()
|
|
|
|
run_id = best_run.id
|
2023-10-12 14:27:53 +02:00
|
|
|
|
|
|
|
run = api.run(f"{wandb_user}/{wandb_project_name}/runs/{run_id}")
|
|
|
|
dataset = run.config["dataset_name"]
|
|
|
|
model = run.config["model_name"][0]
|
|
|
|
experiment = f"{dataset}_{model}"
|
|
|
|
|
|
|
|
checkpoints_path = f"{models_path}/{experiment}/*run={run_id}*ckpt"
|
|
|
|
logging.debug(f"Searching for checkpoints in dir: {checkpoints_path}")
|
|
|
|
matching_models = glob.glob(checkpoints_path)
|
|
|
|
if len(matching_models) != 1:
|
2023-08-29 09:59:31 +02:00
|
|
|
raise ValueError("Unable to determine the best checkpoint for run_id: " + run_id)
|
|
|
|
best_checkpoint_path = matching_models[0]
|
|
|
|
|
|
|
|
else:
|
|
|
|
checkpoints_path = f"{models_path}/{weights}/*ckpt"
|
|
|
|
logging.debug(f"Searching for checkpoints in dir: {checkpoints_path}")
|
|
|
|
|
|
|
|
checkpoints = glob.glob(checkpoints_path)
|
|
|
|
val_losses = []
|
|
|
|
|
|
|
|
for ckpt in checkpoints:
|
|
|
|
i = ckpt.index("val_loss=")
|
|
|
|
val_losses.append(float(ckpt[i + 9:-5]))
|
|
|
|
|
|
|
|
best_checkpoint_path = checkpoints[np.argmin(val_losses)]
|
|
|
|
run_id_st = best_checkpoint_path.index("run=") + 4
|
|
|
|
run_id_end = best_checkpoint_path.index("-epoch=")
|
|
|
|
run_id = best_checkpoint_path[run_id_st:run_id_end]
|
|
|
|
|
|
|
|
return best_checkpoint_path, run_id
|
|
|
|
|
|
|
|
|
|
|
|
default_workers = os.getenv("BENCHMARK_DEFAULT_WORKERS", None)
|
|
|
|
if default_workers is None:
|
|
|
|
logging.warning(
|
|
|
|
"BENCHMARK_DEFAULT_WORKERS not set. "
|
|
|
|
"Will use 12 workers if not specified otherwise in configuration."
|
|
|
|
)
|
|
|
|
default_workers = 12
|
|
|
|
else:
|
|
|
|
default_workers = int(default_workers)
|
|
|
|
|
|
|
|
|
|
|
|
def load_sweep_config(sweep_fname):
|
|
|
|
"""
|
|
|
|
Loads sweep config from yaml file
|
|
|
|
|
|
|
|
:param sweep_fname: sweep yaml file, expected to be in configs_path
|
|
|
|
:return: Dictionary containing sweep config
|
|
|
|
"""
|
|
|
|
|
|
|
|
sweep_config_path = f"{configs_path}/{sweep_fname}"
|
|
|
|
|
|
|
|
try:
|
|
|
|
with open(sweep_config_path, "r") as file:
|
|
|
|
sweep_config = yaml.load(file, Loader=yaml.FullLoader)
|
|
|
|
except FileNotFoundError:
|
|
|
|
logging.error(f"Could not find sweep config file: {sweep_fname}. "
|
|
|
|
f"Please make sure the file exists and is in {configs_path} directory.")
|
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
return sweep_config
|
2023-10-12 14:27:53 +02:00
|
|
|
|
|
|
|
|
|
|
|
def log_metrics(results_file):
|
|
|
|
"""
|
|
|
|
|
|
|
|
:param results_file: csv file with calculated metrics
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
|
|
|
|
api = wandb.Api()
|
|
|
|
wandb_project_name = os.environ.get("WANDB_PROJECT")
|
|
|
|
wandb_user = os.environ.get("WANDB_USER")
|
|
|
|
|
|
|
|
results = pd.read_csv(results_file)
|
|
|
|
for run_id in results["version"].unique():
|
|
|
|
try:
|
|
|
|
run = api.run(f"{wandb_user}/{wandb_project_name}/{run_id}")
|
|
|
|
metrics_to_log = {}
|
|
|
|
run_results = results[results["version"] == run_id]
|
|
|
|
for col in run_results.columns:
|
|
|
|
if 'mae' in col:
|
|
|
|
metrics_to_log[col] = run_results[col].values[0]
|
|
|
|
run.summary[col] = run_results[col].values[0]
|
|
|
|
|
|
|
|
run.summary.update()
|
|
|
|
logging.info(f"Logged metrics for run: {run_id}, {metrics_to_log}")
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
print(f"An error occurred: {e}, {type(e).__name__}, {e.args}")
|
|
|
|
|
|
|
|
|
|
|
|
def set_dataset(dataset_name):
|
|
|
|
"""
|
|
|
|
Sets the dataset name in the config file
|
|
|
|
:param dataset_name:
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
|
|
|
|
with open(config_path, "r+") as f:
|
|
|
|
config = json.load(f)
|
|
|
|
config["dataset_name"] = dataset_name
|
|
|
|
config["data_path"] = f"datasets/{dataset_name}/seisbench_format/"
|
|
|
|
|
|
|
|
f.seek(0) # rewind
|
|
|
|
json.dump(config, f, indent=4)
|
|
|
|
f.truncate()
|
|
|
|
|
|
|
|
|