2023-08-29 09:59:31 +02:00
|
|
|
"""
|
|
|
|
-----------------
|
|
|
|
Copyright © 2023 ACK Cyfronet AGH, Poland.
|
|
|
|
This work was partially funded by EPOS Project funded in frame of PL-POIR4.2
|
|
|
|
-----------------
|
|
|
|
|
|
|
|
This script runs the pipeline for the training and evaluation of the models.
|
|
|
|
"""
|
|
|
|
import logging
|
|
|
|
import time
|
|
|
|
import argparse
|
|
|
|
|
|
|
|
import util
|
|
|
|
import generate_eval_targets
|
|
|
|
import hyperparameter_sweep
|
|
|
|
import eval
|
|
|
|
import collect_results
|
2023-10-12 13:25:34 +02:00
|
|
|
import importlib
|
|
|
|
import config_loader
|
2023-08-29 09:59:31 +02:00
|
|
|
|
2023-09-26 10:50:46 +02:00
|
|
|
logging.root.setLevel(logging.INFO)
|
2023-08-29 09:59:31 +02:00
|
|
|
logger = logging.getLogger('pipeline')
|
|
|
|
|
|
|
|
|
|
|
|
def load_sweep_config(model_name, args):
|
|
|
|
if model_name == "PhaseNet" and args.phasenet_config is not None:
|
|
|
|
sweep_fname = args.phasenet_config
|
|
|
|
elif model_name == "GPD" and args.gpd_config is not None:
|
|
|
|
sweep_fname = args.gpd_config
|
2023-10-12 13:25:34 +02:00
|
|
|
elif model_name == "BasicPhaseAE" and args.basic_phase_ae_config is not None:
|
|
|
|
sweep_fname = args.basic_phase_ae_config
|
|
|
|
elif model_name == "EQTransformer" and args.eqtransformer_config is not None:
|
|
|
|
sweep_fname = args.eqtransformer_config
|
2023-08-29 09:59:31 +02:00
|
|
|
else:
|
|
|
|
# use the default sweep config for the model
|
2023-10-12 13:25:34 +02:00
|
|
|
sweep_fname = config_loader.sweep_files[model_name]
|
2023-08-29 09:59:31 +02:00
|
|
|
|
|
|
|
logger.info(f"Loading sweep config: {sweep_fname}")
|
|
|
|
|
|
|
|
return util.load_sweep_config(sweep_fname)
|
|
|
|
|
|
|
|
|
|
|
|
def find_the_best_params(model_name, args):
|
|
|
|
# find the best hyperparams for the model_name
|
|
|
|
logger.info(f"Starting searching for the best hyperparams for the model: {model_name}")
|
|
|
|
|
|
|
|
sweep_config = load_sweep_config(model_name, args)
|
|
|
|
sweep_runner = hyperparameter_sweep.start_sweep(sweep_config)
|
|
|
|
|
|
|
|
# wait for all runs to finish
|
|
|
|
|
|
|
|
all_finished = sweep_runner.all_runs_finished()
|
|
|
|
while not all_finished:
|
|
|
|
logger.info("Waiting for sweep runs to finish...")
|
|
|
|
# Sleep for a few seconds before checking again
|
|
|
|
time.sleep(30)
|
|
|
|
all_finished = sweep_runner.all_runs_finished()
|
|
|
|
|
|
|
|
logger.info(f"Finished the sweep: {sweep_runner.sweep_id}")
|
|
|
|
return sweep_runner.sweep_id
|
|
|
|
|
|
|
|
|
|
|
|
def generate_predictions(sweep_id, model_name):
|
2023-10-12 13:25:34 +02:00
|
|
|
experiment_name = f"{config_loader.dataset_name}_{model_name}"
|
2023-08-29 09:59:31 +02:00
|
|
|
eval.main(weights=experiment_name,
|
2023-10-12 13:25:34 +02:00
|
|
|
targets=config_loader.targets_path,
|
2023-08-29 09:59:31 +02:00
|
|
|
sets='dev,test',
|
|
|
|
batchsize=128,
|
|
|
|
num_workers=4,
|
|
|
|
# sampling_rate=sampling_rate,
|
|
|
|
sweep_id=sweep_id
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("--phasenet_config", type=str, required=False)
|
|
|
|
parser.add_argument("--gpd_config", type=str, required=False)
|
2023-10-12 13:25:34 +02:00
|
|
|
parser.add_argument("--basic_phase_ae_config", type=str, required=False)
|
|
|
|
parser.add_argument("--eqtransformer_config", type=str, required=False)
|
|
|
|
parser.add_argument("--dataset", type=str, required=False)
|
2023-08-29 09:59:31 +02:00
|
|
|
args = parser.parse_args()
|
|
|
|
|
2023-10-12 13:25:34 +02:00
|
|
|
if args.dataset is not None:
|
|
|
|
util.set_dataset(args.dataset)
|
|
|
|
importlib.reload(config_loader)
|
|
|
|
|
|
|
|
logger.info(f"Started pipeline for the {config_loader.dataset_name} dataset.")
|
|
|
|
|
2023-08-29 09:59:31 +02:00
|
|
|
# generate labels
|
2023-09-26 10:50:46 +02:00
|
|
|
logger.info("Started generating labels for the dataset.")
|
2023-10-12 13:25:34 +02:00
|
|
|
generate_eval_targets.main(config_loader.data_path, config_loader.targets_path, "2,3", config_loader.sampling_rate,
|
|
|
|
None)
|
2023-08-29 09:59:31 +02:00
|
|
|
|
|
|
|
# find the best hyperparams for the models
|
2023-09-26 10:50:46 +02:00
|
|
|
logger.info("Started training the models.")
|
2023-10-12 13:25:34 +02:00
|
|
|
for model_name in ["GPD", "PhaseNet", "BasicPhaseAE", "EQTransformer"]:
|
|
|
|
if config_loader.dataset_name == "lumineos" and model_name == "EQTransformer":
|
|
|
|
break
|
2023-08-29 09:59:31 +02:00
|
|
|
sweep_id = find_the_best_params(model_name, args)
|
|
|
|
generate_predictions(sweep_id, model_name)
|
|
|
|
|
|
|
|
# collect results
|
2023-09-26 10:50:46 +02:00
|
|
|
logger.info("Collecting results.")
|
2023-10-12 13:25:34 +02:00
|
|
|
results_path = "pred/results.csv"
|
|
|
|
collect_results.traverse_path("pred", results_path)
|
|
|
|
logger.info(f"Results saved in {results_path}")
|
|
|
|
|
|
|
|
# log calculated metrics (MAE) on w&b
|
|
|
|
logger.info("Logging MAE metrics on w&b.")
|
|
|
|
util.log_metrics(results_path)
|
|
|
|
|
|
|
|
logger.info("Pipeline finished")
|
|
|
|
|
2023-08-29 09:59:31 +02:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|
|
|
|
|