extended pipeline arguments
This commit is contained in:
parent
bb2e136d42
commit
318a344c15
@ -79,26 +79,30 @@ def main():
|
|||||||
parser.add_argument("--basic_phase_ae_config", type=str, required=False)
|
parser.add_argument("--basic_phase_ae_config", type=str, required=False)
|
||||||
parser.add_argument("--eqtransformer_config", type=str, required=False)
|
parser.add_argument("--eqtransformer_config", type=str, required=False)
|
||||||
parser.add_argument("--dataset", type=str, required=False)
|
parser.add_argument("--dataset", type=str, required=False)
|
||||||
|
available_models = ["GPD", "PhaseNet", "BasicPhaseAE", "EQTransformer"]
|
||||||
|
parser.add_argument("--models", nargs='*', required=False, choices=available_models, default=available_models,
|
||||||
|
help="Models to train and evaluate (default: all)")
|
||||||
|
parser.add_argument("--collect_results", action="store_true", help="Collect and log results without training")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.dataset is not None:
|
if not args.collect_results:
|
||||||
util.set_dataset(args.dataset)
|
|
||||||
importlib.reload(config_loader)
|
|
||||||
|
|
||||||
logger.info(f"Started pipeline for the {config_loader.dataset_name} dataset.")
|
if args.dataset is not None:
|
||||||
|
util.set_dataset(args.dataset)
|
||||||
|
importlib.reload(config_loader)
|
||||||
|
|
||||||
# generate labels
|
logger.info(f"Started pipeline for the {config_loader.dataset_name} dataset.")
|
||||||
logger.info("Started generating labels for the dataset.")
|
|
||||||
generate_eval_targets.main(config_loader.data_path, config_loader.targets_path, "2,3", config_loader.sampling_rate,
|
|
||||||
None)
|
|
||||||
|
|
||||||
# find the best hyperparams for the models
|
# generate labels
|
||||||
logger.info("Started training the models.")
|
logger.info("Started generating labels for the dataset.")
|
||||||
for model_name in ["GPD", "PhaseNet", "BasicPhaseAE", "EQTransformer"]:
|
generate_eval_targets.main(config_loader.data_path, config_loader.targets_path, "2,3", config_loader.sampling_rate,
|
||||||
if config_loader.dataset_name == "lumineos" and model_name == "EQTransformer":
|
None)
|
||||||
break
|
|
||||||
sweep_id = find_the_best_params(model_name, args)
|
# find the best hyperparams for the models
|
||||||
generate_predictions(sweep_id, model_name)
|
logger.info("Started training the models.")
|
||||||
|
for model_name in args.models:
|
||||||
|
sweep_id = find_the_best_params(model_name, args)
|
||||||
|
generate_predictions(sweep_id, model_name)
|
||||||
|
|
||||||
# collect results
|
# collect results
|
||||||
logger.info("Collecting results.")
|
logger.info("Collecting results.")
|
||||||
|
Loading…
Reference in New Issue
Block a user