From 318a344c1579de1beb35598b08338569ac63027a Mon Sep 17 00:00:00 2001 From: Krystyna Milian Date: Mon, 26 Feb 2024 23:43:31 +0100 Subject: [PATCH] extended pipeline arguments --- scripts/pipeline.py | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/scripts/pipeline.py b/scripts/pipeline.py index 6b637a5..34e7492 100644 --- a/scripts/pipeline.py +++ b/scripts/pipeline.py @@ -79,26 +79,30 @@ def main(): parser.add_argument("--basic_phase_ae_config", type=str, required=False) parser.add_argument("--eqtransformer_config", type=str, required=False) parser.add_argument("--dataset", type=str, required=False) + available_models = ["GPD", "PhaseNet", "BasicPhaseAE", "EQTransformer"] + parser.add_argument("--models", nargs='*', required=False, choices=available_models, default=available_models, + help="Models to train and evaluate (default: all)") + parser.add_argument("--collect_results", action="store_true", help="Collect and log results without training") args = parser.parse_args() - if args.dataset is not None: - util.set_dataset(args.dataset) - importlib.reload(config_loader) + if not args.collect_results: - logger.info(f"Started pipeline for the {config_loader.dataset_name} dataset.") + if args.dataset is not None: + util.set_dataset(args.dataset) + importlib.reload(config_loader) - # generate labels - logger.info("Started generating labels for the dataset.") - generate_eval_targets.main(config_loader.data_path, config_loader.targets_path, "2,3", config_loader.sampling_rate, - None) + logger.info(f"Started pipeline for the {config_loader.dataset_name} dataset.") - # find the best hyperparams for the models - logger.info("Started training the models.") - for model_name in ["GPD", "PhaseNet", "BasicPhaseAE", "EQTransformer"]: - if config_loader.dataset_name == "lumineos" and model_name == "EQTransformer": - break - sweep_id = find_the_best_params(model_name, args) - generate_predictions(sweep_id, model_name) + # generate labels + logger.info("Started generating labels for the dataset.") + generate_eval_targets.main(config_loader.data_path, config_loader.targets_path, "2,3", config_loader.sampling_rate, + None) + + # find the best hyperparams for the models + logger.info("Started training the models.") + for model_name in args.models: + sweep_id = find_the_best_params(model_name, args) + generate_predictions(sweep_id, model_name) # collect results logger.info("Collecting results.")