""" ----------------- Copyright © 2023 ACK Cyfronet AGH, Poland. This work was partially funded by EPOS Project funded in frame of PL-POIR4.2 ----------------- This script runs the pipeline for the training and evaluation of the models. """ import logging import time import argparse import util import generate_eval_targets import hyperparameter_sweep import eval import collect_results from config_loader import data_path, targets_path, sampling_rate, dataset_name, sweep_files logging.root.setLevel(logging.INFO) logger = logging.getLogger('pipeline') def load_sweep_config(model_name, args): if model_name == "PhaseNet" and args.phasenet_config is not None: sweep_fname = args.phasenet_config elif model_name == "GPD" and args.gpd_config is not None: sweep_fname = args.gpd_config else: # use the default sweep config for the model sweep_fname = sweep_files[model_name] logger.info(f"Loading sweep config: {sweep_fname}") return util.load_sweep_config(sweep_fname) def find_the_best_params(model_name, args): # find the best hyperparams for the model_name logger.info(f"Starting searching for the best hyperparams for the model: {model_name}") sweep_config = load_sweep_config(model_name, args) sweep_runner = hyperparameter_sweep.start_sweep(sweep_config) # wait for all runs to finish all_finished = sweep_runner.all_runs_finished() while not all_finished: logger.info("Waiting for sweep runs to finish...") # Sleep for a few seconds before checking again time.sleep(30) all_finished = sweep_runner.all_runs_finished() logger.info(f"Finished the sweep: {sweep_runner.sweep_id}") return sweep_runner.sweep_id def generate_predictions(sweep_id, model_name): experiment_name = f"{dataset_name}_{model_name}" eval.main(weights=experiment_name, targets=targets_path, sets='dev,test', batchsize=128, num_workers=4, # sampling_rate=sampling_rate, sweep_id=sweep_id ) def main(): parser = argparse.ArgumentParser() parser.add_argument("--phasenet_config", type=str, required=False) parser.add_argument("--gpd_config", type=str, required=False) args = parser.parse_args() # generate labels logger.info("Started generating labels for the dataset.") generate_eval_targets.main(data_path, targets_path, "2,3", sampling_rate, None) # find the best hyperparams for the models logger.info("Started training the models.") for model_name in ["GPD", "PhaseNet"]: sweep_id = find_the_best_params(model_name, args) generate_predictions(sweep_id, model_name) # collect results logger.info("Collecting results.") collect_results.traverse_path("pred", "pred/results.csv") logger.info("Results saved in pred/results.csv") if __name__ == "__main__": main()