""" ----------------- Copyright © 2023 ACK Cyfronet AGH, Poland. This work was partially funded by EPOS Project funded in frame of PL-POIR4.2 ----------------- This script runs the pipeline for the training and evaluation of the models. """ import logging import time import argparse import util import generate_eval_targets import hyperparameter_sweep import eval import collect_results import importlib import config_loader import input_validate import data logging.root.setLevel(logging.INFO) logger = logging.getLogger('pipeline') def load_sweep_config(model_name, args): if model_name == "PhaseNet" and args.phasenet_config is not None: sweep_fname = args.phasenet_config elif model_name == "GPD" and args.gpd_config is not None: sweep_fname = args.gpd_config elif model_name == "BasicPhaseAE" and args.basic_phase_ae_config is not None: sweep_fname = args.basic_phase_ae_config elif model_name == "EQTransformer" and args.eqtransformer_config is not None: sweep_fname = args.eqtransformer_config else: # use the default sweep config for the model sweep_fname = config_loader.sweep_files[model_name] logger.info(f"Loading sweep config: {sweep_fname}") return util.load_sweep_config(sweep_fname) def validate_pipeline_input(args): # validate input parameters for model_name in args.models: sweep_config = load_sweep_config(model_name, args) input_validate.validate_sweep_config(sweep_config, model_name) # validate dataset data.validate_custom_dataset(config_loader.data_path) def find_the_best_params(sweep_config): # find the best hyperparams for the model_name model_name = sweep_config['parameters']['model_name'] logger.info(f"Starting searching for the best hyperparams for the model: {model_name}") sweep_runner = hyperparameter_sweep.start_sweep(sweep_config) # wait for all runs to finish all_finished = sweep_runner.all_runs_finished() while not all_finished: logger.info("Waiting for sweep runs to finish...") # Sleep for a few seconds before checking again time.sleep(30) all_finished = sweep_runner.all_runs_finished() logger.info(f"Finished the sweep: {sweep_runner.sweep_id}") return sweep_runner.sweep_id def generate_predictions(sweep_id, model_name): experiment_name = f"{config_loader.dataset_name}_{model_name}" eval.main(weights=experiment_name, targets=config_loader.targets_path, sets='dev,test', batchsize=128, num_workers=4, # sampling_rate=sampling_rate, sweep_id=sweep_id ) def main(): parser = argparse.ArgumentParser() parser.add_argument("--phasenet_config", type=str, required=False) parser.add_argument("--gpd_config", type=str, required=False) parser.add_argument("--basic_phase_ae_config", type=str, required=False) parser.add_argument("--eqtransformer_config", type=str, required=False) parser.add_argument("--dataset", type=str, required=False) available_models = ["GPD", "PhaseNet", "BasicPhaseAE", "EQTransformer"] parser.add_argument("--models", nargs='*', required=False, choices=available_models, default=available_models, help="Models to train and evaluate (default: all)") parser.add_argument("--collect_results", action="store_true", help="Collect and log results without training") args = parser.parse_args() if not args.collect_results: if args.dataset is not None: util.set_dataset(args.dataset) importlib.reload(config_loader) validate_pipeline_input(args) logger.info(f"Started pipeline for the {config_loader.dataset_name} dataset.") # generate labels logger.info("Started generating labels for the dataset.") generate_eval_targets.main(config_loader.data_path, config_loader.targets_path, "2,3", config_loader.sampling_rate, None) # find the best hyperparams for the models logger.info("Started training the models.") for model_name in args.models: sweep_config = load_sweep_config(model_name, args) sweep_id = find_the_best_params(sweep_config) generate_predictions(sweep_id, model_name) # collect results logger.info("Collecting results.") results_path = "pred/results.csv" collect_results.traverse_path("pred", results_path) logger.info(f"Results saved in {results_path}") # log calculated metrics (MAE) on w&b logger.info("Logging MAE metrics on w&b.") util.log_metrics(results_path) logger.info("Pipeline finished") if __name__ == "__main__": main()