extended pipeline arguments

This commit is contained in:
Krystyna Milian 2024-02-26 23:43:31 +01:00
parent bb2e136d42
commit 318a344c15

View File

@ -79,8 +79,14 @@ def main():
parser.add_argument("--basic_phase_ae_config", type=str, required=False)
parser.add_argument("--eqtransformer_config", type=str, required=False)
parser.add_argument("--dataset", type=str, required=False)
available_models = ["GPD", "PhaseNet", "BasicPhaseAE", "EQTransformer"]
parser.add_argument("--models", nargs='*', required=False, choices=available_models, default=available_models,
help="Models to train and evaluate (default: all)")
parser.add_argument("--collect_results", action="store_true", help="Collect and log results without training")
args = parser.parse_args()
if not args.collect_results:
if args.dataset is not None:
util.set_dataset(args.dataset)
importlib.reload(config_loader)
@ -94,9 +100,7 @@ def main():
# find the best hyperparams for the models
logger.info("Started training the models.")
for model_name in ["GPD", "PhaseNet", "BasicPhaseAE", "EQTransformer"]:
if config_loader.dataset_name == "lumineos" and model_name == "EQTransformer":
break
for model_name in args.models:
sweep_id = find_the_best_params(model_name, args)
generate_predictions(sweep_id, model_name)