platform-demo-scripts/scripts/pipeline.py

93 lines
2.7 KiB
Python

"""
-----------------
Copyright © 2023 ACK Cyfronet AGH, Poland.
This work was partially funded by EPOS Project funded in frame of PL-POIR4.2
-----------------
This script runs the pipeline for the training and evaluation of the models.
"""
import logging
import time
import argparse
import util
import generate_eval_targets
import hyperparameter_sweep
import eval
import collect_results
from config_loader import data_path, targets_path, sampling_rate, dataset_name, sweep_files
logger = logging.getLogger('pipeline')
logger.setLevel(logging.INFO)
def load_sweep_config(model_name, args):
if model_name == "PhaseNet" and args.phasenet_config is not None:
sweep_fname = args.phasenet_config
elif model_name == "GPD" and args.gpd_config is not None:
sweep_fname = args.gpd_config
else:
# use the default sweep config for the model
sweep_fname = sweep_files[model_name]
logger.info(f"Loading sweep config: {sweep_fname}")
return util.load_sweep_config(sweep_fname)
def find_the_best_params(model_name, args):
# find the best hyperparams for the model_name
logger.info(f"Starting searching for the best hyperparams for the model: {model_name}")
sweep_config = load_sweep_config(model_name, args)
sweep_runner = hyperparameter_sweep.start_sweep(sweep_config)
# wait for all runs to finish
all_finished = sweep_runner.all_runs_finished()
while not all_finished:
logger.info("Waiting for sweep runs to finish...")
# Sleep for a few seconds before checking again
time.sleep(30)
all_finished = sweep_runner.all_runs_finished()
logger.info(f"Finished the sweep: {sweep_runner.sweep_id}")
return sweep_runner.sweep_id
def generate_predictions(sweep_id, model_name):
experiment_name = f"{dataset_name}_{model_name}"
eval.main(weights=experiment_name,
targets=targets_path,
sets='dev,test',
batchsize=128,
num_workers=4,
# sampling_rate=sampling_rate,
sweep_id=sweep_id
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--phasenet_config", type=str, required=False)
parser.add_argument("--gpd_config", type=str, required=False)
args = parser.parse_args()
# generate labels
generate_eval_targets.main(data_path, targets_path, "2,3", sampling_rate, None)
# find the best hyperparams for the models
for model_name in ["GPD", "PhaseNet"]:
sweep_id = find_the_best_params(model_name, args)
generate_predictions(sweep_id, model_name)
# collect results
collect_results.traverse_path("pred", "pred/results.csv")
if __name__ == "__main__":
main()