93 lines
2.7 KiB
Python
93 lines
2.7 KiB
Python
"""
|
|
-----------------
|
|
Copyright © 2023 ACK Cyfronet AGH, Poland.
|
|
This work was partially funded by EPOS Project funded in frame of PL-POIR4.2
|
|
-----------------
|
|
|
|
This script runs the pipeline for the training and evaluation of the models.
|
|
"""
|
|
import logging
|
|
import time
|
|
import argparse
|
|
|
|
import util
|
|
import generate_eval_targets
|
|
import hyperparameter_sweep
|
|
import eval
|
|
import collect_results
|
|
from config_loader import data_path, targets_path, sampling_rate, dataset_name, sweep_files
|
|
|
|
logger = logging.getLogger('pipeline')
|
|
logger.setLevel(logging.INFO)
|
|
|
|
|
|
def load_sweep_config(model_name, args):
|
|
|
|
if model_name == "PhaseNet" and args.phasenet_config is not None:
|
|
sweep_fname = args.phasenet_config
|
|
elif model_name == "GPD" and args.gpd_config is not None:
|
|
sweep_fname = args.gpd_config
|
|
else:
|
|
# use the default sweep config for the model
|
|
sweep_fname = sweep_files[model_name]
|
|
|
|
logger.info(f"Loading sweep config: {sweep_fname}")
|
|
|
|
return util.load_sweep_config(sweep_fname)
|
|
|
|
|
|
def find_the_best_params(model_name, args):
|
|
|
|
# find the best hyperparams for the model_name
|
|
logger.info(f"Starting searching for the best hyperparams for the model: {model_name}")
|
|
|
|
sweep_config = load_sweep_config(model_name, args)
|
|
sweep_runner = hyperparameter_sweep.start_sweep(sweep_config)
|
|
|
|
# wait for all runs to finish
|
|
|
|
all_finished = sweep_runner.all_runs_finished()
|
|
while not all_finished:
|
|
logger.info("Waiting for sweep runs to finish...")
|
|
# Sleep for a few seconds before checking again
|
|
time.sleep(30)
|
|
all_finished = sweep_runner.all_runs_finished()
|
|
|
|
logger.info(f"Finished the sweep: {sweep_runner.sweep_id}")
|
|
return sweep_runner.sweep_id
|
|
|
|
|
|
def generate_predictions(sweep_id, model_name):
|
|
experiment_name = f"{dataset_name}_{model_name}"
|
|
eval.main(weights=experiment_name,
|
|
targets=targets_path,
|
|
sets='dev,test',
|
|
batchsize=128,
|
|
num_workers=4,
|
|
# sampling_rate=sampling_rate,
|
|
sweep_id=sweep_id
|
|
)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--phasenet_config", type=str, required=False)
|
|
parser.add_argument("--gpd_config", type=str, required=False)
|
|
args = parser.parse_args()
|
|
|
|
# generate labels
|
|
generate_eval_targets.main(data_path, targets_path, "2,3", sampling_rate, None)
|
|
|
|
# find the best hyperparams for the models
|
|
for model_name in ["GPD", "PhaseNet"]:
|
|
sweep_id = find_the_best_params(model_name, args)
|
|
generate_predictions(sweep_id, model_name)
|
|
|
|
# collect results
|
|
collect_results.traverse_path("pred", "pred/results.csv")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|