platform-demo-scripts/scripts/pipeline.py

123 lines
4.3 KiB
Python

"""
-----------------
Copyright © 2023 ACK Cyfronet AGH, Poland.
This work was partially funded by EPOS Project funded in frame of PL-POIR4.2
-----------------
This script runs the pipeline for the training and evaluation of the models.
"""
import logging
import time
import argparse
import util
import generate_eval_targets
import hyperparameter_sweep
import eval
import collect_results
import importlib
import config_loader
logging.root.setLevel(logging.INFO)
logger = logging.getLogger('pipeline')
def load_sweep_config(model_name, args):
if model_name == "PhaseNet" and args.phasenet_config is not None:
sweep_fname = args.phasenet_config
elif model_name == "GPD" and args.gpd_config is not None:
sweep_fname = args.gpd_config
elif model_name == "BasicPhaseAE" and args.basic_phase_ae_config is not None:
sweep_fname = args.basic_phase_ae_config
elif model_name == "EQTransformer" and args.eqtransformer_config is not None:
sweep_fname = args.eqtransformer_config
else:
# use the default sweep config for the model
sweep_fname = config_loader.sweep_files[model_name]
logger.info(f"Loading sweep config: {sweep_fname}")
return util.load_sweep_config(sweep_fname)
def find_the_best_params(model_name, args):
# find the best hyperparams for the model_name
logger.info(f"Starting searching for the best hyperparams for the model: {model_name}")
sweep_config = load_sweep_config(model_name, args)
sweep_runner = hyperparameter_sweep.start_sweep(sweep_config)
# wait for all runs to finish
all_finished = sweep_runner.all_runs_finished()
while not all_finished:
logger.info("Waiting for sweep runs to finish...")
# Sleep for a few seconds before checking again
time.sleep(30)
all_finished = sweep_runner.all_runs_finished()
logger.info(f"Finished the sweep: {sweep_runner.sweep_id}")
return sweep_runner.sweep_id
def generate_predictions(sweep_id, model_name):
experiment_name = f"{config_loader.dataset_name}_{model_name}"
eval.main(weights=experiment_name,
targets=config_loader.targets_path,
sets='dev,test',
batchsize=128,
num_workers=4,
# sampling_rate=sampling_rate,
sweep_id=sweep_id
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--phasenet_config", type=str, required=False)
parser.add_argument("--gpd_config", type=str, required=False)
parser.add_argument("--basic_phase_ae_config", type=str, required=False)
parser.add_argument("--eqtransformer_config", type=str, required=False)
parser.add_argument("--dataset", type=str, required=False)
available_models = ["GPD", "PhaseNet", "BasicPhaseAE", "EQTransformer"]
parser.add_argument("--models", nargs='*', required=False, choices=available_models, default=available_models,
help="Models to train and evaluate (default: all)")
parser.add_argument("--collect_results", action="store_true", help="Collect and log results without training")
args = parser.parse_args()
if not args.collect_results:
if args.dataset is not None:
util.set_dataset(args.dataset)
importlib.reload(config_loader)
logger.info(f"Started pipeline for the {config_loader.dataset_name} dataset.")
# generate labels
logger.info("Started generating labels for the dataset.")
generate_eval_targets.main(config_loader.data_path, config_loader.targets_path, "2,3", config_loader.sampling_rate,
None)
# find the best hyperparams for the models
logger.info("Started training the models.")
for model_name in args.models:
sweep_id = find_the_best_params(model_name, args)
generate_predictions(sweep_id, model_name)
# collect results
logger.info("Collecting results.")
results_path = "pred/results.csv"
collect_results.traverse_path("pred", results_path)
logger.info(f"Results saved in {results_path}")
# log calculated metrics (MAE) on w&b
logger.info("Logging MAE metrics on w&b.")
util.log_metrics(results_path)
logger.info("Pipeline finished")
if __name__ == "__main__":
main()